1use axum::{
4 extract::{Path, Query, State},
5 http::StatusCode,
6 Json,
7};
8use chrono::Utc;
9use nexus_storage::repository::StoreMemoryParams;
10use serde::Deserialize;
11use serde_json::json;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::info;
15
16use crate::{
17 error::{Result, WebError},
18 models::{
19 CreateMemoryRequest, MemoryCreateResponse, MemoryListResponse, MemoryResponse,
20 SearchRequest, SearchResponse, UpdateMemoryRequest, WebSocketMessage,
21 },
22 state::AppState,
23};
24
25#[derive(Debug, Deserialize)]
27pub struct ListMemoriesQuery {
28 #[serde(default = "default_agent_type")]
29 pub agent_type: String,
30 pub query: Option<String>,
31 pub category: Option<String>,
32 pub memory_lane_type: Option<String>,
33 #[serde(default = "default_limit")]
34 pub limit: usize,
35 #[serde(default)]
36 pub offset: usize,
37}
38
39fn default_agent_type() -> String {
40 "general".to_string()
41}
42
43fn default_limit() -> usize {
44 20
45}
46
47pub async fn list_memories(
49 State(state): State<Arc<RwLock<AppState>>>,
50 Query(params): Query<ListMemoriesQuery>,
51) -> Result<Json<MemoryListResponse>> {
52 let state = state.read().await;
53
54 let namespace = state
56 .namespace_repo
57 .get_or_create(¶ms.agent_type, ¶ms.agent_type)
58 .await?;
59
60 let memories = state
62 .memory_repo
63 .search_by_namespace(namespace.id, params.limit, params.offset)
64 .await?;
65
66 let total = state.memory_repo.count_by_namespace(namespace.id).await?;
67
68 let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
70
71 let filters = json!({
72 "category": params.category,
73 "memory_lane_type": params.memory_lane_type,
74 });
75
76 Ok(Json(MemoryListResponse {
77 success: true,
78 total,
79 results,
80 query: params.query,
81 agent_type: params.agent_type,
82 filters,
83 }))
84}
85
86pub async fn create_memory(
88 State(state): State<Arc<RwLock<AppState>>>,
89 Json(request): Json<CreateMemoryRequest>,
90) -> Result<(StatusCode, Json<MemoryCreateResponse>)> {
91 let state = state.read().await;
92
93 if request.content.trim().is_empty() {
95 return Err(WebError::InvalidRequest(
96 "Content cannot be empty".to_string(),
97 ));
98 }
99
100 let namespace = state
102 .namespace_repo
103 .get_or_create(&request.agent_type, &request.agent_type)
104 .await?;
105
106 let memory = state
108 .memory_repo
109 .store(StoreMemoryParams {
110 namespace_id: namespace.id,
111 content: &request.content,
112 category: &request.category,
113 memory_lane_type: request.memory_lane_type.as_ref(),
114 labels: &request.labels,
115 metadata: &request.metadata,
116 embedding: None,
117 embedding_model: None,
118 })
119 .await?;
120
121 let memory_response = MemoryResponse::from(memory.clone());
123 let ws_msg = WebSocketMessage::memory_stored(&memory_response, &request.agent_type);
124 let _ = state.broadcast_ws(ws_msg);
125
126 info!(
127 "Memory created: id={}, agent_type={}",
128 memory.id, request.agent_type
129 );
130
131 Ok((
132 StatusCode::CREATED,
133 Json(MemoryCreateResponse {
134 success: true,
135 memory_id: Some(memory.id),
136 agent_type: request.agent_type,
137 category: request.category.to_string(),
138 error: None,
139 }),
140 ))
141}
142
143pub async fn get_memory(
145 State(state): State<Arc<RwLock<AppState>>>,
146 Path(id): Path<i64>,
147) -> Result<Json<MemoryResponse>> {
148 let state = state.read().await;
149
150 let memory = state
151 .memory_repo
152 .get_by_id(id)
153 .await?
154 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
155
156 let _ = state.memory_repo.touch(id).await;
158
159 Ok(Json(MemoryResponse::from(memory)))
160}
161
162pub async fn update_memory(
164 State(state): State<Arc<RwLock<AppState>>>,
165 Path(id): Path<i64>,
166 Json(request): Json<UpdateMemoryRequest>,
167) -> Result<Json<MemoryResponse>> {
168 let state = state.read().await;
169
170 let existing = state
172 .memory_repo
173 .get_by_id(id)
174 .await?
175 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
176
177 let mut updates: Vec<String> = Vec::new();
179
180 if let Some(content) = request.content {
181 if !content.trim().is_empty() {
182 updates.push(format!("content = '{}'", content.replace("'", "''")));
183 }
184 }
185
186 if let Some(category) = request.category {
187 updates.push(format!("category = '{}'", category));
188 }
189
190 if let Some(memory_lane_type) = request.memory_lane_type {
191 updates.push(format!("memory_lane_type = '{}'", memory_lane_type));
192 }
193
194 if let Some(labels) = request.labels {
195 let labels_json = serde_json::to_string(&labels).unwrap_or_default();
196 updates.push(format!("labels = '{}'", labels_json.replace("'", "''")));
197 }
198
199 if let Some(metadata) = request.metadata {
200 let metadata_json = serde_json::to_string(&metadata).unwrap_or_default();
201 updates.push(format!("metadata = '{}'", metadata_json.replace("'", "''")));
202 }
203
204 if let Some(is_active) = request.is_active {
205 updates.push(format!("is_active = {}", if is_active { 1 } else { 0 }));
206 }
207
208 if let Some(is_archived) = request.is_archived {
209 updates.push(format!("is_archived = {}", if is_archived { 1 } else { 0 }));
210 }
211
212 if updates.is_empty() {
213 return Ok(Json(MemoryResponse::from(existing)));
214 }
215
216 updates.push(format!("updated_at = '{}'", Utc::now().to_rfc3339()));
217
218 let query = format!("UPDATE memories SET {} WHERE id = ?", updates.join(", "));
219
220 sqlx::query(&query)
221 .bind(id)
222 .execute(state.pool())
223 .await
224 .map_err(|e| WebError::Storage(e.to_string()))?;
225
226 let updated = state
228 .memory_repo
229 .get_by_id(id)
230 .await?
231 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found after update", id)))?;
232
233 let ws_msg = WebSocketMessage::memory_updated(id);
235 let _ = state.broadcast_ws(ws_msg);
236
237 info!("Memory updated: id={}", id);
238
239 Ok(Json(MemoryResponse::from(updated)))
240}
241
242pub async fn delete_memory(
244 State(state): State<Arc<RwLock<AppState>>>,
245 Path(id): Path<i64>,
246) -> Result<StatusCode> {
247 let state = state.read().await;
248
249 let _ = state
251 .memory_repo
252 .get_by_id(id)
253 .await?
254 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
255
256 sqlx::query("UPDATE memories SET is_active = 0, is_archived = 1, updated_at = ? WHERE id = ?")
258 .bind(Utc::now())
259 .bind(id)
260 .execute(state.pool())
261 .await
262 .map_err(|e| WebError::Storage(e.to_string()))?;
263
264 let ws_msg = WebSocketMessage::memory_deleted(id);
266 let _ = state.broadcast_ws(ws_msg);
267
268 info!("Memory deleted: id={}", id);
269
270 Ok(StatusCode::NO_CONTENT)
271}
272
273pub async fn search_memories(
275 State(state): State<Arc<RwLock<AppState>>>,
276 Json(request): Json<SearchRequest>,
277) -> Result<Json<SearchResponse>> {
278 let state = state.read().await;
279
280 if request.query.trim().is_empty() {
282 return Err(WebError::InvalidRequest(
283 "Query cannot be empty".to_string(),
284 ));
285 }
286
287 let namespace = state
289 .namespace_repo
290 .get_or_create(&request.agent_type, &request.agent_type)
291 .await?;
292
293 let search_pattern = format!(
296 "%{}%",
297 request.query.replace("%", "\\%").replace("_", "\\_")
298 );
299
300 let query_str = "SELECT * FROM memories WHERE namespace_id = ? AND is_active = 1 AND content LIKE ? ORDER BY created_at DESC LIMIT ? OFFSET ?".to_string();
301
302 let rows: Vec<nexus_storage::models::MemoryRow> = sqlx::query_as(&query_str)
303 .bind(namespace.id)
304 .bind(&search_pattern)
305 .bind(request.limit as i64)
306 .bind(request.offset as i64)
307 .fetch_all(state.pool())
308 .await
309 .map_err(|e| WebError::Storage(e.to_string()))?;
310
311 let memories: Vec<nexus_core::Memory> = rows.into_iter().map(row_to_memory).collect();
313
314 let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
315
316 let total = results.len() as i64;
317
318 let filters = json!({
319 "category": request.category.map(|c| c.to_string()),
320 "memory_lane_type": request.memory_lane_type.map(|t| t.to_string()),
321 "threshold": request.threshold,
322 });
323
324 Ok(Json(SearchResponse {
325 success: true,
326 results,
327 total,
328 query: request.query,
329 agent_type: request.agent_type,
330 filters,
331 error: None,
332 }))
333}
334
335fn row_to_memory(row: nexus_storage::models::MemoryRow) -> nexus_core::Memory {
337 use nexus_core::{Memory, MemoryCategory, MemoryLaneType};
338
339 let labels: Vec<String> = serde_json::from_str(&row.labels).unwrap_or_default();
340 let metadata: serde_json::Value =
341 serde_json::from_str(&row.metadata).unwrap_or(serde_json::Value::Null);
342 let embedding: Option<Vec<f32>> = row
343 .content_embedding
344 .and_then(|e| serde_json::from_str(&e).ok());
345
346 Memory {
347 id: row.id,
348 namespace_id: row.namespace_id,
349 content: row.content,
350 category: MemoryCategory::parse(&row.category).unwrap_or(MemoryCategory::General),
351 memory_lane_type: row
352 .memory_lane_type
353 .as_deref()
354 .and_then(MemoryLaneType::parse),
355 labels,
356 metadata,
357 similarity_score: row.similarity_score,
358 relevance_score: row.relevance_score,
359 content_embedding: embedding,
360 embedding_model: row.embedding_model,
361 created_at: row.created_at,
362 updated_at: row.updated_at,
363 last_accessed: row.last_accessed,
364 is_active: row.is_active,
365 is_archived: row.is_archived,
366 access_count: row.access_count,
367 }
368}