Skip to main content

mofa_foundation/persistence/
traits.rs

1//! 持久化核心 traits
2//!
3//! 定义存储后端必须实现的接口
4
5use super::entities::*;
6use async_trait::async_trait;
7use std::sync::Arc;
8use uuid::Uuid;
9
10/// 持久化错误
11#[derive(Debug, thiserror::Error)]
12pub enum PersistenceError {
13    /// 连接错误
14    #[error("Connection error: {0}")]
15    Connection(String),
16    /// 查询错误
17    #[error("Query error: {0}")]
18    Query(String),
19    /// 序列化错误
20    #[error("Serialization error: {0}")]
21    Serialization(String),
22    /// 记录未找到
23    #[error("Record not found: {0}")]
24    NotFound(String),
25    /// 约束冲突
26    #[error("Constraint violation: {0}")]
27    Constraint(String),
28    /// 其他错误
29    #[error("Persistence error: {0}")]
30    Other(String),
31}
32
33/// 持久化结果类型
34pub type PersistenceResult<T> = Result<T, PersistenceError>;
35
36/// 消息存储 trait
37///
38/// 提供 LLM 消息的 CRUD 操作
39#[async_trait]
40pub trait MessageStore: Send + Sync {
41    /// 保存消息
42    async fn save_message(&self, message: &LLMMessage) -> PersistenceResult<()>;
43
44    /// 批量保存消息
45    async fn save_messages(&self, messages: &[LLMMessage]) -> PersistenceResult<()> {
46        for msg in messages {
47            self.save_message(msg).await?;
48        }
49        Ok(())
50    }
51
52    /// 获取消息
53    async fn get_message(&self, id: Uuid) -> PersistenceResult<Option<LLMMessage>>;
54
55    /// 获取会话消息列表
56    async fn get_session_messages(&self, session_id: Uuid) -> PersistenceResult<Vec<LLMMessage>>;
57
58    /// 获取会话消息列表 (分页)
59    async fn get_session_messages_paginated(
60        &self,
61        session_id: Uuid,
62        offset: i64,
63        limit: i64,
64    ) -> PersistenceResult<Vec<LLMMessage>>;
65
66    /// 删除消息
67    async fn delete_message(&self, id: Uuid) -> PersistenceResult<bool>;
68
69    /// 删除会话所有消息
70    async fn delete_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64>;
71
72    /// 统计会话消息数
73    async fn count_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64>;
74}
75
76/// API 调用记录存储 trait
77///
78/// 提供 LLM API 调用记录的存储和查询
79#[async_trait]
80pub trait ApiCallStore: Send + Sync {
81    /// 保存 API 调用记录
82    async fn save_api_call(&self, call: &LLMApiCall) -> PersistenceResult<()>;
83
84    /// 批量保存 API 调用记录
85    async fn save_api_calls(&self, calls: &[LLMApiCall]) -> PersistenceResult<()> {
86        for call in calls {
87            self.save_api_call(call).await?;
88        }
89        Ok(())
90    }
91
92    /// 获取 API 调用记录
93    async fn get_api_call(&self, id: Uuid) -> PersistenceResult<Option<LLMApiCall>>;
94
95    /// 查询 API 调用记录
96    async fn query_api_calls(&self, filter: &QueryFilter) -> PersistenceResult<Vec<LLMApiCall>>;
97
98    /// 统计 API 调用
99    async fn get_statistics(&self, filter: &QueryFilter) -> PersistenceResult<UsageStatistics>;
100
101    /// 删除 API 调用记录
102    async fn delete_api_call(&self, id: Uuid) -> PersistenceResult<bool>;
103
104    /// 清理旧记录
105    async fn cleanup_old_records(
106        &self,
107        before: chrono::DateTime<chrono::Utc>,
108    ) -> PersistenceResult<i64>;
109}
110
111/// 会话存储 trait
112///
113/// 提供聊天会话的管理
114#[async_trait]
115pub trait SessionStore: Send + Sync {
116    /// 创建会话
117    async fn create_session(&self, session: &ChatSession) -> PersistenceResult<()>;
118
119    /// 获取会话
120    async fn get_session(&self, id: Uuid) -> PersistenceResult<Option<ChatSession>>;
121
122    /// 获取用户会话列表
123    async fn get_user_sessions(&self, user_id: Uuid) -> PersistenceResult<Vec<ChatSession>>;
124
125    /// 更新会话
126    async fn update_session(&self, session: &ChatSession) -> PersistenceResult<()>;
127
128    /// 删除会话
129    async fn delete_session(&self, id: Uuid) -> PersistenceResult<bool>;
130}
131
132/// Provider 存储 trait
133///
134/// 提供 LLM Provider 的数据库操作
135#[async_trait]
136pub trait ProviderStore: Send + Sync {
137    /// 根据 ID 获取 provider
138    async fn get_provider(&self, id: Uuid) -> PersistenceResult<Option<super::entities::Provider>>;
139
140    /// 根据名称和租户 ID 获取 provider
141    async fn get_provider_by_name(
142        &self,
143        tenant_id: Uuid,
144        name: &str,
145    ) -> PersistenceResult<Option<super::entities::Provider>>;
146
147    /// 列出租户的所有 providers
148    async fn list_providers(
149        &self,
150        tenant_id: Uuid,
151    ) -> PersistenceResult<Vec<super::entities::Provider>>;
152
153    /// 获取租户所有启用的 providers
154    async fn get_enabled_providers(
155        &self,
156        tenant_id: Uuid,
157    ) -> PersistenceResult<Vec<super::entities::Provider>>;
158}
159
160/// Agent 存储 trait
161///
162/// 提供 LLM Agent 配置的数据库操作
163#[async_trait]
164pub trait AgentStore: Send + Sync {
165    /// 根据 ID 获取 agent
166    async fn get_agent(&self, id: Uuid) -> PersistenceResult<Option<super::entities::Agent>>;
167
168    /// 根据 code 获取 agent(全局查找)
169    async fn get_agent_by_code(
170        &self,
171        code: &str,
172    ) -> PersistenceResult<Option<super::entities::Agent>>;
173
174    /// 根据 code 和租户 ID 获取 agent
175    async fn get_agent_by_code_and_tenant(
176        &self,
177        tenant_id: Uuid,
178        code: &str,
179    ) -> PersistenceResult<Option<super::entities::Agent>>;
180
181    /// 列出租户的所有 agents
182    async fn list_agents(&self, tenant_id: Uuid) -> PersistenceResult<Vec<super::entities::Agent>>;
183
184    /// 获取租户所有启用的 agents
185    async fn get_active_agents(
186        &self,
187        tenant_id: Uuid,
188    ) -> PersistenceResult<Vec<super::entities::Agent>>;
189
190    /// 根据 ID 获取 agent 及其 provider 配置
191    async fn get_agent_with_provider(
192        &self,
193        id: Uuid,
194    ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
195
196    /// 根据 code 获取 agent 及其 provider 配置(全局查找)
197    async fn get_agent_by_code_with_provider(
198        &self,
199        code: &str,
200    ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
201
202    /// 根据 code 和租户 ID 获取 agent 及其 provider 配置
203    async fn get_agent_by_code_and_tenant_with_provider(
204        &self,
205        tenant_id: Uuid,
206        code: &str,
207    ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
208}
209
210/// 完整的持久化存储 trait
211///
212/// 组合所有存储能力
213pub trait PersistenceStore:
214    MessageStore + ApiCallStore + SessionStore + ProviderStore + AgentStore
215{
216    /// 获取存储后端名称
217    fn backend_name(&self) -> &str;
218
219    /// 检查连接状态
220    fn is_connected(&self) -> bool;
221
222    /// 关闭连接
223    fn close(&self) -> impl std::future::Future<Output = PersistenceResult<()>> + Send;
224}
225
226/// 存储工厂 trait
227///
228/// 用于创建存储实例
229#[async_trait]
230pub trait StoreFactory: Send + Sync {
231    /// 存储类型
232    type Store: PersistenceStore;
233
234    /// 创建存储实例
235    async fn create(&self, config: &str) -> PersistenceResult<Self::Store>;
236}
237
238/// 事务支持 trait (可选)
239#[async_trait]
240pub trait Transactional: Send + Sync {
241    /// 事务类型
242    type Transaction<'a>: Send + Sync
243    where
244        Self: 'a;
245
246    /// 开始事务
247    async fn begin_transaction(&self) -> PersistenceResult<Self::Transaction<'_>>;
248
249    /// 提交事务
250    async fn commit_transaction(&self, tx: Self::Transaction<'_>) -> PersistenceResult<()>;
251
252    /// 回滚事务
253    async fn rollback_transaction(&self, tx: Self::Transaction<'_>) -> PersistenceResult<()>;
254}
255
256/// 存储引用包装
257///
258/// 便于在多个组件间共享存储
259pub type SharedStore<S> = Arc<S>;
260
261/// 动态分发的存储类型
262pub type DynMessageStore = Arc<dyn MessageStore>;
263pub type DynApiCallStore = Arc<dyn ApiCallStore>;
264pub type DynSessionStore = Arc<dyn SessionStore>;
265
266/// 组合存储包装器
267pub struct CompositeStore<M, A, S> {
268    pub message_store: M,
269    pub api_call_store: A,
270    pub session_store: S,
271}
272
273impl<M, A, S> CompositeStore<M, A, S>
274where
275    M: MessageStore,
276    A: ApiCallStore,
277    S: SessionStore,
278{
279    pub fn new(message_store: M, api_call_store: A, session_store: S) -> Self {
280        Self {
281            message_store,
282            api_call_store,
283            session_store,
284        }
285    }
286}
287
288/// 存储事件
289#[derive(Debug, Clone)]
290pub enum StoreEvent {
291    /// 消息已保存
292    MessageSaved { message_id: Uuid, session_id: Uuid },
293    /// API 调用已记录
294    ApiCallRecorded { call_id: Uuid, session_id: Uuid },
295    /// 会话已创建
296    SessionCreated { session_id: Uuid },
297    /// 会话已删除
298    SessionDeleted { session_id: Uuid },
299}
300
301/// 存储事件监听器
302#[async_trait]
303pub trait StoreEventListener: Send + Sync {
304    /// 处理事件
305    async fn on_event(&self, event: StoreEvent);
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_persistence_error_display() {
314        let err = PersistenceError::NotFound("user".to_string());
315        assert!(err.to_string().contains("not found"));
316    }
317
318    #[test]
319    fn test_query_filter_default() {
320        let filter = QueryFilter::default();
321        assert!(filter.user_id.is_none());
322        assert!(filter.limit.is_none());
323    }
324}