codex_memory/mcp/
handlers.rs

1use crate::memory::{
2    models::{CreateMemoryRequest, SearchRequest, SearchResult, UpdateMemoryRequest},
3    Memory, MemoryError, MemoryRepository, MemoryTier,
4};
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::{debug, error, info, instrument, warn};
9use uuid::Uuid;
10
11#[async_trait]
12pub trait MemoryHandler: Send + Sync {
13    async fn create_memory(&self, request: CreateMemoryRequest) -> Result<Memory, MemoryError>;
14    async fn get_memory(&self, id: Uuid) -> Result<Memory, MemoryError>;
15    async fn update_memory(
16        &self,
17        id: Uuid,
18        request: UpdateMemoryRequest,
19    ) -> Result<Memory, MemoryError>;
20    async fn delete_memory(&self, id: Uuid) -> Result<(), MemoryError>;
21    async fn search_memories(
22        &self,
23        request: SearchRequest,
24    ) -> Result<Vec<SearchResult>, MemoryError>;
25    async fn migrate_memory(
26        &self,
27        id: Uuid,
28        to_tier: MemoryTier,
29        reason: Option<String>,
30    ) -> Result<Memory, MemoryError>;
31}
32
33pub struct MemoryHandlerImpl {
34    repository: Arc<MemoryRepository>,
35    #[allow(dead_code)]
36    circuit_breaker: Arc<crate::mcp::circuit_breaker::CircuitBreaker>,
37    retry_policy: Arc<crate::mcp::retry::RetryPolicy>,
38}
39
40impl MemoryHandlerImpl {
41    pub fn new(
42        repository: Arc<MemoryRepository>,
43        circuit_breaker: Arc<crate::mcp::circuit_breaker::CircuitBreaker>,
44        retry_policy: Arc<crate::mcp::retry::RetryPolicy>,
45    ) -> Self {
46        Self {
47            repository,
48            circuit_breaker,
49            retry_policy,
50        }
51    }
52}
53
54#[async_trait]
55impl MemoryHandler for MemoryHandlerImpl {
56    #[instrument(skip(self))]
57    async fn create_memory(&self, request: CreateMemoryRequest) -> Result<Memory, MemoryError> {
58        debug!(
59            "Creating memory with content length: {}",
60            request.content.len()
61        );
62
63        let repo = self.repository.clone();
64        let result = self
65            .retry_policy
66            .execute(|| async { repo.create_memory(request.clone()).await })
67            .await;
68
69        match &result {
70            Ok(memory) => info!("Created memory {} in tier {:?}", memory.id, memory.tier),
71            Err(e) => error!("Failed to create memory: {}", e),
72        }
73
74        result
75    }
76
77    #[instrument(skip(self))]
78    async fn get_memory(&self, id: Uuid) -> Result<Memory, MemoryError> {
79        debug!("Getting memory {}", id);
80
81        let repo = self.repository.clone();
82        let result = self
83            .retry_policy
84            .execute(|| async { repo.get_memory(id).await })
85            .await;
86
87        match &result {
88            Ok(memory) => debug!("Retrieved memory {} from tier {:?}", id, memory.tier),
89            Err(e) => warn!("Failed to get memory {}: {}", id, e),
90        }
91
92        result
93    }
94
95    #[instrument(skip(self))]
96    async fn update_memory(
97        &self,
98        id: Uuid,
99        request: UpdateMemoryRequest,
100    ) -> Result<Memory, MemoryError> {
101        debug!("Updating memory {}", id);
102
103        let repo = self.repository.clone();
104        let result = self
105            .retry_policy
106            .execute(|| async { repo.update_memory(id, request.clone()).await })
107            .await;
108
109        match &result {
110            Ok(memory) => info!("Updated memory {}", memory.id),
111            Err(e) => error!("Failed to update memory {}: {}", id, e),
112        }
113
114        result
115    }
116
117    #[instrument(skip(self))]
118    async fn delete_memory(&self, id: Uuid) -> Result<(), MemoryError> {
119        debug!("Deleting memory {}", id);
120
121        let repo = self.repository.clone();
122        let result = self
123            .retry_policy
124            .execute(|| async { repo.delete_memory(id).await })
125            .await;
126
127        match &result {
128            Ok(_) => info!("Deleted memory {}", id),
129            Err(e) => error!("Failed to delete memory {}: {}", id, e),
130        }
131
132        result
133    }
134
135    #[instrument(skip(self, request))]
136    async fn search_memories(
137        &self,
138        request: SearchRequest,
139    ) -> Result<Vec<SearchResult>, MemoryError> {
140        debug!("Searching memories with limit {:?}", request.limit);
141
142        let repo = self.repository.clone();
143        let result = self
144            .retry_policy
145            .execute(|| async { repo.search_memories(request.clone()).await })
146            .await;
147
148        match &result {
149            Ok(results) => info!("Found {} memories matching search", results.results.len()),
150            Err(e) => error!("Failed to search memories: {}", e),
151        }
152
153        result.map(|response| response.results)
154    }
155
156    #[instrument(skip(self))]
157    async fn migrate_memory(
158        &self,
159        id: Uuid,
160        to_tier: MemoryTier,
161        reason: Option<String>,
162    ) -> Result<Memory, MemoryError> {
163        debug!("Migrating memory {} to tier {:?}", id, to_tier);
164
165        let repo = self.repository.clone();
166        let result = self
167            .retry_policy
168            .execute(|| async { repo.migrate_memory(id, to_tier, reason.clone()).await })
169            .await;
170
171        match &result {
172            Ok(memory) => info!("Migrated memory {} to tier {:?}", id, memory.tier),
173            Err(e) => error!("Failed to migrate memory {}: {}", id, e),
174        }
175
176        result
177    }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct HealthCheckResponse {
182    pub status: String,
183    pub timestamp: chrono::DateTime<chrono::Utc>,
184    pub version: String,
185    pub uptime_seconds: u64,
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct MetricsResponse {
190    pub total_requests: u64,
191    pub active_connections: u32,
192    pub memory_count: u64,
193    pub error_rate: f64,
194    pub avg_response_time_ms: f64,
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_health_check_response_serialization() {
203        let response = HealthCheckResponse {
204            status: "healthy".to_string(),
205            timestamp: chrono::Utc::now(),
206            version: "1.0.0".to_string(),
207            uptime_seconds: 3600,
208        };
209
210        let json = serde_json::to_string(&response).unwrap();
211        assert!(json.contains("healthy"));
212        assert!(json.contains("1.0.0"));
213    }
214
215    #[test]
216    fn test_metrics_response_serialization() {
217        let response = MetricsResponse {
218            total_requests: 1000,
219            active_connections: 10,
220            memory_count: 500,
221            error_rate: 0.01,
222            avg_response_time_ms: 25.5,
223        };
224
225        let json = serde_json::to_string(&response).unwrap();
226        let deserialized: MetricsResponse = serde_json::from_str(&json).unwrap();
227        assert_eq!(deserialized.total_requests, 1000);
228    }
229}