mofa_foundation/llm/
plugin.rs1use super::provider::{LLMConfig, LLMProvider};
6use super::types::*;
7use mofa_kernel::plugin::{
8 AgentPlugin, PluginContext, PluginMetadata, PluginPriority, PluginResult, PluginState,
9 PluginType,
10};
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16pub struct LLMPlugin {
32 metadata: PluginMetadata,
33 state: PluginState,
34 provider: Arc<dyn LLMProvider>,
35 config: LLMConfig,
36 stats: RwLock<LLMStats>,
37}
38
39#[derive(Debug, Default)]
41struct LLMStats {
42 total_requests: u64,
43 total_tokens: u64,
44 total_prompt_tokens: u64,
45 total_completion_tokens: u64,
46 failed_requests: u64,
47 avg_latency_ms: f64,
48}
49
50impl LLMPlugin {
51 pub fn new(id: &str, provider: Arc<dyn LLMProvider>) -> Self {
53 let metadata = PluginMetadata::new(id, provider.name(), PluginType::LLM)
54 .with_description(&format!("LLM provider: {}", provider.name()))
55 .with_priority(PluginPriority::High)
56 .with_capability("chat")
57 .with_capability("text-generation");
58
59 Self {
60 metadata,
61 state: PluginState::Unloaded,
62 provider,
63 config: LLMConfig::default(),
64 stats: RwLock::new(LLMStats::default()),
65 }
66 }
67
68 pub fn with_config(id: &str, provider: Arc<dyn LLMProvider>, config: LLMConfig) -> Self {
70 let mut plugin = Self::new(id, provider);
71 plugin.config = config;
72 plugin
73 }
74
75 pub fn provider(&self) -> &Arc<dyn LLMProvider> {
77 &self.provider
78 }
79
80 pub fn llm_config(&self) -> &LLMConfig {
82 &self.config
83 }
84
85 pub async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
87 let start = std::time::Instant::now();
88
89 let result = self.provider.chat(request).await;
90
91 let mut stats = self.stats.write().await;
93 stats.total_requests += 1;
94
95 match &result {
96 Ok(response) => {
97 if let Some(usage) = &response.usage {
98 stats.total_tokens += usage.total_tokens as u64;
99 stats.total_prompt_tokens += usage.prompt_tokens as u64;
100 stats.total_completion_tokens += usage.completion_tokens as u64;
101 }
102 let latency = start.elapsed().as_millis() as f64;
103 stats.avg_latency_ms = (stats.avg_latency_ms * (stats.total_requests - 1) as f64
104 + latency)
105 / stats.total_requests as f64;
106 }
107 Err(_) => {
108 stats.failed_requests += 1;
109 }
110 }
111
112 result
113 }
114
115 pub async fn ask(&self, question: &str) -> LLMResult<String> {
117 let model = self
118 .config
119 .default_model
120 .clone()
121 .unwrap_or_else(|| self.provider.default_model().to_string());
122
123 let request = ChatCompletionRequest::new(model)
124 .user(question)
125 .temperature(self.config.default_temperature.unwrap_or(0.7))
126 .max_tokens(self.config.default_max_tokens.unwrap_or(4096));
127
128 let response = self.chat(request).await?;
129
130 response
131 .content()
132 .map(|s| s.to_string())
133 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
134 }
135}
136
137#[async_trait::async_trait]
138impl AgentPlugin for LLMPlugin {
139 fn metadata(&self) -> &PluginMetadata {
140 &self.metadata
141 }
142
143 fn state(&self) -> PluginState {
144 self.state.clone()
145 }
146
147 async fn load(&mut self, ctx: &PluginContext) -> PluginResult<()> {
148 self.state = PluginState::Loading;
149
150 if let Some(api_key) = ctx.config.get_string("api_key") {
152 self.config.api_key = Some(api_key);
153 }
154 if let Some(base_url) = ctx.config.get_string("base_url") {
155 self.config.base_url = Some(base_url);
156 }
157 if let Some(model) = ctx.config.get_string("model") {
158 self.config.default_model = Some(model);
159 }
160
161 self.state = PluginState::Loaded;
162 Ok(())
163 }
164
165 async fn init_plugin(&mut self) -> PluginResult<()> {
166 self.provider.health_check().await.map_err(|e| {
168 self.state = PluginState::Error(e.to_string());
169 anyhow::anyhow!("LLM health check failed: {}", e)
170 })?;
171
172 Ok(())
173 }
174
175 async fn start(&mut self) -> PluginResult<()> {
176 self.state = PluginState::Running;
177 Ok(())
178 }
179
180 async fn stop(&mut self) -> PluginResult<()> {
181 self.state = PluginState::Paused;
182 Ok(())
183 }
184
185 async fn unload(&mut self) -> PluginResult<()> {
186 self.state = PluginState::Unloaded;
187 Ok(())
188 }
189
190 async fn execute(&mut self, input: String) -> PluginResult<String> {
191 self.ask(&input)
193 .await
194 .map_err(|e| anyhow::anyhow!("LLM execution failed: {}", e))
195 }
196
197 fn stats(&self) -> HashMap<String, serde_json::Value> {
198 let stats = match self.stats.try_read() {
200 Ok(s) => s,
201 Err(_) => return HashMap::new(),
202 };
203
204 let mut result = HashMap::new();
205 result.insert(
206 "total_requests".to_string(),
207 serde_json::json!(stats.total_requests),
208 );
209 result.insert(
210 "total_tokens".to_string(),
211 serde_json::json!(stats.total_tokens),
212 );
213 result.insert(
214 "total_prompt_tokens".to_string(),
215 serde_json::json!(stats.total_prompt_tokens),
216 );
217 result.insert(
218 "total_completion_tokens".to_string(),
219 serde_json::json!(stats.total_completion_tokens),
220 );
221 result.insert(
222 "failed_requests".to_string(),
223 serde_json::json!(stats.failed_requests),
224 );
225 result.insert(
226 "avg_latency_ms".to_string(),
227 serde_json::json!(stats.avg_latency_ms),
228 );
229 result
230 }
231
232 fn as_any(&self) -> &dyn Any {
233 self
234 }
235
236 fn as_any_mut(&mut self) -> &mut dyn Any {
237 self
238 }
239
240 fn into_any(self: Box<Self>) -> Box<dyn Any> {
241 self
242 }
243}
244
245#[async_trait::async_trait]
253pub trait LLMCapability: Send + Sync {
254 fn llm_provider(&self) -> Option<&Arc<dyn LLMProvider>>;
256
257 async fn llm_ask(&self, question: &str) -> LLMResult<String> {
259 let provider = self
260 .llm_provider()
261 .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
262
263 let request = ChatCompletionRequest::new(provider.default_model()).user(question);
264
265 let response = provider.chat(request).await?;
266
267 response
268 .content()
269 .map(|s| s.to_string())
270 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
271 }
272
273 async fn llm_ask_with_system(&self, system: &str, question: &str) -> LLMResult<String> {
275 let provider = self
276 .llm_provider()
277 .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
278
279 let request = ChatCompletionRequest::new(provider.default_model())
280 .system(system)
281 .user(question);
282
283 let response = provider.chat(request).await?;
284
285 response
286 .content()
287 .map(|s| s.to_string())
288 .ok_or_else(|| LLMError::Other("No content in response".to_string()))
289 }
290
291 async fn llm_chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
293 let provider = self
294 .llm_provider()
295 .ok_or_else(|| LLMError::ConfigError("LLM provider not configured".to_string()))?;
296
297 provider.chat(request).await
298 }
299}
300
301pub struct MockLLMProvider {
307 name: String,
308 responses: RwLock<Vec<String>>,
309 default_response: String,
310}
311
312impl MockLLMProvider {
313 pub fn new(name: &str) -> Self {
315 Self {
316 name: name.to_string(),
317 responses: RwLock::new(Vec::new()),
318 default_response: "This is a mock response.".to_string(),
319 }
320 }
321
322 pub fn with_default_response(mut self, response: impl Into<String>) -> Self {
324 self.default_response = response.into();
325 self
326 }
327
328 pub async fn add_response(&self, response: impl Into<String>) {
330 let mut responses = self.responses.write().await;
331 responses.push(response.into());
332 }
333}
334
335#[async_trait::async_trait]
336impl LLMProvider for MockLLMProvider {
337 fn name(&self) -> &str {
338 &self.name
339 }
340
341 fn default_model(&self) -> &str {
342 "mock-model"
343 }
344
345 fn supported_models(&self) -> Vec<&str> {
346 vec!["mock-model", "mock-model-large"]
347 }
348
349 fn supports_streaming(&self) -> bool {
350 false
351 }
352
353 fn supports_tools(&self) -> bool {
354 true
355 }
356
357 async fn chat(&self, _request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
358 let content = {
359 let mut responses = self.responses.write().await;
360 if responses.is_empty() {
361 self.default_response.clone()
362 } else {
363 responses.remove(0)
364 }
365 };
366
367 Ok(ChatCompletionResponse {
368 id: format!("mock-{}", uuid::Uuid::now_v7()),
369 object: "chat.completion".to_string(),
370 created: std::time::SystemTime::now()
371 .duration_since(std::time::UNIX_EPOCH)
372 .unwrap()
373 .as_secs(),
374 model: "mock-model".to_string(),
375 choices: vec![Choice {
376 index: 0,
377 message: ChatMessage::assistant(content),
378 finish_reason: Some(FinishReason::Stop),
379 logprobs: None,
380 }],
381 usage: Some(Usage {
382 prompt_tokens: 10,
383 completion_tokens: 20,
384 total_tokens: 30,
385 }),
386 system_fingerprint: None,
387 })
388 }
389}