1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use anyhow::Result;
use async_trait::async_trait;
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use strum::IntoEnumIterator;
use strum_macros::{AsRefStr, EnumIter};
mod claude;
mod ollama;
mod openai;

// For testing
pub mod test;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter, AsRefStr)]
#[strum(serialize_all = "lowercase")]
pub enum LLMProviderType {
    OpenAI,
    Claude,
    Ollama,
    Test,
}

impl fmt::Display for LLMProviderType {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", self.as_ref())
    }
}

impl FromStr for LLMProviderType {
    type Err = anyhow::Error;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s.to_lowercase().as_str() {
            "openai" => Ok(LLMProviderType::OpenAI),
            "claude" => Ok(LLMProviderType::Claude),
            "ollama" => Ok(LLMProviderType::Ollama),
            "test" => Ok(LLMProviderType::Test),
            _ => Err(anyhow::anyhow!("Unsupported provider: {}", s)),
        }
    }
}


#[async_trait]
pub trait LLMProvider: Send + Sync {
    async fn generate_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String>;
}

pub struct ProviderMetadata {
    pub name: &'static str,
    pub default_model: &'static str,
    pub default_token_limit: usize,
    pub requires_api_key: bool,
}

#[derive(Clone, Debug)]
pub struct LLMProviderConfig {
    pub api_key: String,
    pub model: String,
    pub additional_params: HashMap<String, String>,
}

pub fn create_provider(
    provider_type: LLMProviderType,
    config: LLMProviderConfig,
) -> Result<Box<dyn LLMProvider + Send + Sync>> {
    match provider_type {
        LLMProviderType::OpenAI => Ok(Box::new(openai::OpenAIProvider::new(config)?)),
        LLMProviderType::Claude => Ok(Box::new(claude::ClaudeProvider::new(config)?)),
        LLMProviderType::Ollama => Ok(Box::new(ollama::OllamaProvider::new(config)?)),
        LLMProviderType::Test => Ok(Box::new(test::TestLLMProvider::new(config)?)),
    }
}

pub fn get_provider_metadata(provider_type: &LLMProviderType) -> ProviderMetadata {
    match provider_type {
        LLMProviderType::OpenAI => openai::get_metadata(),
        LLMProviderType::Claude => claude::get_metadata(),
        LLMProviderType::Ollama => ollama::get_metadata(),
        LLMProviderType::Test => test::get_metadata(),
    }
}

pub fn get_available_providers() -> Vec<LLMProviderType> {
    LLMProviderType::iter().collect()
}