mofa_foundation/persistence/
memory.rs1use super::entities::*;
6use super::traits::*;
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicBool, Ordering};
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14pub struct InMemoryStore {
36 messages: Arc<RwLock<HashMap<Uuid, LLMMessage>>>,
38 session_messages: Arc<RwLock<HashMap<Uuid, Vec<Uuid>>>>,
40 api_calls: Arc<RwLock<HashMap<Uuid, LLMApiCall>>>,
42 sessions: Arc<RwLock<HashMap<Uuid, ChatSession>>>,
44 user_sessions: Arc<RwLock<HashMap<Uuid, Vec<Uuid>>>>,
46 connected: AtomicBool,
48}
49
50impl InMemoryStore {
51 pub fn new() -> Self {
53 Self {
54 messages: Arc::new(RwLock::new(HashMap::new())),
55 session_messages: Arc::new(RwLock::new(HashMap::new())),
56 api_calls: Arc::new(RwLock::new(HashMap::new())),
57 sessions: Arc::new(RwLock::new(HashMap::new())),
58 user_sessions: Arc::new(RwLock::new(HashMap::new())),
59 connected: AtomicBool::new(true),
60 }
61 }
62
63 pub fn shared() -> Arc<Self> {
65 Arc::new(Self::new())
66 }
67
68 pub async fn clear(&self) {
70 self.messages.write().await.clear();
71 self.session_messages.write().await.clear();
72 self.api_calls.write().await.clear();
73 self.sessions.write().await.clear();
74 self.user_sessions.write().await.clear();
75 }
76
77 pub async fn message_count(&self) -> usize {
79 self.messages.read().await.len()
80 }
81
82 pub async fn api_call_count(&self) -> usize {
84 self.api_calls.read().await.len()
85 }
86
87 pub async fn session_count(&self) -> usize {
89 self.sessions.read().await.len()
90 }
91}
92
93impl Default for InMemoryStore {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99#[async_trait]
100impl MessageStore for InMemoryStore {
101 async fn save_message(&self, message: &LLMMessage) -> PersistenceResult<()> {
102 let mut messages = self.messages.write().await;
103 let mut session_messages = self.session_messages.write().await;
104
105 messages.insert(message.id, message.clone());
106
107 session_messages
108 .entry(message.chat_session_id)
109 .or_insert_with(Vec::new)
110 .push(message.id);
111
112 Ok(())
113 }
114
115 async fn get_message(&self, id: Uuid) -> PersistenceResult<Option<LLMMessage>> {
116 let messages = self.messages.read().await;
117 Ok(messages.get(&id).cloned())
118 }
119
120 async fn get_session_messages(&self, session_id: Uuid) -> PersistenceResult<Vec<LLMMessage>> {
121 let messages = self.messages.read().await;
122 let session_messages = self.session_messages.read().await;
123
124 let msg_ids = session_messages.get(&session_id);
125
126 let mut result = Vec::new();
127 if let Some(ids) = msg_ids {
128 for id in ids {
129 if let Some(msg) = messages.get(id) {
130 result.push(msg.clone());
131 }
132 }
133 }
134
135 result.sort_by(|a, b| a.create_time.cmp(&b.create_time));
137
138 Ok(result)
139 }
140
141 async fn get_session_messages_paginated(
142 &self,
143 session_id: Uuid,
144 offset: i64,
145 limit: i64,
146 ) -> PersistenceResult<Vec<LLMMessage>> {
147 let all_messages = self.get_session_messages(session_id).await?;
148
149 let start = offset as usize;
150 let end = (offset + limit) as usize;
151
152 Ok(all_messages
153 .into_iter()
154 .skip(start)
155 .take(end - start)
156 .collect())
157 }
158
159 async fn delete_message(&self, id: Uuid) -> PersistenceResult<bool> {
160 let mut messages = self.messages.write().await;
161
162 if let Some(msg) = messages.remove(&id) {
163 let mut session_messages = self.session_messages.write().await;
164 if let Some(ids) = session_messages.get_mut(&msg.chat_session_id) {
165 ids.retain(|&x| x != id);
166 }
167 return Ok(true);
168 }
169
170 Ok(false)
171 }
172
173 async fn delete_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64> {
174 let mut messages = self.messages.write().await;
175 let mut session_messages = self.session_messages.write().await;
176
177 let count = if let Some(ids) = session_messages.remove(&session_id) {
178 let len = ids.len();
179 for id in ids {
180 messages.remove(&id);
181 }
182 len as i64
183 } else {
184 0
185 };
186
187 Ok(count)
188 }
189
190 async fn count_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64> {
191 let session_messages = self.session_messages.read().await;
192 Ok(session_messages
193 .get(&session_id)
194 .map(|ids| ids.len() as i64)
195 .unwrap_or(0))
196 }
197}
198
199#[async_trait]
200impl ApiCallStore for InMemoryStore {
201 async fn save_api_call(&self, call: &LLMApiCall) -> PersistenceResult<()> {
202 let mut api_calls = self.api_calls.write().await;
203 api_calls.insert(call.id, call.clone());
204 Ok(())
205 }
206
207 async fn get_api_call(&self, id: Uuid) -> PersistenceResult<Option<LLMApiCall>> {
208 let api_calls = self.api_calls.read().await;
209 Ok(api_calls.get(&id).cloned())
210 }
211
212 async fn query_api_calls(&self, filter: &QueryFilter) -> PersistenceResult<Vec<LLMApiCall>> {
213 let api_calls = self.api_calls.read().await;
214
215 let mut result: Vec<LLMApiCall> = api_calls
216 .values()
217 .filter(|call| {
218 if let Some(user_id) = filter.user_id
220 && call.user_id != user_id
221 {
222 return false;
223 }
224
225 if let Some(session_id) = filter.session_id
227 && call.chat_session_id != session_id
228 {
229 return false;
230 }
231
232 if let Some(agent_id) = filter.agent_id
234 && call.agent_id != agent_id
235 {
236 return false;
237 }
238
239 if let Some(start) = filter.start_time
241 && call.create_time < start
242 {
243 return false;
244 }
245 if let Some(end) = filter.end_time
246 && call.create_time > end
247 {
248 return false;
249 }
250
251 if let Some(status) = filter.status
253 && call.status != status
254 {
255 return false;
256 }
257
258 if let Some(ref model) = filter.model_name
260 && &call.model_name != model
261 {
262 return false;
263 }
264
265 true
266 })
267 .cloned()
268 .collect();
269
270 result.sort_by(|a, b| b.create_time.cmp(&a.create_time));
272
273 let offset = filter.offset.unwrap_or(0) as usize;
275 let limit = filter.limit.unwrap_or(100) as usize;
276
277 Ok(result.into_iter().skip(offset).take(limit).collect())
278 }
279
280 async fn get_statistics(&self, filter: &QueryFilter) -> PersistenceResult<UsageStatistics> {
281 let calls = self.query_api_calls(filter).await?;
282
283 let total_calls = calls.len() as i64;
284 let success_count = calls
285 .iter()
286 .filter(|c| c.status == ApiCallStatus::Success)
287 .count() as i64;
288 let failed_count = total_calls - success_count;
289
290 let total_prompt_tokens: i64 = calls.iter().map(|c| c.prompt_tokens as i64).sum();
291 let total_completion_tokens: i64 = calls.iter().map(|c| c.completion_tokens as i64).sum();
292 let total_tokens = total_prompt_tokens + total_completion_tokens;
293
294 let total_cost: Option<f64> = {
295 let costs: Vec<f64> = calls.iter().filter_map(|c| c.total_price).collect();
296 if costs.is_empty() {
297 None
298 } else {
299 Some(costs.iter().sum())
300 }
301 };
302
303 let avg_latency_ms: Option<f64> = {
304 let latencies: Vec<i32> = calls.iter().filter_map(|c| c.latency_ms).collect();
305 if latencies.is_empty() {
306 None
307 } else {
308 Some(latencies.iter().sum::<i32>() as f64 / latencies.len() as f64)
309 }
310 };
311
312 let avg_tokens_per_second: Option<f64> = {
313 let tps: Vec<f64> = calls.iter().filter_map(|c| c.tokens_per_second).collect();
314 if tps.is_empty() {
315 None
316 } else {
317 Some(tps.iter().sum::<f64>() / tps.len() as f64)
318 }
319 };
320
321 Ok(UsageStatistics {
322 total_calls,
323 success_count,
324 failed_count,
325 total_tokens,
326 total_prompt_tokens,
327 total_completion_tokens,
328 total_cost,
329 avg_latency_ms,
330 avg_tokens_per_second,
331 })
332 }
333
334 async fn delete_api_call(&self, id: Uuid) -> PersistenceResult<bool> {
335 let mut api_calls = self.api_calls.write().await;
336 Ok(api_calls.remove(&id).is_some())
337 }
338
339 async fn cleanup_old_records(
340 &self,
341 before: chrono::DateTime<chrono::Utc>,
342 ) -> PersistenceResult<i64> {
343 let mut api_calls = self.api_calls.write().await;
344 let old_len = api_calls.len();
345
346 api_calls.retain(|_, call| call.create_time >= before);
347
348 Ok((old_len - api_calls.len()) as i64)
349 }
350}
351
352#[async_trait]
353impl SessionStore for InMemoryStore {
354 async fn create_session(&self, session: &ChatSession) -> PersistenceResult<()> {
355 let mut sessions = self.sessions.write().await;
356 let mut user_sessions = self.user_sessions.write().await;
357
358 sessions.insert(session.id, session.clone());
359
360 user_sessions
361 .entry(session.user_id)
362 .or_insert_with(Vec::new)
363 .push(session.id);
364
365 Ok(())
366 }
367
368 async fn get_session(&self, id: Uuid) -> PersistenceResult<Option<ChatSession>> {
369 let sessions = self.sessions.read().await;
370 Ok(sessions.get(&id).cloned())
371 }
372
373 async fn get_user_sessions(&self, user_id: Uuid) -> PersistenceResult<Vec<ChatSession>> {
374 let sessions = self.sessions.read().await;
375 let user_sessions = self.user_sessions.read().await;
376
377 let session_ids = user_sessions.get(&user_id);
378
379 let mut result = Vec::new();
380 if let Some(ids) = session_ids {
381 for id in ids {
382 if let Some(session) = sessions.get(id) {
383 result.push(session.clone());
384 }
385 }
386 }
387
388 result.sort_by(|a, b| b.update_time.cmp(&a.update_time));
390
391 Ok(result)
392 }
393
394 async fn update_session(&self, session: &ChatSession) -> PersistenceResult<()> {
395 let mut sessions = self.sessions.write().await;
396
397 if let std::collections::hash_map::Entry::Occupied(mut e) = sessions.entry(session.id) {
398 e.insert(session.clone());
399 Ok(())
400 } else {
401 Err(PersistenceError::NotFound(format!(
402 "Session {} not found",
403 session.id
404 )))
405 }
406 }
407
408 async fn delete_session(&self, id: Uuid) -> PersistenceResult<bool> {
409 let mut sessions = self.sessions.write().await;
410 let mut user_sessions = self.user_sessions.write().await;
411
412 if let Some(session) = sessions.remove(&id) {
413 if let Some(ids) = user_sessions.get_mut(&session.user_id) {
414 ids.retain(|&x| x != id);
415 }
416 return Ok(true);
417 }
418
419 Ok(false)
420 }
421}
422
423#[async_trait]
424impl ProviderStore for InMemoryStore {
425 async fn get_provider(
426 &self,
427 _id: Uuid,
428 ) -> PersistenceResult<Option<crate::persistence::entities::Provider>> {
429 Ok(None)
431 }
432
433 async fn get_provider_by_name(
434 &self,
435 _tenant_id: Uuid,
436 _name: &str,
437 ) -> PersistenceResult<Option<crate::persistence::entities::Provider>> {
438 Ok(None)
440 }
441
442 async fn list_providers(
443 &self,
444 _tenant_id: Uuid,
445 ) -> PersistenceResult<Vec<crate::persistence::entities::Provider>> {
446 Ok(Vec::new())
448 }
449
450 async fn get_enabled_providers(
451 &self,
452 _tenant_id: Uuid,
453 ) -> PersistenceResult<Vec<crate::persistence::entities::Provider>> {
454 Ok(Vec::new())
456 }
457}
458
459#[async_trait]
460impl AgentStore for InMemoryStore {
461 async fn get_agent(
462 &self,
463 _id: Uuid,
464 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
465 Ok(None)
467 }
468
469 async fn get_agent_by_code(
470 &self,
471 _code: &str,
472 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
473 Ok(None)
475 }
476
477 async fn get_agent_by_code_and_tenant(
478 &self,
479 _tenant_id: Uuid,
480 _code: &str,
481 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
482 Ok(None)
484 }
485
486 async fn list_agents(
487 &self,
488 _tenant_id: Uuid,
489 ) -> PersistenceResult<Vec<crate::persistence::entities::Agent>> {
490 Ok(Vec::new())
492 }
493
494 async fn get_active_agents(
495 &self,
496 _tenant_id: Uuid,
497 ) -> PersistenceResult<Vec<crate::persistence::entities::Agent>> {
498 Ok(Vec::new())
500 }
501
502 async fn get_agent_with_provider(
503 &self,
504 _id: Uuid,
505 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
506 Ok(None)
508 }
509
510 async fn get_agent_by_code_with_provider(
511 &self,
512 _code: &str,
513 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
514 Ok(None)
516 }
517
518 async fn get_agent_by_code_and_tenant_with_provider(
519 &self,
520 _tenant_id: Uuid,
521 _code: &str,
522 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
523 Ok(None)
525 }
526}
527
528impl PersistenceStore for InMemoryStore {
529 fn backend_name(&self) -> &str {
530 "memory"
531 }
532
533 fn is_connected(&self) -> bool {
534 self.connected.load(Ordering::SeqCst)
535 }
536
537 async fn close(&self) -> PersistenceResult<()> {
538 self.connected.store(false, Ordering::SeqCst);
539 Ok(())
540 }
541}
542
543pub struct BoundedInMemoryStore {
547 inner: InMemoryStore,
549 max_messages: usize,
551 max_api_calls: usize,
553}
554
555impl BoundedInMemoryStore {
556 pub fn new(max_messages: usize, max_api_calls: usize) -> Self {
558 Self {
559 inner: InMemoryStore::new(),
560 max_messages,
561 max_api_calls,
562 }
563 }
564
565 pub fn shared(max_messages: usize, max_api_calls: usize) -> Arc<Self> {
567 Arc::new(Self::new(max_messages, max_api_calls))
568 }
569
570 async fn cleanup_messages_if_needed(&self) {
572 let mut messages = self.inner.messages.write().await;
573
574 if messages.len() > self.max_messages {
575 let mut sorted: Vec<_> = messages
577 .iter()
578 .map(|(id, msg)| (*id, msg.create_time))
579 .collect();
580 sorted.sort_by(|a, b| a.1.cmp(&b.1));
581
582 let to_remove: Vec<Uuid> = sorted
583 .into_iter()
584 .take(messages.len() - self.max_messages)
585 .map(|(id, _)| id)
586 .collect();
587
588 for id in to_remove {
589 messages.remove(&id);
590 }
591 }
592 }
593
594 async fn cleanup_api_calls_if_needed(&self) {
596 let mut api_calls = self.inner.api_calls.write().await;
597
598 if api_calls.len() > self.max_api_calls {
599 let mut sorted: Vec<_> = api_calls
601 .iter()
602 .map(|(id, call)| (*id, call.create_time))
603 .collect();
604 sorted.sort_by(|a, b| a.1.cmp(&b.1));
605
606 let to_remove: Vec<Uuid> = sorted
607 .into_iter()
608 .take(api_calls.len() - self.max_api_calls)
609 .map(|(id, _)| id)
610 .collect();
611
612 for id in to_remove {
613 api_calls.remove(&id);
614 }
615 }
616 }
617}
618
619#[async_trait]
620impl MessageStore for BoundedInMemoryStore {
621 async fn save_message(&self, message: &LLMMessage) -> PersistenceResult<()> {
622 self.inner.save_message(message).await?;
623 self.cleanup_messages_if_needed().await;
624 Ok(())
625 }
626
627 async fn get_message(&self, id: Uuid) -> PersistenceResult<Option<LLMMessage>> {
628 self.inner.get_message(id).await
629 }
630
631 async fn get_session_messages(&self, session_id: Uuid) -> PersistenceResult<Vec<LLMMessage>> {
632 self.inner.get_session_messages(session_id).await
633 }
634
635 async fn get_session_messages_paginated(
636 &self,
637 session_id: Uuid,
638 offset: i64,
639 limit: i64,
640 ) -> PersistenceResult<Vec<LLMMessage>> {
641 self.inner
642 .get_session_messages_paginated(session_id, offset, limit)
643 .await
644 }
645
646 async fn delete_message(&self, id: Uuid) -> PersistenceResult<bool> {
647 self.inner.delete_message(id).await
648 }
649
650 async fn delete_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64> {
651 self.inner.delete_session_messages(session_id).await
652 }
653
654 async fn count_session_messages(&self, session_id: Uuid) -> PersistenceResult<i64> {
655 self.inner.count_session_messages(session_id).await
656 }
657}
658
659#[async_trait]
660impl ApiCallStore for BoundedInMemoryStore {
661 async fn save_api_call(&self, call: &LLMApiCall) -> PersistenceResult<()> {
662 self.inner.save_api_call(call).await?;
663 self.cleanup_api_calls_if_needed().await;
664 Ok(())
665 }
666
667 async fn get_api_call(&self, id: Uuid) -> PersistenceResult<Option<LLMApiCall>> {
668 self.inner.get_api_call(id).await
669 }
670
671 async fn query_api_calls(&self, filter: &QueryFilter) -> PersistenceResult<Vec<LLMApiCall>> {
672 self.inner.query_api_calls(filter).await
673 }
674
675 async fn get_statistics(&self, filter: &QueryFilter) -> PersistenceResult<UsageStatistics> {
676 self.inner.get_statistics(filter).await
677 }
678
679 async fn delete_api_call(&self, id: Uuid) -> PersistenceResult<bool> {
680 self.inner.delete_api_call(id).await
681 }
682
683 async fn cleanup_old_records(
684 &self,
685 before: chrono::DateTime<chrono::Utc>,
686 ) -> PersistenceResult<i64> {
687 self.inner.cleanup_old_records(before).await
688 }
689}
690
691#[async_trait]
692impl SessionStore for BoundedInMemoryStore {
693 async fn create_session(&self, session: &ChatSession) -> PersistenceResult<()> {
694 self.inner.create_session(session).await
695 }
696
697 async fn get_session(&self, id: Uuid) -> PersistenceResult<Option<ChatSession>> {
698 self.inner.get_session(id).await
699 }
700
701 async fn get_user_sessions(&self, user_id: Uuid) -> PersistenceResult<Vec<ChatSession>> {
702 self.inner.get_user_sessions(user_id).await
703 }
704
705 async fn update_session(&self, session: &ChatSession) -> PersistenceResult<()> {
706 self.inner.update_session(session).await
707 }
708
709 async fn delete_session(&self, id: Uuid) -> PersistenceResult<bool> {
710 self.inner.delete_session(id).await
711 }
712}
713
714#[async_trait]
715impl ProviderStore for BoundedInMemoryStore {
716 async fn get_provider(
717 &self,
718 id: Uuid,
719 ) -> PersistenceResult<Option<crate::persistence::entities::Provider>> {
720 self.inner.get_provider(id).await
721 }
722
723 async fn get_provider_by_name(
724 &self,
725 tenant_id: Uuid,
726 name: &str,
727 ) -> PersistenceResult<Option<crate::persistence::entities::Provider>> {
728 self.inner.get_provider_by_name(tenant_id, name).await
729 }
730
731 async fn list_providers(
732 &self,
733 tenant_id: Uuid,
734 ) -> PersistenceResult<Vec<crate::persistence::entities::Provider>> {
735 self.inner.list_providers(tenant_id).await
736 }
737
738 async fn get_enabled_providers(
739 &self,
740 tenant_id: Uuid,
741 ) -> PersistenceResult<Vec<crate::persistence::entities::Provider>> {
742 self.inner.get_enabled_providers(tenant_id).await
743 }
744}
745
746#[async_trait]
747impl AgentStore for BoundedInMemoryStore {
748 async fn get_agent(
749 &self,
750 id: Uuid,
751 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
752 self.inner.get_agent(id).await
753 }
754
755 async fn get_agent_by_code(
756 &self,
757 code: &str,
758 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
759 self.inner.get_agent_by_code(code).await
760 }
761
762 async fn get_agent_by_code_and_tenant(
763 &self,
764 tenant_id: Uuid,
765 code: &str,
766 ) -> PersistenceResult<Option<crate::persistence::entities::Agent>> {
767 self.inner
768 .get_agent_by_code_and_tenant(tenant_id, code)
769 .await
770 }
771
772 async fn list_agents(
773 &self,
774 tenant_id: Uuid,
775 ) -> PersistenceResult<Vec<crate::persistence::entities::Agent>> {
776 self.inner.list_agents(tenant_id).await
777 }
778
779 async fn get_active_agents(
780 &self,
781 tenant_id: Uuid,
782 ) -> PersistenceResult<Vec<crate::persistence::entities::Agent>> {
783 self.inner.get_active_agents(tenant_id).await
784 }
785
786 async fn get_agent_with_provider(
787 &self,
788 id: Uuid,
789 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
790 self.inner.get_agent_with_provider(id).await
791 }
792
793 async fn get_agent_by_code_with_provider(
794 &self,
795 code: &str,
796 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
797 self.inner.get_agent_by_code_with_provider(code).await
798 }
799
800 async fn get_agent_by_code_and_tenant_with_provider(
801 &self,
802 tenant_id: Uuid,
803 code: &str,
804 ) -> PersistenceResult<Option<crate::persistence::entities::AgentConfig>> {
805 self.inner
806 .get_agent_by_code_and_tenant_with_provider(tenant_id, code)
807 .await
808 }
809}
810
811impl PersistenceStore for BoundedInMemoryStore {
812 fn backend_name(&self) -> &str {
813 "bounded-memory"
814 }
815
816 fn is_connected(&self) -> bool {
817 self.inner.is_connected()
818 }
819
820 async fn close(&self) -> PersistenceResult<()> {
821 self.inner.close().await
822 }
823}