use super::{BuildConfig, Index};
use crate::{ApiConfig, QueryConfig};
use crate::error::Error;
use lazy_static::lazy_static;
use ragit_fs::{WriteMode, read_string, write_bytes, write_string};
use ragit_pdl::JsonType;
use serde_json::Value;
use std::collections::HashMap;
#[allow(dead_code)]
enum ConfigType {
Build,
Query,
Api,
}
lazy_static! {
static ref NEWLY_ADDED_CONFIGS: HashMap<String, (Value, ConfigType)> = vec![
("super_rerank", (Value::Bool(false), ConfigType::Query)),
("enable_rag", (Value::Bool(true), ConfigType::Query)),
("summary_after_build", (Value::Bool(false), ConfigType::Build)),
].into_iter().map(
|(key, value)| (key.to_string(), value)
).collect();
static ref DEPRECATED_CONFIGS: HashMap<String, String> = vec![
("max_titles", "Please use `max_summaries` and `max_retrieval`."),
("api_key", "Please set an environment variable or edit `models.json`."),
].into_iter().map(
|(key, value)| (key.to_string(), value.to_string())
).collect();
}
impl Index {
pub fn get_config_by_key(&self, key: String) -> Result<Value, Error> {
if let Some(message) = DEPRECATED_CONFIGS.get(&key) {
return Err(Error::DeprecatedConfig {
key,
message: message.to_string(),
});
}
for path in [
self.get_build_config_path()?,
self.get_api_config_path()?,
self.get_query_config_path()?,
] {
let j = read_string(&path)?;
let j = serde_json::from_str::<Value>(&j)?;
match j {
Value::Object(obj) => match obj.get(&key) {
Some(v) => { return Ok(v.clone()) },
_ => {},
},
_ => {
return Err(Error::JsonTypeError {
expected: JsonType::Object,
got: (&j).into(),
});
},
}
}
if let Some((value, _)) = NEWLY_ADDED_CONFIGS.get(&key) {
Ok(value.clone())
}
else {
Err(Error::InvalidConfigKey(key))
}
}
pub fn get_all_configs(&self) -> Result<Vec<(String, Value)>, Error> {
let mut result = HashMap::new();
for path in [
self.get_build_config_path()?,
self.get_api_config_path()?,
self.get_query_config_path()?,
] {
let j = read_string(&path)?;
let j = serde_json::from_str::<Value>(&j)?;
match j {
Value::Object(obj) => {
for (k, v) in obj.iter() {
if !DEPRECATED_CONFIGS.contains_key(k) {
result.insert(k.to_string(), v.clone());
}
}
},
_ => {
return Err(Error::JsonTypeError {
expected: JsonType::Object,
got: (&j).into(),
});
},
}
}
for (k, (v, _)) in NEWLY_ADDED_CONFIGS.iter() {
if !result.contains_key(k) {
result.insert(k.to_string(), v.clone());
}
}
Ok(result.into_iter().collect())
}
pub fn set_config_by_key(&mut self, key: String, value: String) -> Result<Option<String>, Error> {
if let Some(message) = DEPRECATED_CONFIGS.get(&key) {
return Err(Error::DeprecatedConfig {
key,
message: message.to_string(),
});
}
let mut tmp_json_cache = HashMap::new();
for json_file in [
self.get_build_config_path()?,
self.get_api_config_path()?,
self.get_query_config_path()?,
] {
tmp_json_cache.insert(
json_file.clone(),
read_string(&json_file)?,
);
}
match self.set_config_by_key_worker(key, value) {
Ok(previous_value) => Ok(previous_value),
Err(e) => {
for (path, content) in tmp_json_cache.iter() {
write_string(
path,
content,
WriteMode::CreateOrTruncate,
)?;
}
Err(e)
},
}
}
fn set_config_by_key_worker(&mut self, key: String, value: String) -> Result<Option<String>, Error> { let mut updated = false;
let mut previous_value = None;
for path in [
self.get_build_config_path()?,
self.get_api_config_path()?,
self.get_query_config_path()?,
] {
let j = read_string(&path)?;
let mut j = serde_json::from_str::<Value>(&j)?;
match &mut j {
Value::Object(obj) => match obj.get(&key) {
Some(original_value) => {
let original_type = JsonType::from(original_value);
let new_value = original_type.parse(&value)?;
previous_value = obj.get(&key).map(|value| value.to_string());
obj.insert(
key.clone(),
new_value,
);
updated = true;
write_bytes(
&path,
&serde_json::to_vec_pretty(&j)?,
WriteMode::CreateOrTruncate,
)?;
break;
},
None => {
continue;
},
},
_ => {
return Err(Error::JsonTypeError {
expected: JsonType::Object,
got: (&j).into(),
});
},
}
}
if !updated {
if let Some((default_value, config)) = NEWLY_ADDED_CONFIGS.get(&key) {
let config_path = match config {
ConfigType::Build => self.get_build_config_path()?,
ConfigType::Query => self.get_query_config_path()?,
ConfigType::Api => self.get_api_config_path()?,
};
let j = read_string(&config_path)?;
let mut j = serde_json::from_str::<Value>(&j)?;
let Value::Object(obj) = &mut j else { unreachable!() };
obj.insert(key.clone(), JsonType::from(default_value).parse(&value)?);
write_bytes(
&config_path,
&serde_json::to_vec_pretty(&j)?,
WriteMode::CreateOrTruncate,
)?;
}
else {
return Err(Error::InvalidConfigKey(key));
}
}
self.build_config = serde_json::from_str::<BuildConfig>(
&read_string(&self.get_build_config_path()?)?,
)?;
self.query_config = serde_json::from_str::<QueryConfig>(
&read_string(&self.get_query_config_path()?)?,
)?;
self.api_config = serde_json::from_str::<ApiConfig>(
&read_string(&self.get_api_config_path()?)?,
)?;
Ok(previous_value)
}
}