Skip to main content

mofa_foundation/persistence/
plugin.rs

1//! 持久化插件
2//!
3//! 提供与 LLMAgent 集成的持久化功能
4
5use super::entities::*;
6use super::traits::*;
7use crate::llm::types::LLMResponseMetadata;
8use crate::llm::{LLMError, LLMResult};
9use mofa_kernel::plugin::{
10    AgentPlugin, PluginContext, PluginMetadata, PluginResult, PluginState, PluginType,
11};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tracing::{debug, info};
16use uuid::Uuid;
17
18/// 持久化上下文
19///
20/// 提供对持久化功能的便捷访问
21pub struct PersistenceContext<S>
22where
23    S: MessageStore + ApiCallStore + SessionStore + Send + Sync + 'static,
24{
25    store: Arc<S>,
26    user_id: Uuid,
27    agent_id: Uuid,
28    tenant_id: Uuid,
29    session_id: Uuid,
30}
31
32impl<S> PersistenceContext<S>
33where
34    S: MessageStore + ApiCallStore + SessionStore + Send + Sync + 'static,
35{
36    /// 创建新的持久化上下文
37    pub async fn new(
38        store: Arc<S>,
39        user_id: Uuid,
40        tenant_id: Uuid,
41        agent_id: Uuid,
42    ) -> LLMResult<Self> {
43        let session = ChatSession::new(user_id, agent_id);
44        store
45            .create_session(&session)
46            .await
47            .map_err(|e| LLMError::Other(e.to_string()))?;
48
49        Ok(Self {
50            store,
51            user_id,
52            agent_id,
53            tenant_id,
54            session_id: session.id,
55        })
56    }
57
58    /// 从现有会话创建上下文
59    pub fn from_session(
60        store: Arc<S>,
61        user_id: Uuid,
62        agent_id: Uuid,
63        tenant_id: Uuid,
64        session_id: Uuid,
65    ) -> Self {
66        Self {
67            store,
68            user_id,
69            agent_id,
70            tenant_id,
71            session_id,
72        }
73    }
74
75    /// 获取会话 ID
76    pub fn session_id(&self) -> Uuid {
77        self.session_id
78    }
79
80    /// 保存用户消息
81    pub async fn save_user_message(&self, content: impl Into<String>) -> LLMResult<Uuid> {
82        let message = LLMMessage::new(
83            self.session_id,
84            self.agent_id,
85            self.user_id,
86            self.tenant_id,
87            MessageRole::User,
88            MessageContent::text(content),
89        );
90        let id = message.id;
91
92        self.store
93            .save_message(&message)
94            .await
95            .map_err(|e| LLMError::Other(e.to_string()))?;
96
97        Ok(id)
98    }
99
100    /// 保存助手消息
101    pub async fn save_assistant_message(&self, content: impl Into<String>) -> LLMResult<Uuid> {
102        let message = LLMMessage::new(
103            self.session_id,
104            self.agent_id,
105            self.user_id,
106            self.tenant_id,
107            MessageRole::Assistant,
108            MessageContent::text(content),
109        );
110        let id = message.id;
111
112        self.store
113            .save_message(&message)
114            .await
115            .map_err(|e| LLMError::Other(e.to_string()))?;
116
117        Ok(id)
118    }
119
120    /// 获取会话消息历史
121    pub async fn get_history(&self) -> LLMResult<Vec<LLMMessage>> {
122        self.store
123            .get_session_messages(self.session_id)
124            .await
125            .map_err(|e| LLMError::Other(e.to_string()))
126    }
127
128    /// 获取使用统计
129    pub async fn get_usage_stats(&self) -> LLMResult<UsageStatistics> {
130        let filter = QueryFilter::new().session(self.session_id);
131        self.store
132            .get_statistics(&filter)
133            .await
134            .map_err(|e| LLMError::Other(e.to_string()))
135    }
136
137    /// 创建新会话
138    pub async fn new_session(&mut self) -> LLMResult<Uuid> {
139        let session = ChatSession::new(self.user_id, self.agent_id);
140        self.store
141            .create_session(&session)
142            .await
143            .map_err(|e| LLMError::Other(e.to_string()))?;
144
145        self.session_id = session.id;
146        Ok(session.id)
147    }
148
149    /// 获取存储引用
150    pub fn store(&self) -> Arc<S> {
151        self.store.clone()
152    }
153}
154
155// ============================================================================
156// PersistencePlugin - 实现 AgentPlugin trait
157// ============================================================================
158
159/// 持久化插件
160///
161/// 实现 AgentPlugin trait,提供完整的持久化能力:
162/// - 从数据库加载会话历史
163/// - 自动记录用户消息、助手消息、API 调用
164///
165/// # 示例
166///
167/// ```rust,ignore
168/// use mofa_foundation::persistence::{PersistencePlugin, PostgresStore};
169/// use mofa_sdk::llm::LLMAgentBuilder;
170/// use uuid::Uuid;
171///
172/// # async fn example() -> anyhow::Result<()> {
173/// let store = PostgresStore::connect("postgres://localhost/mofa").await?;
174/// let user_id = Uuid::now_v7();
175/// let tenant_id = Uuid::now_v7();
176/// let agent_id = Uuid::now_v7();
177/// let session_id = Uuid::now_v7();
178///
179/// let plugin = PersistencePlugin::from_store(
180///     "persistence-plugin",
181///     store,
182///     user_id,
183///     tenant_id,
184///     agent_id,
185///     session_id,
186/// );
187///
188/// let agent = LLMAgentBuilder::new()
189///     .with_plugin(plugin)
190///     .build_async()
191///     .await;
192/// # Ok(())
193/// # }
194/// ```
195pub struct PersistencePlugin {
196    metadata: PluginMetadata,
197    state: PluginState,
198    message_store: Arc<dyn MessageStore + Send + Sync>,
199    api_call_store: Arc<dyn ApiCallStore + Send + Sync>,
200    session_store: Option<Arc<dyn SessionStore + Send + Sync>>,
201    user_id: Uuid,
202    tenant_id: Uuid,
203    agent_id: Uuid,
204    session_id: Arc<RwLock<Uuid>>,
205    current_user_msg_id: Arc<RwLock<Option<Uuid>>>,
206    request_start_time: Arc<RwLock<Option<std::time::Instant>>>,
207    response_id: Arc<RwLock<Option<String>>>,
208    current_model: Arc<RwLock<Option<String>>>,
209}
210
211impl PersistencePlugin {
212    /// 创建持久化插件
213    ///
214    /// # 参数
215    /// - `plugin_id`: 插件唯一标识
216    /// - `message_store`: 消息存储后端
217    /// - `api_call_store`: API 调用存储后端
218    /// - `user_id`: 用户 ID
219    /// - `tenant_id`: 租户 ID
220    /// - `agent_id`: Agent ID
221    /// - `session_id`: 会话 ID
222    pub fn new(
223        plugin_id: &str,
224        message_store: Arc<dyn MessageStore + Send + Sync>,
225        api_call_store: Arc<dyn ApiCallStore + Send + Sync>,
226        user_id: Uuid,
227        tenant_id: Uuid,
228        agent_id: Uuid,
229        session_id: Uuid,
230    ) -> Self {
231        let metadata = PluginMetadata::new(plugin_id, "Persistence Plugin", PluginType::Storage)
232            .with_description("Message and API call persistence plugin")
233            .with_capability("message_persistence")
234            .with_capability("api_call_logging")
235            .with_capability("session_history");
236
237        Self {
238            metadata,
239            state: PluginState::Loaded,
240            message_store,
241            api_call_store,
242            session_store: None,
243            user_id,
244            tenant_id,
245            agent_id,
246            session_id: Arc::new(RwLock::new(session_id)),
247            current_user_msg_id: Arc::new(RwLock::new(None)),
248            request_start_time: Arc::new(RwLock::new(None)),
249            response_id: Arc::new(RwLock::new(None)),
250            current_model: Arc::new(RwLock::new(None)),
251        }
252    }
253
254    /// 创建持久化插件(便捷方法,使用单个存储后端)
255    ///
256    /// # 参数
257    /// - `plugin_id`: 插件唯一标识
258    /// - `store`: 持久化存储后端(需要同时实现 MessageStore、ApiCallStore、SessionStore)
259    /// - `user_id`: 用户 ID
260    /// - `tenant_id`: 租户 ID
261    /// - `agent_id`: Agent ID
262    /// - `session_id`: 会话 ID
263    pub fn from_store<S>(
264        plugin_id: &str,
265        store: S,
266        user_id: Uuid,
267        tenant_id: Uuid,
268        agent_id: Uuid,
269        session_id: Uuid,
270    ) -> Self
271    where
272        S: MessageStore + ApiCallStore + SessionStore + Send + Sync + 'static,
273    {
274        let store_arc = Arc::new(store);
275        let session_store: Arc<dyn SessionStore + Send + Sync> = store_arc.clone();
276        let mut plugin = Self::new(
277            plugin_id,
278            store_arc.clone(),
279            store_arc,
280            user_id,
281            tenant_id,
282            agent_id,
283            session_id,
284        );
285        plugin.session_store = Some(session_store);
286        plugin
287    }
288
289    /// 更新会话 ID
290    pub async fn with_session_id(&self, session_id: Uuid) {
291        *self.session_id.write().await = session_id;
292    }
293
294    /// 获取当前会话 ID
295    pub async fn session_id(&self) -> Uuid {
296        *self.session_id.read().await
297    }
298
299    /// 获取历史消息(用于 build_async)
300    pub async fn load_history(&self) -> PersistenceResult<Vec<LLMMessage>> {
301        self.message_store
302            .get_session_messages(*self.session_id.read().await)
303            .await
304    }
305
306    /// 获取消息存储引用
307    pub fn message_store(&self) -> Arc<dyn MessageStore + Send + Sync> {
308        self.message_store.clone()
309    }
310
311    /// 获取 API 调用存储引用
312    pub fn api_call_store(&self) -> Arc<dyn ApiCallStore + Send + Sync> {
313        self.api_call_store.clone()
314    }
315
316    /// 获取会话存储引用
317    pub fn session_store(&self) -> Option<Arc<dyn SessionStore + Send + Sync>> {
318        self.session_store.clone()
319    }
320
321    /// 获取用户 ID
322    pub fn user_id(&self) -> Uuid {
323        self.user_id
324    }
325
326    /// 获取租户 ID
327    pub fn tenant_id(&self) -> Uuid {
328        self.tenant_id
329    }
330
331    /// 获取 Agent ID
332    pub fn agent_id(&self) -> Uuid {
333        self.agent_id
334    }
335
336    /// 保存消息(内部方法)
337    async fn save_message_internal(&self, role: MessageRole, content: &str) -> LLMResult<Uuid> {
338        let session_id = *self.session_id.read().await;
339        let message = LLMMessage::new(
340            session_id,
341            self.agent_id,
342            self.user_id,
343            self.tenant_id,
344            role,
345            MessageContent::text(content),
346        );
347        let id = message.id;
348
349        self.message_store
350            .save_message(&message)
351            .await
352            .map_err(|e| LLMError::Other(e.to_string()))?;
353
354        Ok(id)
355    }
356
357    /// 保存用户消息
358    pub async fn save_user_message(&self, content: &str) -> LLMResult<Uuid> {
359        self.save_message_internal(MessageRole::User, content).await
360    }
361
362    /// 保存助手消息
363    pub async fn save_assistant_message(&self, content: &str) -> LLMResult<Uuid> {
364        self.save_message_internal(MessageRole::Assistant, content)
365            .await
366    }
367}
368
369impl Clone for PersistencePlugin {
370    fn clone(&self) -> Self {
371        Self {
372            metadata: self.metadata.clone(),
373            state: self.state.clone(),
374            message_store: self.message_store.clone(),
375            api_call_store: self.api_call_store.clone(),
376            session_store: self.session_store.clone(),
377            user_id: self.user_id,
378            tenant_id: self.tenant_id,
379            agent_id: self.agent_id,
380            session_id: self.session_id.clone(),
381            current_user_msg_id: self.current_user_msg_id.clone(),
382            request_start_time: self.request_start_time.clone(),
383            response_id: self.response_id.clone(),
384            current_model: self.current_model.clone(),
385        }
386    }
387}
388
389#[async_trait::async_trait]
390impl AgentPlugin for PersistencePlugin {
391    fn metadata(&self) -> &PluginMetadata {
392        &self.metadata
393    }
394
395    fn state(&self) -> PluginState {
396        self.state.clone()
397    }
398
399    async fn load(&mut self, _ctx: &PluginContext) -> PluginResult<()> {
400        self.state = PluginState::Loaded;
401        Ok(())
402    }
403
404    async fn init_plugin(&mut self) -> PluginResult<()> {
405        self.state = PluginState::Running;
406        Ok(())
407    }
408
409    async fn start(&mut self) -> PluginResult<()> {
410        self.state = PluginState::Running;
411        Ok(())
412    }
413
414    async fn stop(&mut self) -> PluginResult<()> {
415        self.state = PluginState::Unloaded;
416        Ok(())
417    }
418
419    async fn unload(&mut self) -> PluginResult<()> {
420        self.state = PluginState::Unloaded;
421        Ok(())
422    }
423
424    async fn execute(&mut self, _input: String) -> PluginResult<String> {
425        Ok("persistence plugin".to_string())
426    }
427
428    fn stats(&self) -> HashMap<String, serde_json::Value> {
429        let mut stats = HashMap::new();
430        stats.insert(
431            "plugin_type".to_string(),
432            serde_json::Value::String("persistence".to_string()),
433        );
434        stats.insert(
435            "user_id".to_string(),
436            serde_json::Value::String(self.user_id.to_string()),
437        );
438        stats.insert(
439            "tenant_id".to_string(),
440            serde_json::Value::String(self.tenant_id.to_string()),
441        );
442        stats.insert(
443            "agent_id".to_string(),
444            serde_json::Value::String(self.agent_id.to_string()),
445        );
446        stats
447    }
448
449    fn as_any(&self) -> &dyn std::any::Any {
450        self
451    }
452
453    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
454        self
455    }
456
457    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any> {
458        self
459    }
460}
461
462// 实现 LLMAgentEventHandler trait
463#[async_trait::async_trait]
464impl crate::llm::agent::LLMAgentEventHandler for PersistencePlugin {
465    fn clone_box(&self) -> Box<dyn crate::llm::agent::LLMAgentEventHandler> {
466        // 由于 PersistencePlugin 需要 Arc<S>,我们创建一个新的克隆实例
467        Box::new(self.clone())
468    }
469
470    fn as_any(&self) -> &dyn std::any::Any {
471        self
472    }
473
474    /// 在发送用户消息前调用 - 记录用户消息
475    async fn before_chat(&self, message: &str) -> LLMResult<Option<String>> {
476        // 记录请求开始时间
477        *self.request_start_time.write().await = Some(std::time::Instant::now());
478
479        // 保存用户消息
480        let user_msg_id = self.save_user_message(message).await?;
481        info!("✅ [持久化插件] 用户消息已保存: ID = {}", user_msg_id);
482
483        // 存储当前用户消息 ID,用于后续关联 API 调用
484        *self.current_user_msg_id.write().await = Some(user_msg_id);
485
486        Ok(Some(message.to_string()))
487    }
488
489    /// 在发送用户消息前调用(带模型名称)- 记录用户消息和模型
490    async fn before_chat_with_model(
491        &self,
492        message: &str,
493        model: &str,
494    ) -> LLMResult<Option<String>> {
495        // 存储模型名称,用于后续的 after_chat 和 on_error
496        *self.current_model.write().await = Some(model.to_string());
497
498        // 调用原有的 before_chat 逻辑
499        self.before_chat(message).await
500    }
501
502    /// 在收到 LLM 响应后调用 - 记录助手消息和 API 调用
503    async fn after_chat(&self, response: &str) -> LLMResult<Option<String>> {
504        // 保存助手消息
505        let assistant_msg_id = self.save_assistant_message(response).await?;
506        info!("✅ [持久化插件] 助手消息已保存: ID = {}", assistant_msg_id);
507
508        // 计算请求延迟
509        let latency = match *self.request_start_time.read().await {
510            Some(start) => start.elapsed().as_millis() as i32,
511            None => 0,
512        };
513
514        // 获取存储的模型名称,或使用默认值
515        let model = self.current_model.read().await;
516        let model_name = model.as_ref().map(|s| s.as_str()).unwrap_or("unknown");
517
518        // 记录 API 调用
519        if let Some(user_msg_id) = *self.current_user_msg_id.read().await {
520            let session_id = *self.session_id.read().await;
521            let now = chrono::Utc::now();
522            let request_time = now - chrono::Duration::milliseconds(latency as i64);
523
524            let api_call = LLMApiCall::success(
525                session_id,
526                self.agent_id,
527                self.user_id,
528                self.tenant_id,
529                user_msg_id,
530                assistant_msg_id,
531                model_name,
532                0,                         // 未知(没有元数据时无法获取真实值)
533                response.len() as i32 / 4, // 简单估算 completion_tokens (每4字符一个token)
534                request_time,
535                now,
536            );
537
538            let _ = self
539                .api_call_store
540                .save_api_call(&api_call)
541                .await
542                .map_err(|e| LLMError::Other(e.to_string()));
543            info!(
544                "✅ [持久化插件] API 调用记录已保存: 模型={}, 延迟={}ms",
545                model_name, latency
546            );
547        }
548
549        // 清理状态
550        *self.current_user_msg_id.write().await = None;
551        *self.request_start_time.write().await = None;
552        *self.current_model.write().await = None;
553
554        Ok(Some(response.to_string()))
555    }
556
557    /// 在收到 LLM 响应后调用 - 记录助手消息和 API 调用(带元数据)
558    async fn after_chat_with_metadata(
559        &self,
560        response: &str,
561        metadata: &LLMResponseMetadata,
562    ) -> LLMResult<Option<String>> {
563        // 保存 response_id
564        *self.response_id.write().await = Some(metadata.id.clone());
565
566        // 保存助手消息
567        let assistant_msg_id = self.save_assistant_message(response).await?;
568        info!("✅ [持久化插件] 助手消息已保存: ID = {}", assistant_msg_id);
569
570        // 计算请求延迟
571        let latency = match *self.request_start_time.read().await {
572            Some(start) => start.elapsed().as_millis() as i32,
573            None => 0,
574        };
575
576        // 记录 API 调用
577        if let Some(user_msg_id) = *self.current_user_msg_id.read().await {
578            let session_id = *self.session_id.read().await;
579            let now = chrono::Utc::now();
580            let request_time = now - chrono::Duration::milliseconds(latency as i64);
581
582            let mut api_call = LLMApiCall::success(
583                session_id,
584                self.agent_id,
585                self.user_id,
586                self.tenant_id,
587                user_msg_id,
588                assistant_msg_id,
589                &metadata.model,
590                metadata.prompt_tokens as i32,
591                metadata.completion_tokens as i32,
592                request_time,
593                now,
594            );
595
596            // 设置 response_id
597            api_call = api_call.with_api_response_id(&metadata.id);
598
599            let _ = self
600                .api_call_store
601                .save_api_call(&api_call)
602                .await
603                .map_err(|e| LLMError::Other(e.to_string()));
604            info!(
605                "✅ [持久化插件] API 调用记录已保存: 模型={}, tokens={}/{}, 延迟={}ms",
606                metadata.model, metadata.prompt_tokens, metadata.completion_tokens, latency
607            );
608        }
609
610        // 清理状态
611        *self.current_user_msg_id.write().await = None;
612        *self.request_start_time.write().await = None;
613        *self.response_id.write().await = None;
614
615        Ok(Some(response.to_string()))
616    }
617
618    /// 在发生错误时调用 - 记录 API 错误
619    async fn on_error(&self, error: &LLMError) -> LLMResult<Option<String>> {
620        info!("✅ [持久化插件] 记录 API 错误...");
621
622        // 获取存储的模型名称,或使用默认值
623        let model = self.current_model.read().await;
624        let model_name = model.as_ref().map(|s| s.as_str()).unwrap_or("unknown");
625
626        if let Some(user_msg_id) = *self.current_user_msg_id.read().await {
627            let session_id = *self.session_id.read().await;
628            let now = chrono::Utc::now();
629
630            let api_call = LLMApiCall::failed(
631                session_id,
632                self.agent_id,
633                self.user_id,
634                self.tenant_id,
635                user_msg_id,
636                model_name,
637                error.to_string(),
638                None,
639                now,
640            );
641
642            let _ = self
643                .api_call_store
644                .save_api_call(&api_call)
645                .await
646                .map_err(|e| LLMError::Other(e.to_string()));
647            info!("✅ [持久化插件] API 错误记录已保存");
648        }
649
650        // 清理状态
651        *self.current_user_msg_id.write().await = None;
652        *self.request_start_time.write().await = None;
653        *self.current_model.write().await = None;
654
655        Ok(None)
656    }
657}