traiy_core 0.0.13

An utility to serve AI suggestions according to user-provided guidelines and (optionally) context
Documentation
//! Builder module
//!
//! This module provides functionality to build and configure LLM (Large Language Model) providers.
//! It supports different LLM providers like Google and OpenAI, and allows building models
//! based on user-specified actions, API keys, and configurations.

use crate::cli::Action;
use crate::enums::LlmProvider;
use anyhow::Result;
use llm::LLMProvider as ExternalLLMProvider;
use llm::builder::LLMBackend;
use llm::builder::LLMBuilder;

/// Builds an LLM provider based on the specified action and configuration.
///
/// # Arguments
///
/// * `action` - The action to perform (e.g., recommend, enhance).
/// * `api_key` - The API key for the LLM provider.
/// * `model` - The name of the LLM model to use.
/// * `max_tokens` - The maximum number of tokens to generate.
/// * `temperature` - The temperature to use for generation.
/// * `system_msg` - The system message to send to the LLM.
/// * `llm_provider` - The LLM provider to use (e.g., Google, OpenAI).
///
/// # Returns
///
/// A `Result` containing a boxed `ExternalLLMProvider` trait object, or an error if the provider could not be built.
pub fn build_model(
    action: &Action,
    api_key: String,
    model: String,
    max_tokens: u32,
    temperature: f32,
    system_msg: String,
    llm_provider: &LlmProvider,
) -> Result<Box<dyn ExternalLLMProvider>> {
    match action {
        Action::Recommend { .. } => match llm_provider {
            LlmProvider::Google => {
                let model = LLMBuilder::new()
                    .backend(LLMBackend::Google)
                    .api_key(api_key)
                    .model(model.to_string())
                    .max_tokens(max_tokens)
                    .temperature(temperature)
                    .system(system_msg)
                    .build()?;
                Ok(model)
            }
            LlmProvider::Openai => {
                let model = LLMBuilder::new()
                    .backend(LLMBackend::OpenAI)
                    .api_key(api_key)
                    .model(model.to_string())
                    .max_tokens(max_tokens)
                    .temperature(temperature)
                    .system(system_msg)
                    .build()?;
                Ok(model)
            }
        },
        Action::Enhance { .. } => {
            unimplemented!("Not yet implmented!")
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::enums::LlmModel;
    use anyhow::Result;

    #[test]
    fn test_build_model_google_recommend() -> Result<()> {
        let action = Action::Recommend {
            input_csv: "input.csv".into(),
            guidelines_csv: "guidelines.csv".into(),
            context_csv: None,
            num_recommendations: None,
            llm_provider: LlmProvider::Google,
            model: LlmModel::GeminiFlash2,
            max_tokens: None,
            temperature: None,
        };
        let api_key = "test_api_key".to_string();
        let model_name = "gemini-2.0-flash".to_string();
        let max_tokens = 100u32;
        let temperature = 0.5f32;
        let system_msg = "Test system message".to_string();
        let llm_provider = LlmProvider::Google;

        let result = build_model(
            &action,
            api_key,
            model_name,
            max_tokens,
            temperature,
            system_msg,
            &llm_provider,
        );

        assert!(result.is_ok());
        Ok(())
    }
}