gitory-cli 0.1.0

Build a story for your project based on your git history
use crate::llms::gemini;
use anyhow::Result;
use std::{env, path::PathBuf};
use toml;

pub fn parse_gitory_config_file(
    config_file: &PathBuf,
) -> Result<(
    toml::map::Map<String, toml::Value>,
    toml::map::Map<String, toml::Value>,
)> {
    let config_string = std::fs::read_to_string(config_file)?;
    let gitory_config: toml::Table = toml::from_str(&config_string)?;

    let app_config = match gitory_config.get("app") {
        Some(app_config) => match app_config.as_table() {
            Some(app_config) => app_config,
            None => {
                return Err(anyhow::anyhow!(
                    "Unable to parse `app` section from `gitory_config.toml`"
                ));
            }
        },
        None => {
            return Err(anyhow::anyhow!(
                "`app` section in the `gitory_config.toml` is missing"
            ))
        }
    };

    let llm_config = match gitory_config.get("llm") {
        Some(llm_config) => match llm_config.as_table() {
            Some(llm_config) => llm_config,
            None => {
                return Err(anyhow::anyhow!(
                    "Unable to parse `llm` section from `gitory_config.toml`"
                ))
            }
        },
        None => {
            return Err(anyhow::anyhow!(
                "`llm` section in the `gitory_config.toml` is missing"
            ))
        }
    };

    Ok((app_config.clone(), llm_config.clone()))
}

pub fn load_llm_configs_for_gemini(
    api_config: &toml::Value,
    generation_config: &Option<&toml::Value>,
) -> Result<(
    gemini::models::GeminiAPIConfig,
    gemini::models::GenerationConfig,
)> {
    let api_config = match api_config.as_table() {
        Some(api_config) => {
            let endpoint = match api_config["endpoint"].as_str() {
                Some(endpoint) => endpoint,
                None => {
                    return Err(anyhow::anyhow!(
                        "`endpoint` is not specified in the config file"
                    ));
                }
            };

            let endpoint = env::var(&endpoint)
                .expect(format!("{endpoint} not set in .env file", endpoint = &endpoint).as_str());

            let key = match api_config["key"].as_str() {
                Some(key) => key,
                None => {
                    return Err(anyhow::anyhow!("`key` is not specified in the config file"));
                }
            };

            let key =
                env::var(&key).expect(format!("{key} not set in .env file", key = &key).as_str());

            let model_name = match api_config["model"].as_str() {
                Some(model_name) => model_name,
                None => {
                    return Err(anyhow::anyhow!(
                        "`model_name` is not specified in the config file"
                    ));
                }
            };

            let model_name = env::var(&model_name).expect(
                format!(
                    "{model_name} not set in .env file",
                    model_name = &model_name
                )
                .as_str(),
            );

            gemini::models::GeminiAPIConfig {
                endpoint,
                key,
                model_name,
            }
        }
        None => {
            return Err(anyhow::anyhow!(
                "The `llm.api_config` section in the `gitory_config.toml` is missing"
            ))
        }
    };

    let generation_config = match generation_config {
        Some(generation_config) => match generation_config.as_table() {
            Some(generation_config) => {
                let max_output_tokens = match generation_config.get("max_output_tokens") {
                    Some(max_output_tokens) => match max_output_tokens.as_integer() {
                        Some(max_output_tokens) => Some(max_output_tokens as i32),
                        None => {
                            return Err(anyhow::anyhow!("Unable to parse `max_output_tokens` key"))
                        }
                    },
                    None => Default::default(),
                };
                let temperature = match generation_config.get("temperature") {
                    Some(temperature) => match temperature.as_float() {
                        Some(temperature) => Some(temperature as f32),
                        None => return Err(anyhow::anyhow!("Unable to parse `temperature` key")),
                    },
                    None => Default::default(),
                };
                let top_p = match generation_config.get("top_p") {
                    Some(top_p) => match top_p.as_float() {
                        Some(top_p) => Some(top_p as f32),
                        None => return Err(anyhow::anyhow!("Unable to parse `top_p` key")),
                    },
                    None => Default::default(),
                };
                let top_k = match generation_config.get("top_k") {
                    Some(top_k) => match top_k.as_integer() {
                        Some(top_k) => Some(top_k as i32),
                        None => return Err(anyhow::anyhow!("Unable to parse `top_k` key")),
                    },
                    None => Default::default(),
                };
                let stop_sequences = match generation_config.get("stop_sequences") {
                    Some(stop_sequences) => match stop_sequences.as_array() {
                        Some(stop_sequences) => {
                            let stop_sequences = stop_sequences
                                .into_iter()
                                .map(|stop_sequence| stop_sequence.as_str().unwrap().to_string())
                                .collect::<Vec<String>>();

                            Some(stop_sequences)
                        }
                        None => {
                            return Err(anyhow::anyhow!("Unable to parse `stop_sequences` key"))
                        }
                    },
                    None => Default::default(),
                };

                let candidate_count = match generation_config.get("candidate_count") {
                    Some(candidate_count) => match candidate_count.as_integer() {
                        Some(candidate_count) => Some(candidate_count as u32),
                        None => {
                            return Err(anyhow::anyhow!("Unable to parse `candidate_count` key"))
                        }
                    },
                    None => Default::default(),
                };

                gemini::models::GenerationConfig {
                    max_output_tokens,
                    temperature,
                    top_p,
                    top_k,
                    stop_sequences,
                    candidate_count,
                }
            }
            None => gemini::models::GenerationConfig {
                ..Default::default()
            },
        },
        None => gemini::models::GenerationConfig {
            ..Default::default()
        },
    };

    Ok((api_config, generation_config))
}