git_iris/llm_providers/
mod.rs1use anyhow::Result;
2use async_trait::async_trait;
3use std::collections::HashMap;
4use std::fmt;
5use std::str::FromStr;
6use strum::IntoEnumIterator;
7use strum_macros::{AsRefStr, EnumIter};
8mod claude;
9mod ollama;
10mod openai;
11
12pub mod test;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter, AsRefStr)]
16#[strum(serialize_all = "lowercase")]
17pub enum LLMProviderType {
18 OpenAI,
19 Claude,
20 Ollama,
21 Test,
22}
23
24impl fmt::Display for LLMProviderType {
25 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26 write!(f, "{}", self.as_ref())
27 }
28}
29
30impl FromStr for LLMProviderType {
31 type Err = anyhow::Error;
32
33 fn from_str(s: &str) -> Result<Self, Self::Err> {
34 match s.to_lowercase().as_str() {
35 "openai" => Ok(Self::OpenAI),
36 "claude" => Ok(Self::Claude),
37 "ollama" => Ok(Self::Ollama),
38 "test" => Ok(Self::Test),
39 _ => Err(anyhow::anyhow!("Unsupported provider: {}", s)),
40 }
41 }
42}
43
44#[async_trait]
45pub trait LLMProvider: Send + Sync {
46 async fn generate_message(&self, system_prompt: &str, user_prompt: &str) -> Result<String>;
47}
48
49pub struct ProviderMetadata {
50 pub name: &'static str,
51 pub default_model: &'static str,
52 pub default_token_limit: usize,
53 pub requires_api_key: bool,
54}
55
56#[derive(Clone, Debug)]
57pub struct LLMProviderConfig {
58 pub api_key: String,
59 pub model: String,
60 pub additional_params: HashMap<String, String>,
61}
62
63pub fn create_provider(
64 provider_type: LLMProviderType,
65 config: LLMProviderConfig,
66) -> Result<Box<dyn LLMProvider + Send + Sync>> {
67 match provider_type {
68 LLMProviderType::OpenAI => Ok(Box::new(openai::OpenAIProvider::new(config))),
69 LLMProviderType::Claude => Ok(Box::new(claude::ClaudeProvider::new(config))),
70 LLMProviderType::Ollama => Ok(Box::new(ollama::OllamaProvider::new(config))),
71 LLMProviderType::Test => Ok(Box::new(test::TestLLMProvider::new(config))),
72 }
73}
74
75pub fn get_provider_metadata(provider_type: &LLMProviderType) -> ProviderMetadata {
76 match provider_type {
77 LLMProviderType::OpenAI => openai::get_metadata(),
78 LLMProviderType::Claude => claude::get_metadata(),
79 LLMProviderType::Ollama => ollama::get_metadata(),
80 LLMProviderType::Test => test::get_metadata(),
81 }
82}
83
84pub fn get_available_providers() -> Vec<LLMProviderType> {
85 LLMProviderType::iter().collect()
86}