use crate::llms::gemini;
use anyhow::Result;
use std::{env, path::PathBuf};
use toml;
pub fn parse_gitory_config_file(
config_file: &PathBuf,
) -> Result<(
toml::map::Map<String, toml::Value>,
toml::map::Map<String, toml::Value>,
)> {
let config_string = std::fs::read_to_string(config_file)?;
let gitory_config: toml::Table = toml::from_str(&config_string)?;
let app_config = match gitory_config.get("app") {
Some(app_config) => match app_config.as_table() {
Some(app_config) => app_config,
None => {
return Err(anyhow::anyhow!(
"Unable to parse `app` section from `gitory_config.toml`"
));
}
},
None => {
return Err(anyhow::anyhow!(
"`app` section in the `gitory_config.toml` is missing"
))
}
};
let llm_config = match gitory_config.get("llm") {
Some(llm_config) => match llm_config.as_table() {
Some(llm_config) => llm_config,
None => {
return Err(anyhow::anyhow!(
"Unable to parse `llm` section from `gitory_config.toml`"
))
}
},
None => {
return Err(anyhow::anyhow!(
"`llm` section in the `gitory_config.toml` is missing"
))
}
};
Ok((app_config.clone(), llm_config.clone()))
}
pub fn load_llm_configs_for_gemini(
api_config: &toml::Value,
generation_config: &Option<&toml::Value>,
) -> Result<(
gemini::models::GeminiAPIConfig,
gemini::models::GenerationConfig,
)> {
let api_config = match api_config.as_table() {
Some(api_config) => {
let endpoint = match api_config["endpoint"].as_str() {
Some(endpoint) => endpoint,
None => {
return Err(anyhow::anyhow!(
"`endpoint` is not specified in the config file"
));
}
};
let endpoint = env::var(&endpoint)
.expect(format!("{endpoint} not set in .env file", endpoint = &endpoint).as_str());
let key = match api_config["key"].as_str() {
Some(key) => key,
None => {
return Err(anyhow::anyhow!("`key` is not specified in the config file"));
}
};
let key =
env::var(&key).expect(format!("{key} not set in .env file", key = &key).as_str());
let model_name = match api_config["model"].as_str() {
Some(model_name) => model_name,
None => {
return Err(anyhow::anyhow!(
"`model_name` is not specified in the config file"
));
}
};
let model_name = env::var(&model_name).expect(
format!(
"{model_name} not set in .env file",
model_name = &model_name
)
.as_str(),
);
gemini::models::GeminiAPIConfig {
endpoint,
key,
model_name,
}
}
None => {
return Err(anyhow::anyhow!(
"The `llm.api_config` section in the `gitory_config.toml` is missing"
))
}
};
let generation_config = match generation_config {
Some(generation_config) => match generation_config.as_table() {
Some(generation_config) => {
let max_output_tokens = match generation_config.get("max_output_tokens") {
Some(max_output_tokens) => match max_output_tokens.as_integer() {
Some(max_output_tokens) => Some(max_output_tokens as i32),
None => {
return Err(anyhow::anyhow!("Unable to parse `max_output_tokens` key"))
}
},
None => Default::default(),
};
let temperature = match generation_config.get("temperature") {
Some(temperature) => match temperature.as_float() {
Some(temperature) => Some(temperature as f32),
None => return Err(anyhow::anyhow!("Unable to parse `temperature` key")),
},
None => Default::default(),
};
let top_p = match generation_config.get("top_p") {
Some(top_p) => match top_p.as_float() {
Some(top_p) => Some(top_p as f32),
None => return Err(anyhow::anyhow!("Unable to parse `top_p` key")),
},
None => Default::default(),
};
let top_k = match generation_config.get("top_k") {
Some(top_k) => match top_k.as_integer() {
Some(top_k) => Some(top_k as i32),
None => return Err(anyhow::anyhow!("Unable to parse `top_k` key")),
},
None => Default::default(),
};
let stop_sequences = match generation_config.get("stop_sequences") {
Some(stop_sequences) => match stop_sequences.as_array() {
Some(stop_sequences) => {
let stop_sequences = stop_sequences
.into_iter()
.map(|stop_sequence| stop_sequence.as_str().unwrap().to_string())
.collect::<Vec<String>>();
Some(stop_sequences)
}
None => {
return Err(anyhow::anyhow!("Unable to parse `stop_sequences` key"))
}
},
None => Default::default(),
};
let candidate_count = match generation_config.get("candidate_count") {
Some(candidate_count) => match candidate_count.as_integer() {
Some(candidate_count) => Some(candidate_count as u32),
None => {
return Err(anyhow::anyhow!("Unable to parse `candidate_count` key"))
}
},
None => Default::default(),
};
gemini::models::GenerationConfig {
max_output_tokens,
temperature,
top_p,
top_k,
stop_sequences,
candidate_count,
}
}
None => gemini::models::GenerationConfig {
..Default::default()
},
},
None => gemini::models::GenerationConfig {
..Default::default()
},
};
Ok((api_config, generation_config))
}