Skip to main content

mofa_foundation/llm/
provider.rs

1//! LLM Provider Trait
2//!
3//! 定义 LLM 提供商接口,支持多种 LLM 后端
4
5use super::types::*;
6use async_trait::async_trait;
7use futures::Stream;
8use std::pin::Pin;
9
10/// 流式响应类型
11pub type ChatStream = Pin<Box<dyn Stream<Item = LLMResult<ChatCompletionChunk>> + Send>>;
12
13/// LLM 提供商 trait
14///
15/// 所有 LLM 后端(OpenAI、Anthropic、本地模型等)都需要实现此 trait
16///
17/// # 示例
18///
19/// ```rust,ignore
20/// use mofa_foundation::llm::{LLMProvider, ChatCompletionRequest, LLMResult};
21///
22/// struct MyLLMProvider {
23///     api_key: String,
24///     base_url: String,
25/// }
26///
27/// #[async_trait::async_trait]
28/// impl LLMProvider for MyLLMProvider {
29///     fn name(&self) -> &str {
30///         "my-llm"
31///     }
32///
33///     async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
34///         // 实现 API 调用
35///         todo!()
36///     }
37/// }
38/// ```
39#[async_trait]
40pub trait LLMProvider: Send + Sync {
41    /// 获取提供商名称
42    fn name(&self) -> &str;
43
44    /// 获取默认模型
45    fn default_model(&self) -> &str {
46        ""
47    }
48
49    /// 获取支持的模型列表
50    fn supported_models(&self) -> Vec<&str> {
51        vec![]
52    }
53
54    /// 检查是否支持某个模型
55    fn supports_model(&self, model: &str) -> bool {
56        self.supported_models().contains(&model)
57    }
58
59    /// 检查是否支持流式输出
60    fn supports_streaming(&self) -> bool {
61        true
62    }
63
64    /// 检查是否支持工具调用
65    fn supports_tools(&self) -> bool {
66        true
67    }
68
69    /// 检查是否支持视觉(图片输入)
70    fn supports_vision(&self) -> bool {
71        false
72    }
73
74    /// 检查是否支持 embedding
75    fn supports_embedding(&self) -> bool {
76        false
77    }
78
79    /// 发送 Chat Completion 请求
80    async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse>;
81
82    /// 发送 Chat Completion 流式请求
83    async fn chat_stream(&self, _request: ChatCompletionRequest) -> LLMResult<ChatStream> {
84        // 默认实现:不支持流式
85        Err(LLMError::ProviderNotSupported(format!(
86            "Provider {} does not support streaming",
87            self.name()
88        )))
89    }
90
91    /// 发送 Embedding 请求
92    async fn embedding(&self, _request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
93        Err(LLMError::ProviderNotSupported(format!(
94            "Provider {} does not support embedding",
95            self.name()
96        )))
97    }
98
99    /// 健康检查
100    async fn health_check(&self) -> LLMResult<bool> {
101        Ok(true)
102    }
103
104    /// 获取模型信息
105    async fn get_model_info(&self, _model: &str) -> LLMResult<ModelInfo> {
106        Err(LLMError::ProviderNotSupported(format!(
107            "Provider {} does not support model info",
108            self.name()
109        )))
110    }
111}
112
113/// 模型信息
114#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
115pub struct ModelInfo {
116    /// 模型 ID
117    pub id: String,
118    /// 模型名称
119    pub name: String,
120    /// 模型描述
121    pub description: Option<String>,
122    /// 上下文窗口大小
123    pub context_window: Option<u32>,
124    /// 最大输出 token 数
125    pub max_output_tokens: Option<u32>,
126    /// 训练数据截止日期
127    pub training_cutoff: Option<String>,
128    /// 支持的功能
129    pub capabilities: ModelCapabilities,
130}
131
132/// 模型功能
133#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
134pub struct ModelCapabilities {
135    /// 支持流式输出
136    pub streaming: bool,
137    /// 支持工具调用
138    pub tools: bool,
139    /// 支持视觉(图片输入)
140    pub vision: bool,
141    /// 支持 JSON 模式
142    pub json_mode: bool,
143    /// 支持 JSON Schema
144    pub json_schema: bool,
145}
146
147/// LLM 配置
148#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
149pub struct LLMConfig {
150    /// 提供商名称
151    pub provider: String,
152    /// API Key
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub api_key: Option<String>,
155    /// API 基础 URL
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub base_url: Option<String>,
158    /// 默认模型
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub default_model: Option<String>,
161    /// 默认温度
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub default_temperature: Option<f32>,
164    /// 默认最大 token 数
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub default_max_tokens: Option<u32>,
167    /// 请求超时(秒)
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub timeout_secs: Option<u64>,
170    /// 重试次数
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub max_retries: Option<u32>,
173    /// 额外配置
174    #[serde(flatten)]
175    pub extra: HashMap<String, serde_json::Value>,
176}
177
178impl Default for LLMConfig {
179    fn default() -> Self {
180        Self {
181            provider: "openai".to_string(),
182            api_key: None,
183            base_url: None,
184            default_model: None,
185            default_temperature: Some(0.7),
186            default_max_tokens: Some(4096),
187            timeout_secs: Some(60),
188            max_retries: Some(3),
189            extra: std::collections::HashMap::new(),
190        }
191    }
192}
193
194impl LLMConfig {
195    /// 创建 OpenAI 配置
196    pub fn openai(api_key: impl Into<String>) -> Self {
197        Self {
198            provider: "openai".to_string(),
199            api_key: Some(api_key.into()),
200            base_url: Some("https://api.openai.com/v1".to_string()),
201            default_model: Some("gpt-4".to_string()),
202            ..Default::default()
203        }
204    }
205
206    /// 创建 Anthropic 配置
207    pub fn anthropic(api_key: impl Into<String>) -> Self {
208        Self {
209            provider: "anthropic".to_string(),
210            api_key: Some(api_key.into()),
211            base_url: Some("https://api.anthropic.com".to_string()),
212            default_model: Some("claude-3-sonnet-20240229".to_string()),
213            ..Default::default()
214        }
215    }
216
217    /// 创建本地 Ollama 配置
218    pub fn ollama(model: impl Into<String>) -> Self {
219        Self {
220            provider: "ollama".to_string(),
221            api_key: None,
222            base_url: Some("http://localhost:11434".to_string()),
223            default_model: Some(model.into()),
224            ..Default::default()
225        }
226    }
227
228    /// 创建兼容 OpenAI API 的配置
229    pub fn openai_compatible(
230        base_url: impl Into<String>,
231        api_key: impl Into<String>,
232        model: impl Into<String>,
233    ) -> Self {
234        Self {
235            provider: "openai-compatible".to_string(),
236            api_key: Some(api_key.into()),
237            base_url: Some(base_url.into()),
238            default_model: Some(model.into()),
239            ..Default::default()
240        }
241    }
242
243    /// 设置模型
244    pub fn model(mut self, model: impl Into<String>) -> Self {
245        self.default_model = Some(model.into());
246        self
247    }
248
249    /// 设置温度
250    pub fn temperature(mut self, temp: f32) -> Self {
251        self.default_temperature = Some(temp);
252        self
253    }
254
255    /// 设置最大 token 数
256    pub fn max_tokens(mut self, tokens: u32) -> Self {
257        self.default_max_tokens = Some(tokens);
258        self
259    }
260}
261
262// ============================================================================
263// 可扩展的 Provider 注册表
264// ============================================================================
265
266use std::collections::HashMap;
267use std::sync::Arc;
268use tokio::sync::RwLock;
269
270/// Provider 工厂函数类型
271pub type ProviderFactory = Box<dyn Fn(LLMConfig) -> LLMResult<Box<dyn LLMProvider>> + Send + Sync>;
272
273/// LLM Provider 注册表
274///
275/// 用于注册和创建 LLM Provider 实例
276pub struct LLMRegistry {
277    factories: RwLock<HashMap<String, ProviderFactory>>,
278    providers: RwLock<HashMap<String, Arc<dyn LLMProvider>>>,
279}
280
281impl LLMRegistry {
282    /// 创建新的注册表
283    pub fn new() -> Self {
284        Self {
285            factories: RwLock::new(HashMap::new()),
286            providers: RwLock::new(HashMap::new()),
287        }
288    }
289
290    /// 注册 Provider 工厂
291    pub async fn register_factory<F>(&self, name: &str, factory: F)
292    where
293        F: Fn(LLMConfig) -> LLMResult<Box<dyn LLMProvider>> + Send + Sync + 'static,
294    {
295        let mut factories = self.factories.write().await;
296        factories.insert(name.to_string(), Box::new(factory));
297    }
298
299    /// 创建 Provider 实例
300    pub async fn create(&self, config: LLMConfig) -> LLMResult<Arc<dyn LLMProvider>> {
301        let factories = self.factories.read().await;
302        let factory = factories
303            .get(&config.provider)
304            .ok_or_else(|| LLMError::ProviderNotSupported(config.provider.clone()))?;
305
306        let provider = factory(config)?;
307        Ok(Arc::from(provider))
308    }
309
310    /// 注册并缓存 Provider 实例
311    pub async fn register(&self, name: &str, provider: Arc<dyn LLMProvider>) {
312        let mut providers = self.providers.write().await;
313        providers.insert(name.to_string(), provider);
314    }
315
316    /// 获取已注册的 Provider
317    pub async fn get(&self, name: &str) -> Option<Arc<dyn LLMProvider>> {
318        let providers = self.providers.read().await;
319        providers.get(name).cloned()
320    }
321
322    /// 列出所有已注册的 Provider 名称
323    pub async fn list_providers(&self) -> Vec<String> {
324        let providers = self.providers.read().await;
325        providers.keys().cloned().collect()
326    }
327
328    /// 列出所有可用的 Provider 工厂名称
329    pub async fn list_factories(&self) -> Vec<String> {
330        let factories = self.factories.read().await;
331        factories.keys().cloned().collect()
332    }
333}
334
335impl Default for LLMRegistry {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341// ============================================================================
342// 全局 Registry(可选)
343// ============================================================================
344
345use std::sync::OnceLock;
346
347static GLOBAL_REGISTRY: OnceLock<LLMRegistry> = OnceLock::new();
348
349/// 获取全局 LLM 注册表
350pub fn global_registry() -> &'static LLMRegistry {
351    GLOBAL_REGISTRY.get_or_init(LLMRegistry::new)
352}