1use axum::{
4 extract::{Path, Query, State},
5 http::StatusCode,
6 Json,
7};
8use chrono::Utc;
9use nexus_storage::repository::StoreMemoryParams;
10use serde::Deserialize;
11use serde_json::json;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{info, warn};
15
16use crate::{
17 error::{Result, WebError},
18 models::{
19 CreateMemoryRequest, MemoryCreateResponse, MemoryListResponse, MemoryResponse,
20 SearchRequest, SearchResponse, UpdateMemoryRequest, WebSocketMessage,
21 },
22 state::AppState,
23};
24
25#[derive(Debug, Deserialize)]
27pub struct ListMemoriesQuery {
28 #[serde(default = "default_agent_type")]
29 pub agent_type: String,
30 pub query: Option<String>,
31 pub category: Option<String>,
32 pub memory_lane_type: Option<String>,
33 #[serde(default = "default_limit")]
34 pub limit: usize,
35 #[serde(default)]
36 pub offset: usize,
37}
38
39fn default_agent_type() -> String {
40 "general".to_string()
41}
42
43fn default_limit() -> usize {
44 20
45}
46
47pub async fn list_memories(
49 State(state): State<Arc<RwLock<AppState>>>,
50 Query(params): Query<ListMemoriesQuery>,
51) -> Result<Json<MemoryListResponse>> {
52 let state = state.read().await;
53
54 let namespace = state
56 .namespace_repo
57 .get_or_create(¶ms.agent_type, ¶ms.agent_type)
58 .await?;
59
60 let memories = state
62 .memory_repo
63 .search_by_namespace(namespace.id, params.limit, params.offset)
64 .await?;
65
66 let total = state.memory_repo.count_by_namespace(namespace.id).await?;
67
68 let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
70
71 let filters = json!({
72 "category": params.category,
73 "memory_lane_type": params.memory_lane_type,
74 });
75
76 Ok(Json(MemoryListResponse {
77 success: true,
78 total,
79 results,
80 query: params.query,
81 agent_type: params.agent_type,
82 filters,
83 }))
84}
85
86pub async fn create_memory(
88 State(state): State<Arc<RwLock<AppState>>>,
89 Json(request): Json<CreateMemoryRequest>,
90) -> Result<(StatusCode, Json<MemoryCreateResponse>)> {
91 let state = state.read().await;
92
93 if request.content.trim().is_empty() {
95 return Err(WebError::InvalidRequest(
96 "Content cannot be empty".to_string(),
97 ));
98 }
99
100 let namespace = state
102 .namespace_repo
103 .get_or_create(&request.agent_type, &request.agent_type)
104 .await?;
105
106 let memory = state
108 .memory_repo
109 .store(StoreMemoryParams {
110 namespace_id: namespace.id,
111 content: &request.content,
112 category: &request.category,
113 memory_lane_type: request.memory_lane_type.as_ref(),
114 labels: &request.labels,
115 metadata: &request.metadata,
116 embedding: None,
117 embedding_model: None,
118 })
119 .await?;
120
121 let memory_response = MemoryResponse::from(memory.clone());
123 let ws_msg = WebSocketMessage::memory_stored(&memory_response, &request.agent_type);
124 let _ = state.broadcast_ws(ws_msg);
125
126 info!(
127 "Memory created: id={}, agent_type={}",
128 memory.id, request.agent_type
129 );
130
131 Ok((
132 StatusCode::CREATED,
133 Json(MemoryCreateResponse {
134 success: true,
135 memory_id: Some(memory.id),
136 agent_type: request.agent_type,
137 category: request.category.to_string(),
138 error: None,
139 }),
140 ))
141}
142
143pub async fn get_memory(
145 State(state): State<Arc<RwLock<AppState>>>,
146 Path(id): Path<i64>,
147) -> Result<Json<MemoryResponse>> {
148 let state = state.read().await;
149
150 let memory = state
151 .memory_repo
152 .get_by_id(id)
153 .await?
154 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
155
156 let _ = state.memory_repo.touch(id).await;
158
159 Ok(Json(MemoryResponse::from(memory)))
160}
161
162pub async fn update_memory(
164 State(state): State<Arc<RwLock<AppState>>>,
165 Path(id): Path<i64>,
166 Json(request): Json<UpdateMemoryRequest>,
167) -> Result<Json<MemoryResponse>> {
168 let state = state.read().await;
169
170 let existing = state
172 .memory_repo
173 .get_by_id(id)
174 .await?
175 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
176
177 enum UpdateBindValue {
178 Text(String),
179 Bool(bool),
180 }
181
182 let mut set_clauses: Vec<String> = Vec::new();
183 let mut bind_values: Vec<UpdateBindValue> = Vec::new();
184
185 if let Some(content) = request.content {
186 if !content.trim().is_empty() {
187 set_clauses.push("content = ?".to_string());
188 bind_values.push(UpdateBindValue::Text(content));
189 }
190 }
191
192 if let Some(category) = request.category {
193 set_clauses.push("category = ?".to_string());
194 bind_values.push(UpdateBindValue::Text(category.to_string()));
195 }
196
197 if let Some(memory_lane_type) = request.memory_lane_type {
198 set_clauses.push("memory_lane_type = ?".to_string());
199 bind_values.push(UpdateBindValue::Text(memory_lane_type.to_string()));
200 }
201
202 if let Some(labels) = request.labels {
203 match serde_json::to_string(&labels) {
204 Ok(labels_json) => {
205 set_clauses.push("labels = ?".to_string());
206 bind_values.push(UpdateBindValue::Text(labels_json));
207 }
208 Err(e) => {
209 warn!(error = %e, "Failed to serialize labels for memory update; labels omitted from SQL update");
210 }
211 }
212 }
213
214 if let Some(metadata) = request.metadata {
215 match serde_json::to_string(&metadata) {
216 Ok(metadata_json) => {
217 set_clauses.push("metadata = ?".to_string());
218 bind_values.push(UpdateBindValue::Text(metadata_json));
219 }
220 Err(e) => {
221 warn!(error = %e, "Failed to serialize metadata for memory update; metadata omitted from SQL update");
222 }
223 }
224 }
225
226 if let Some(is_active) = request.is_active {
227 set_clauses.push("is_active = ?".to_string());
228 bind_values.push(UpdateBindValue::Bool(is_active));
229 }
230
231 if let Some(is_archived) = request.is_archived {
232 set_clauses.push("is_archived = ?".to_string());
233 bind_values.push(UpdateBindValue::Bool(is_archived));
234 }
235
236 if set_clauses.is_empty() {
237 return Ok(Json(MemoryResponse::from(existing)));
238 }
239
240 set_clauses.push("updated_at = ?".to_string());
241 bind_values.push(UpdateBindValue::Text(
242 Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(),
243 ));
244
245 let query = format!(
246 "UPDATE memories SET {} WHERE id = ?",
247 set_clauses.join(", ")
248 );
249
250 let mut query = sqlx::query(&query);
251 for bind_value in bind_values {
252 query = match bind_value {
253 UpdateBindValue::Text(value) => query.bind(value),
254 UpdateBindValue::Bool(value) => query.bind(value),
255 };
256 }
257
258 query
259 .bind(id)
260 .execute(state.pool())
261 .await
262 .map_err(|e| WebError::Storage(e.to_string()))?;
263
264 let updated = state
266 .memory_repo
267 .get_by_id(id)
268 .await?
269 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found after update", id)))?;
270
271 let ws_msg = WebSocketMessage::memory_updated(id);
273 let _ = state.broadcast_ws(ws_msg);
274
275 info!("Memory updated: id={}", id);
276
277 Ok(Json(MemoryResponse::from(updated)))
278}
279
280pub async fn delete_memory(
282 State(state): State<Arc<RwLock<AppState>>>,
283 Path(id): Path<i64>,
284) -> Result<StatusCode> {
285 let state = state.read().await;
286
287 let _ = state
289 .memory_repo
290 .get_by_id(id)
291 .await?
292 .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
293
294 sqlx::query("UPDATE memories SET is_active = 0, is_archived = 1, updated_at = ? WHERE id = ?")
296 .bind(Utc::now())
297 .bind(id)
298 .execute(state.pool())
299 .await
300 .map_err(|e| WebError::Storage(e.to_string()))?;
301
302 let ws_msg = WebSocketMessage::memory_deleted(id);
304 let _ = state.broadcast_ws(ws_msg);
305
306 info!("Memory deleted: id={}", id);
307
308 Ok(StatusCode::NO_CONTENT)
309}
310
311pub async fn search_memories(
313 State(state): State<Arc<RwLock<AppState>>>,
314 Json(request): Json<SearchRequest>,
315) -> Result<Json<SearchResponse>> {
316 let state = state.read().await;
317
318 if request.query.trim().is_empty() {
320 return Err(WebError::InvalidRequest(
321 "Query cannot be empty".to_string(),
322 ));
323 }
324
325 let namespace = state
327 .namespace_repo
328 .get_or_create(&request.agent_type, &request.agent_type)
329 .await?;
330
331 let search_pattern = format!(
334 "%{}%",
335 request.query.replace("%", "\\%").replace("_", "\\_")
336 );
337
338 let query_str = "SELECT * FROM memories WHERE namespace_id = ? AND is_active = 1 AND content LIKE ? ORDER BY created_at DESC LIMIT ? OFFSET ?".to_string();
339
340 let rows: Vec<nexus_storage::models::MemoryRow> = sqlx::query_as(&query_str)
341 .bind(namespace.id)
342 .bind(&search_pattern)
343 .bind(request.limit as i64)
344 .bind(request.offset as i64)
345 .fetch_all(state.pool())
346 .await
347 .map_err(|e| WebError::Storage(e.to_string()))?;
348
349 let memories: Vec<nexus_core::Memory> = rows
351 .into_iter()
352 .map(row_to_memory)
353 .collect::<crate::error::Result<Vec<_>>>()?;
354
355 let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
356
357 let total = results.len() as i64;
358
359 let filters = json!({
360 "category": request.category.map(|c| c.to_string()),
361 "memory_lane_type": request.memory_lane_type.map(|t| t.to_string()),
362 "threshold": request.threshold,
363 });
364
365 Ok(Json(SearchResponse {
366 success: true,
367 results,
368 total,
369 query: request.query,
370 agent_type: request.agent_type,
371 filters,
372 error: None,
373 }))
374}
375
376fn row_to_memory(
378 row: nexus_storage::models::MemoryRow,
379) -> crate::error::Result<nexus_core::Memory> {
380 use nexus_core::{Memory, MemoryCategory, MemoryLaneType};
381
382 let labels: Vec<String> = serde_json::from_str(&row.labels).map_err(|e| {
383 crate::error::WebError::Storage(format!("corrupted labels JSON for memory {}: {e}", row.id))
384 })?;
385 let metadata: serde_json::Value = serde_json::from_str(&row.metadata).map_err(|e| {
386 crate::error::WebError::Storage(format!(
387 "corrupted metadata JSON for memory {}: {e}",
388 row.id
389 ))
390 })?;
391 let embedding: Option<Vec<f32>> = row
392 .content_embedding
393 .map(|e| {
394 serde_json::from_str(&e).map_err(|err| {
395 crate::error::WebError::Storage(format!(
396 "corrupted embedding JSON for memory {}: {err}",
397 row.id
398 ))
399 })
400 })
401 .transpose()?;
402
403 Ok(Memory {
404 id: row.id,
405 namespace_id: row.namespace_id,
406 content: row.content,
407 category: MemoryCategory::parse(&row.category).ok_or_else(|| {
408 WebError::Storage(format!(
409 "Unknown memory category '{}' persisted in database; row may be corrupted",
410 row.category
411 ))
412 })?,
413 memory_lane_type: match &row.memory_lane_type {
414 Some(s) => Some(MemoryLaneType::parse(s).ok_or_else(|| {
415 WebError::Storage(format!(
416 "Unknown memory_lane_type '{}' persisted in database; row may be corrupted",
417 s
418 ))
419 })?),
420 None => None,
421 },
422 labels,
423 metadata,
424 similarity_score: row.similarity_score,
425 relevance_score: row.relevance_score,
426 content_embedding: embedding,
427 embedding_model: row.embedding_model,
428 created_at: row.created_at,
429 updated_at: row.updated_at,
430 last_accessed: row.last_accessed,
431 is_active: row.is_active,
432 is_archived: row.is_archived,
433 access_count: row.access_count,
434 })
435}