Skip to main content

mofa_foundation/llm/
plugin.rs

1//! LLM Plugin Adapter
2//!
3//! 将 LLM Provider 封装为 MoFA 插件
4
5use super::provider::{LLMConfig, LLMProvider};
6use super::types::*;
7use mofa_kernel::plugin::{
8    AgentPlugin, PluginContext, PluginMetadata, PluginPriority, PluginResult, PluginState,
9    PluginType,
10};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16/// LLM 插件
17///
18/// 将 LLM Provider 封装为 MoFA 框架的插件
19///
20/// # 示例
21///
22/// ```rust,ignore
23/// use mofa_foundation::llm::{LLMPlugin, LLMConfig};
24///
25/// // 创建 LLM 插件
26/// let plugin = LLMPlugin::new("my-llm", provider);
27///
28/// // 作为 AgentPlugin 使用
29/// agent.add_plugin(Box::new(plugin));
30/// ```
31pub struct LLMPlugin {
32    metadata: PluginMetadata,
33    state: PluginState,
34    provider: Arc<dyn LLMProvider>,
35    config: LLMConfig,
36    stats: RwLock<LLMStats>,
37}
38
39/// LLM 统计信息
40#[derive(Debug, Default)]
41struct LLMStats {
42    total_requests: u64,
43    total_tokens: u64,
44    total_prompt_tokens: u64,
45    total_completion_tokens: u64,
46    failed_requests: u64,
47    avg_latency_ms: f64,
48}
49
50impl LLMPlugin {
51    /// 创建新的 LLM 插件
52    pub fn new(id: &str, provider: Arc<dyn LLMProvider>) -> Self {
53        let metadata = PluginMetadata::new(id, provider.name(), PluginType::LLM)
54            .with_description(&format!("LLM provider: {}", provider.name()))
55            .with_priority(PluginPriority::High)
56            .with_capability("chat")
57            .with_capability("text-generation");
58
59        Self {
60            metadata,
61            state: PluginState::Unloaded,
62            provider,
63            config: LLMConfig::default(),
64            stats: RwLock::new(LLMStats::default()),
65        }
66    }
67
68    /// 使用配置创建插件
69    pub fn with_config(id: &str, provider: Arc<dyn LLMProvider>, config: LLMConfig) -> Self {
70        let mut plugin = Self::new(id, provider);
71        plugin.config = config;
72        plugin
73    }
74
75    /// 获取 Provider
76    pub fn provider(&self) -> &Arc<dyn LLMProvider> {
77        &self.provider
78    }
79
80    /// 获取配置
81    pub fn llm_config(&self) -> &LLMConfig {
82        &self.config
83    }
84
85    /// 发送 Chat 请求
86    pub async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
87        let start = std::time::Instant::now();
88
89        let result = self.provider.chat(request).await;
90
91        // 更新统计
92        let mut stats = self.stats.write().await;
93        stats.total_requests += 1;
94
95        match &result {
96            Ok(response) => {
97                if let Some(usage) = &response.usage {
98                    stats.total_tokens += usage.total_tokens as u64;
99                    stats.total_prompt_tokens += usage.prompt_tokens as u64;
100                    stats.total_completion_tokens += usage.completion_tokens as u64;
101                }
102                let latency = start.elapsed().as_millis() as f64;
103                stats.avg_latency_ms = (stats.avg_latency_ms * (stats.total_requests - 1) as f64
104                    + latency)
105                    / stats.total_requests as f64;
106            }
107            Err(_) => {
108                stats.failed_requests += 1;
109            }
110        }
111
112        result
113    }
114
115    /// 简单问答
116    pub async fn ask(&self, question: &str) -> LLMResult<String> {
117        let model = self
118            .config
119            .default_model
120            .clone()
121            .unwrap_or_else(|| self.provider.default_model().to_string());
122
123        let request = ChatCompletionRequest::new(model)
124            .user(question)
125            .temperature(self.config.default_temperature.unwrap_or(0.7))
126            .max_tokens(self.config.default_max_tokens.unwrap_or(4096));
127
128        let response = self.chat(request).await?;
129
130        response
131            .content()
132            .map(|s| s.to_string())
133            .ok_or_else(|| LLMError::Other("No content in response".to_string()))
134    }
135}
136
137#[async_trait::async_trait]
138impl AgentPlugin for LLMPlugin {
139    fn metadata(&self) -> &PluginMetadata {
140        &self.metadata
141    }
142
143    fn state(&self) -> PluginState {
144        self.state.clone()
145    }
146
147    async fn load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
148        self.state = PluginState::Loading;
149
150        // 从上下文配置中读取 LLM 配置
151        if let Some(api_key) = ctx.config.get_string("api_key") {
152            self.config.api_key = Some(api_key);
153        }
154        if let Some(base_url) = ctx.config.get_string("base_url") {
155            self.config.base_url = Some(base_url);
156        }
157        if let Some(model) = ctx.config.get_string("model") {
158            self.config.default_model = Some(model);
159        }
160
161        self.state = PluginState::Loaded;
162        Ok(())
163    }
164
165    async fn init_plugin(&mut self) -> PluginResult<()> {
166        // 健康检查
167        self.provider.health_check().await.map_err(|e| {
168            self.state = PluginState::Error(e.to_string());
169            anyhow::anyhow!("LLM health check failed: {}", e)
170        })?;
171
172        Ok(())
173    }
174
175    async fn start(&mut self) -> PluginResult<()> {
176        self.state = PluginState::Running;
177        Ok(())
178    }
179
180    async fn stop(&mut self) -> PluginResult<()> {
181        self.state = PluginState::Paused;
182        Ok(())
183    }
184
185    async fn unload(&mut self) -> PluginResult<()> {
186        self.state = PluginState::Unloaded;
187        Ok(())
188    }
189
190    async fn execute(&mut self, input: String) -> PluginResult<String> {
191        // 简单模式:直接将输入作为用户消息
192        self.ask(&input)
193            .await
194            .map_err(|e| anyhow::anyhow!("LLM execution failed: {}", e))
195    }
196
197    fn stats(&self) -> HashMap<String, serde_json::Value> {
198        // 注意:这里使用 try_read 避免阻塞
199        let stats = match self.stats.try_read() {
200            Ok(s) => s,
201            Err(_) => return HashMap::new(),
202        };
203
204        let mut result = HashMap::new();
205        result.insert(
206            "total_requests".to_string(),
207            serde_json::json!(stats.total_requests),
208        );
209        result.insert(
210            "total_tokens".to_string(),
211            serde_json::json!(stats.total_tokens),
212        );
213        result.insert(
214            "total_prompt_tokens".to_string(),
215            serde_json::json!(stats.total_prompt_tokens),
216        );
217        result.insert(
218            "total_completion_tokens".to_string(),
219            serde_json::json!(stats.total_completion_tokens),
220        );
221        result.insert(
222            "failed_requests".to_string(),
223            serde_json::json!(stats.failed_requests),
224        );
225        result.insert(
226            "avg_latency_ms".to_string(),
227            serde_json::json!(stats.avg_latency_ms),
228        );
229        result
230    }
231
232    fn as_any(&self) -> &dyn Any {
233        self
234    }
235
236    fn as_any_mut(&mut self) -> &mut dyn Any {
237        self
238    }
239
240    fn into_any(self: Box<Self>) -> Box<dyn Any> {
241        self
242    }
243}
244
245// ============================================================================
246// LLM 能力扩展
247// ============================================================================
248
249/// LLM 能力 trait
250///
251/// 为 Agent 提供 LLM 交互能力
252#[async_trait::async_trait]
253pub trait LLMCapability: Send + Sync {
254    /// 获取 LLM 提供商
255    fn llm_provider(&self) -> Option<&Arc<dyn LLMProvider>>;
256
257    /// 简单问答
258    async fn llm_ask(&self, question: &str) -> LLMResult<String> {
259        let provider = self
260            .llm_provider()
261            .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
262
263        let request = ChatCompletionRequest::new(provider.default_model()).user(question);
264
265        let response = provider.chat(request).await?;
266
267        response
268            .content()
269            .map(|s| s.to_string())
270            .ok_or_else(|| LLMError::Other("No content in response".to_string()))
271    }
272
273    /// 带系统提示的问答
274    async fn llm_ask_with_system(&self, system: &str, question: &str) -> LLMResult<String> {
275        let provider = self
276            .llm_provider()
277            .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
278
279        let request = ChatCompletionRequest::new(provider.default_model())
280            .system(system)
281            .user(question);
282
283        let response = provider.chat(request).await?;
284
285        response
286            .content()
287            .map(|s| s.to_string())
288            .ok_or_else(|| LLMError::Other("No content in response".to_string()))
289    }
290
291    /// 发送完整的 Chat 请求
292    async fn llm_chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
293        let provider = self
294            .llm_provider()
295            .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
296
297        provider.chat(request).await
298    }
299}
300
301// ============================================================================
302// Mock Provider(用于测试)
303// ============================================================================
304
305/// Mock LLM Provider(用于测试)
306pub struct MockLLMProvider {
307    name: String,
308    responses: RwLock<Vec<String>>,
309    default_response: String,
310}
311
312impl MockLLMProvider {
313    /// 创建 Mock Provider
314    pub fn new(name: &str) -> Self {
315        Self {
316            name: name.to_string(),
317            responses: RwLock::new(Vec::new()),
318            default_response: "This is a mock response.".to_string(),
319        }
320    }
321
322    /// 设置默认响应
323    pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
324        self.default_response = response.into();
325        self
326    }
327
328    /// 添加预设响应(按顺序返回)
329    pub async fn add_response(&self, response: impl Into<String>) {
330        let mut responses = self.responses.write().await;
331        responses.push(response.into());
332    }
333}
334
335#[async_trait::async_trait]
336impl LLMProvider for MockLLMProvider {
337    fn name(&self) -> &str {
338        &self.name
339    }
340
341    fn default_model(&self) -> &str {
342        "mock-model"
343    }
344
345    fn supported_models(&self) -> Vec<&str> {
346        vec!["mock-model", "mock-model-large"]
347    }
348
349    fn supports_streaming(&self) -> bool {
350        false
351    }
352
353    fn supports_tools(&self) -> bool {
354        true
355    }
356
357    async fn chat(&self, _request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
358        let content = {
359            let mut responses = self.responses.write().await;
360            if responses.is_empty() {
361                self.default_response.clone()
362            } else {
363                responses.remove(0)
364            }
365        };
366
367        Ok(ChatCompletionResponse {
368            id: format!("mock-{}", uuid::Uuid::now_v7()),
369            object: "chat.completion".to_string(),
370            created: std::time::SystemTime::now()
371                .duration_since(std::time::UNIX_EPOCH)
372                .unwrap()
373                .as_secs(),
374            model: "mock-model".to_string(),
375            choices: vec![Choice {
376                index: 0,
377                message: ChatMessage::assistant(content),
378                finish_reason: Some(FinishReason::Stop),
379                logprobs: None,
380            }],
381            usage: Some(Usage {
382                prompt_tokens: 10,
383                completion_tokens: 20,
384                total_tokens: 30,
385            }),
386            system_fingerprint: None,
387        })
388    }
389}