git_iris/llm_providers/
mod.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use std::fmt;
5use std::str::FromStr;
6use strum::IntoEnumIterator;
7use strum_macros::{AsRefStr, EnumIter};
8mod claude;
9mod ollama;
10mod openai;
11
12// For testing
13pub mod test;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter, AsRefStr)]
16#[strum(serialize_all = "lowercase")]
17pub enum LLMProviderType {
18    OpenAI,
19    Claude,
20    Ollama,
21    Test,
22}
23
24impl fmt::Display for LLMProviderType {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        write!(f, "{}", self.as_ref())
27    }
28}
29
30impl FromStr for LLMProviderType {
31    type Err = anyhow::Error;
32
33    fn from_str(s: &str) -> Result<Self, Self::Err> {
34        match s.to_lowercase().as_str() {
35            "openai" => Ok(Self::OpenAI),
36            "claude" => Ok(Self::Claude),
37            "ollama" => Ok(Self::Ollama),
38            "test" => Ok(Self::Test),
39            _ => Err(anyhow::anyhow!("Unsupported provider: {}", s)),
40        }
41    }
42}
43
44#[async_trait]
45pub trait LLMProvider: Send + Sync {
46    async fn generate_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String>;
47}
48
49pub struct ProviderMetadata {
50    pub name: &'static str,
51    pub default_model: &'static str,
52    pub default_token_limit: usize,
53    pub requires_api_key: bool,
54}
55
56#[derive(Clone, Debug)]
57pub struct LLMProviderConfig {
58    pub api_key: String,
59    pub model: String,
60    pub additional_params: HashMap<String, String>,
61}
62
63pub fn create_provider(
64    provider_type: LLMProviderType,
65    config: LLMProviderConfig,
66) -> Result<Box<dyn LLMProvider + Send + Sync>> {
67    match provider_type {
68        LLMProviderType::OpenAI => Ok(Box::new(openai::OpenAIProvider::new(config))),
69        LLMProviderType::Claude => Ok(Box::new(claude::ClaudeProvider::new(config))),
70        LLMProviderType::Ollama => Ok(Box::new(ollama::OllamaProvider::new(config))),
71        LLMProviderType::Test => Ok(Box::new(test::TestLLMProvider::new(config))),
72    }
73}
74
75pub fn get_provider_metadata(provider_type: &LLMProviderType) -> ProviderMetadata {
76    match provider_type {
77        LLMProviderType::OpenAI => openai::get_metadata(),
78        LLMProviderType::Claude => claude::get_metadata(),
79        LLMProviderType::Ollama => ollama::get_metadata(),
80        LLMProviderType::Test => test::get_metadata(),
81    }
82}
83
84pub fn get_available_providers() -> Vec<LLMProviderType> {
85    LLMProviderType::iter().collect()
86}