mofa_foundation/persistence/
traits.rs1use super::entities::*;
6use async_trait::async_trait;
7use std::sync::Arc;
8use uuid::Uuid;
9
10#[derive(Debug, thiserror::Error)]
12pub enum PersistenceError {
13 #[error("Connection error: {0}")]
15 Connection(String),
16 #[error("Query error: {0}")]
18 Query(String),
19 #[error("Serialization error: {0}")]
21 Serialization(String),
22 #[error("Record not found: {0}")]
24 NotFound(String),
25 #[error("Constraint violation: {0}")]
27 Constraint(String),
28 #[error("Persistence error: {0}")]
30 Other(String),
31}
32
33pub type PersistenceResult<T> = Result<T, PersistenceError>;
35
36#[async_trait]
40pub trait MessageStore: Send + Sync {
41 async fn save_message(&self, message: &LLMMessage) -> PersistenceResult<()>;
43
44 async fn save_messages(&self, messages: &[LLMMessage]) -> PersistenceResult<()> {
46 for msg in messages {
47 self.save_message(msg).await?;
48 }
49 Ok(())
50 }
51
52 async fn get_message(&self, id: Uuid) -> PersistenceResult<Option<LLMMessage>>;
54
55 async fn get_session_messages(&self, session_id: Uuid) -> PersistenceResult<Vec<LLMMessage>>;
57
58 async fn get_session_messages_paginated(
60 &self,
61 session_id: Uuid,
62 offset: i64,
63 limit: i64,
64 ) -> PersistenceResult<Vec<LLMMessage>>;
65
66 async fn delete_message(&self, id: Uuid) -> PersistenceResult<bool>;
68
69 async fn delete_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64>;
71
72 async fn count_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64>;
74}
75
76#[async_trait]
80pub trait ApiCallStore: Send + Sync {
81 async fn save_api_call(&self, call: &LLMApiCall) -> PersistenceResult<()>;
83
84 async fn save_api_calls(&self, calls: &[LLMApiCall]) -> PersistenceResult<()> {
86 for call in calls {
87 self.save_api_call(call).await?;
88 }
89 Ok(())
90 }
91
92 async fn get_api_call(&self, id: Uuid) -> PersistenceResult<Option<LLMApiCall>>;
94
95 async fn query_api_calls(&self, filter: &QueryFilter) -> PersistenceResult<Vec<LLMApiCall>>;
97
98 async fn get_statistics(&self, filter: &QueryFilter) -> PersistenceResult<UsageStatistics>;
100
101 async fn delete_api_call(&self, id: Uuid) -> PersistenceResult<bool>;
103
104 async fn cleanup_old_records(
106 &self,
107 before: chrono::DateTime<chrono::Utc>,
108 ) -> PersistenceResult<i64>;
109}
110
111#[async_trait]
115pub trait SessionStore: Send + Sync {
116 async fn create_session(&self, session: &ChatSession) -> PersistenceResult<()>;
118
119 async fn get_session(&self, id: Uuid) -> PersistenceResult<Option<ChatSession>>;
121
122 async fn get_user_sessions(&self, user_id: Uuid) -> PersistenceResult<Vec<ChatSession>>;
124
125 async fn update_session(&self, session: &ChatSession) -> PersistenceResult<()>;
127
128 async fn delete_session(&self, id: Uuid) -> PersistenceResult<bool>;
130}
131
132#[async_trait]
136pub trait ProviderStore: Send + Sync {
137 async fn get_provider(&self, id: Uuid) -> PersistenceResult<Option<super::entities::Provider>>;
139
140 async fn get_provider_by_name(
142 &self,
143 tenant_id: Uuid,
144 name: &str,
145 ) -> PersistenceResult<Option<super::entities::Provider>>;
146
147 async fn list_providers(
149 &self,
150 tenant_id: Uuid,
151 ) -> PersistenceResult<Vec<super::entities::Provider>>;
152
153 async fn get_enabled_providers(
155 &self,
156 tenant_id: Uuid,
157 ) -> PersistenceResult<Vec<super::entities::Provider>>;
158}
159
160#[async_trait]
164pub trait AgentStore: Send + Sync {
165 async fn get_agent(&self, id: Uuid) -> PersistenceResult<Option<super::entities::Agent>>;
167
168 async fn get_agent_by_code(
170 &self,
171 code: &str,
172 ) -> PersistenceResult<Option<super::entities::Agent>>;
173
174 async fn get_agent_by_code_and_tenant(
176 &self,
177 tenant_id: Uuid,
178 code: &str,
179 ) -> PersistenceResult<Option<super::entities::Agent>>;
180
181 async fn list_agents(&self, tenant_id: Uuid) -> PersistenceResult<Vec<super::entities::Agent>>;
183
184 async fn get_active_agents(
186 &self,
187 tenant_id: Uuid,
188 ) -> PersistenceResult<Vec<super::entities::Agent>>;
189
190 async fn get_agent_with_provider(
192 &self,
193 id: Uuid,
194 ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
195
196 async fn get_agent_by_code_with_provider(
198 &self,
199 code: &str,
200 ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
201
202 async fn get_agent_by_code_and_tenant_with_provider(
204 &self,
205 tenant_id: Uuid,
206 code: &str,
207 ) -> PersistenceResult<Option<super::entities::AgentConfig>>;
208}
209
210pub trait PersistenceStore:
214 MessageStore + ApiCallStore + SessionStore + ProviderStore + AgentStore
215{
216 fn backend_name(&self) -> &str;
218
219 fn is_connected(&self) -> bool;
221
222 fn close(&self) -> impl std::future::Future<Output = PersistenceResult<()>> + Send;
224}
225
226#[async_trait]
230pub trait StoreFactory: Send + Sync {
231 type Store: PersistenceStore;
233
234 async fn create(&self, config: &str) -> PersistenceResult<Self::Store>;
236}
237
238#[async_trait]
240pub trait Transactional: Send + Sync {
241 type Transaction<'a>: Send + Sync
243 where
244 Self: 'a;
245
246 async fn begin_transaction(&self) -> PersistenceResult<Self::Transaction<'_>>;
248
249 async fn commit_transaction(&self, tx: Self::Transaction<'_>) -> PersistenceResult<()>;
251
252 async fn rollback_transaction(&self, tx: Self::Transaction<'_>) -> PersistenceResult<()>;
254}
255
256pub type SharedStore<S> = Arc<S>;
260
261pub type DynMessageStore = Arc<dyn MessageStore>;
263pub type DynApiCallStore = Arc<dyn ApiCallStore>;
264pub type DynSessionStore = Arc<dyn SessionStore>;
265
266pub struct CompositeStore<M, A, S> {
268 pub message_store: M,
269 pub api_call_store: A,
270 pub session_store: S,
271}
272
273impl<M, A, S> CompositeStore<M, A, S>
274where
275 M: MessageStore,
276 A: ApiCallStore,
277 S: SessionStore,
278{
279 pub fn new(message_store: M, api_call_store: A, session_store: S) -> Self {
280 Self {
281 message_store,
282 api_call_store,
283 session_store,
284 }
285 }
286}
287
288#[derive(Debug, Clone)]
290pub enum StoreEvent {
291 MessageSaved { message_id: Uuid, session_id: Uuid },
293 ApiCallRecorded { call_id: Uuid, session_id: Uuid },
295 SessionCreated { session_id: Uuid },
297 SessionDeleted { session_id: Uuid },
299}
300
301#[async_trait]
303pub trait StoreEventListener: Send + Sync {
304 async fn on_event(&self, event: StoreEvent);
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_persistence_error_display() {
314 let err = PersistenceError::NotFound("user".to_string());
315 assert!(err.to_string().contains("not found"));
316 }
317
318 #[test]
319 fn test_query_filter_default() {
320 let filter = QueryFilter::default();
321 assert!(filter.user_id.is_none());
322 assert!(filter.limit.is_none());
323 }
324}