Skip to main content

nexus_memory_web/api/
memories.rs

1//! Memory CRUD API endpoints
2
3use axum::{
4    extract::{Path, Query, State},
5    http::StatusCode,
6    Json,
7};
8use chrono::Utc;
9use nexus_storage::repository::StoreMemoryParams;
10use serde::Deserialize;
11use serde_json::json;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{info, warn};
15
16use crate::{
17    error::{Result, WebError},
18    models::{
19        CreateMemoryRequest, MemoryCreateResponse, MemoryListResponse, MemoryResponse,
20        SearchRequest, SearchResponse, UpdateMemoryRequest, WebSocketMessage,
21    },
22    state::AppState,
23};
24
25/// Query parameters for listing memories
26#[derive(Debug, Deserialize)]
27pub struct ListMemoriesQuery {
28    #[serde(default = "default_agent_type")]
29    pub agent_type: String,
30    pub query: Option<String>,
31    pub category: Option<String>,
32    pub memory_lane_type: Option<String>,
33    #[serde(default = "default_limit")]
34    pub limit: usize,
35    #[serde(default)]
36    pub offset: usize,
37}
38
39fn default_agent_type() -> String {
40    "general".to_string()
41}
42
43fn default_limit() -> usize {
44    20
45}
46
47/// List memories with optional filtering
48pub async fn list_memories(
49    State(state): State<Arc<RwLock<AppState>>>,
50    Query(params): Query<ListMemoriesQuery>,
51) -> Result<Json<MemoryListResponse>> {
52    let state = state.read().await;
53
54    // Get or create namespace
55    let namespace = state
56        .namespace_repo
57        .get_or_create(&params.agent_type, &params.agent_type)
58        .await?;
59
60    // Search memories
61    let memories = state
62        .memory_repo
63        .search_by_namespace(namespace.id, params.limit, params.offset)
64        .await?;
65
66    let total = state.memory_repo.count_by_namespace(namespace.id).await?;
67
68    // Convert to response models
69    let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
70
71    let filters = json!({
72        "category": params.category,
73        "memory_lane_type": params.memory_lane_type,
74    });
75
76    Ok(Json(MemoryListResponse {
77        success: true,
78        total,
79        results,
80        query: params.query,
81        agent_type: params.agent_type,
82        filters,
83    }))
84}
85
86/// Create a new memory
87pub async fn create_memory(
88    State(state): State<Arc<RwLock<AppState>>>,
89    Json(request): Json<CreateMemoryRequest>,
90) -> Result<(StatusCode, Json<MemoryCreateResponse>)> {
91    let state = state.read().await;
92
93    // Validate content
94    if request.content.trim().is_empty() {
95        return Err(WebError::InvalidRequest(
96            "Content cannot be empty".to_string(),
97        ));
98    }
99
100    // Get or create namespace
101    let namespace = state
102        .namespace_repo
103        .get_or_create(&request.agent_type, &request.agent_type)
104        .await?;
105
106    // Store memory
107    let memory = state
108        .memory_repo
109        .store(StoreMemoryParams {
110            namespace_id: namespace.id,
111            content: &request.content,
112            category: &request.category,
113            memory_lane_type: request.memory_lane_type.as_ref(),
114            labels: &request.labels,
115            metadata: &request.metadata,
116            embedding: None,
117            embedding_model: None,
118        })
119        .await?;
120
121    // Broadcast to WebSocket clients
122    let memory_response = MemoryResponse::from(memory.clone());
123    let ws_msg = WebSocketMessage::memory_stored(&memory_response, &request.agent_type);
124    let _ = state.broadcast_ws(ws_msg);
125
126    info!(
127        "Memory created: id={}, agent_type={}",
128        memory.id, request.agent_type
129    );
130
131    Ok((
132        StatusCode::CREATED,
133        Json(MemoryCreateResponse {
134            success: true,
135            memory_id: Some(memory.id),
136            agent_type: request.agent_type,
137            category: request.category.to_string(),
138            error: None,
139        }),
140    ))
141}
142
143/// Get a specific memory by ID
144pub async fn get_memory(
145    State(state): State<Arc<RwLock<AppState>>>,
146    Path(id): Path<i64>,
147) -> Result<Json<MemoryResponse>> {
148    let state = state.read().await;
149
150    let memory = state
151        .memory_repo
152        .get_by_id(id)
153        .await?
154        .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
155
156    // Update access count
157    let _ = state.memory_repo.touch(id).await;
158
159    Ok(Json(MemoryResponse::from(memory)))
160}
161
162/// Update an existing memory
163pub async fn update_memory(
164    State(state): State<Arc<RwLock<AppState>>>,
165    Path(id): Path<i64>,
166    Json(request): Json<UpdateMemoryRequest>,
167) -> Result<Json<MemoryResponse>> {
168    let state = state.read().await;
169
170    // Check if memory exists
171    let existing = state
172        .memory_repo
173        .get_by_id(id)
174        .await?
175        .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
176
177    enum UpdateBindValue {
178        Text(String),
179        Bool(bool),
180    }
181
182    let mut set_clauses: Vec<String> = Vec::new();
183    let mut bind_values: Vec<UpdateBindValue> = Vec::new();
184
185    if let Some(content) = request.content {
186        if !content.trim().is_empty() {
187            set_clauses.push("content = ?".to_string());
188            bind_values.push(UpdateBindValue::Text(content));
189        }
190    }
191
192    if let Some(category) = request.category {
193        set_clauses.push("category = ?".to_string());
194        bind_values.push(UpdateBindValue::Text(category.to_string()));
195    }
196
197    if let Some(memory_lane_type) = request.memory_lane_type {
198        set_clauses.push("memory_lane_type = ?".to_string());
199        bind_values.push(UpdateBindValue::Text(memory_lane_type.to_string()));
200    }
201
202    if let Some(labels) = request.labels {
203        match serde_json::to_string(&labels) {
204            Ok(labels_json) => {
205                set_clauses.push("labels = ?".to_string());
206                bind_values.push(UpdateBindValue::Text(labels_json));
207            }
208            Err(e) => {
209                warn!(error = %e, "Failed to serialize labels for memory update; labels omitted from SQL update");
210            }
211        }
212    }
213
214    if let Some(metadata) = request.metadata {
215        match serde_json::to_string(&metadata) {
216            Ok(metadata_json) => {
217                set_clauses.push("metadata = ?".to_string());
218                bind_values.push(UpdateBindValue::Text(metadata_json));
219            }
220            Err(e) => {
221                warn!(error = %e, "Failed to serialize metadata for memory update; metadata omitted from SQL update");
222            }
223        }
224    }
225
226    if let Some(is_active) = request.is_active {
227        set_clauses.push("is_active = ?".to_string());
228        bind_values.push(UpdateBindValue::Bool(is_active));
229    }
230
231    if let Some(is_archived) = request.is_archived {
232        set_clauses.push("is_archived = ?".to_string());
233        bind_values.push(UpdateBindValue::Bool(is_archived));
234    }
235
236    if set_clauses.is_empty() {
237        return Ok(Json(MemoryResponse::from(existing)));
238    }
239
240    set_clauses.push("updated_at = ?".to_string());
241    bind_values.push(UpdateBindValue::Text(
242        Utc::now().format("%Y-%m-%d %H:%M:%S").to_string(),
243    ));
244
245    let query = format!(
246        "UPDATE memories SET {} WHERE id = ?",
247        set_clauses.join(", ")
248    );
249
250    let mut query = sqlx::query(&query);
251    for bind_value in bind_values {
252        query = match bind_value {
253            UpdateBindValue::Text(value) => query.bind(value),
254            UpdateBindValue::Bool(value) => query.bind(value),
255        };
256    }
257
258    query
259        .bind(id)
260        .execute(state.pool())
261        .await
262        .map_err(|e| WebError::Storage(e.to_string()))?;
263
264    // Fetch updated memory
265    let updated = state
266        .memory_repo
267        .get_by_id(id)
268        .await?
269        .ok_or_else(|| WebError::NotFound(format!("Memory {} not found after update", id)))?;
270
271    // Broadcast update
272    let ws_msg = WebSocketMessage::memory_updated(id);
273    let _ = state.broadcast_ws(ws_msg);
274
275    info!("Memory updated: id={}", id);
276
277    Ok(Json(MemoryResponse::from(updated)))
278}
279
280/// Delete a memory (soft delete)
281pub async fn delete_memory(
282    State(state): State<Arc<RwLock<AppState>>>,
283    Path(id): Path<i64>,
284) -> Result<StatusCode> {
285    let state = state.read().await;
286
287    // Check if memory exists
288    let _ = state
289        .memory_repo
290        .get_by_id(id)
291        .await?
292        .ok_or_else(|| WebError::NotFound(format!("Memory {} not found", id)))?;
293
294    // Soft delete: mark as inactive and archived
295    sqlx::query("UPDATE memories SET is_active = 0, is_archived = 1, updated_at = ? WHERE id = ?")
296        .bind(Utc::now())
297        .bind(id)
298        .execute(state.pool())
299        .await
300        .map_err(|e| WebError::Storage(e.to_string()))?;
301
302    // Broadcast deletion
303    let ws_msg = WebSocketMessage::memory_deleted(id);
304    let _ = state.broadcast_ws(ws_msg);
305
306    info!("Memory deleted: id={}", id);
307
308    Ok(StatusCode::NO_CONTENT)
309}
310
311/// Search memories using semantic or text search
312pub async fn search_memories(
313    State(state): State<Arc<RwLock<AppState>>>,
314    Json(request): Json<SearchRequest>,
315) -> Result<Json<SearchResponse>> {
316    let state = state.read().await;
317
318    // Validate query
319    if request.query.trim().is_empty() {
320        return Err(WebError::InvalidRequest(
321            "Query cannot be empty".to_string(),
322        ));
323    }
324
325    // Get namespace
326    let namespace = state
327        .namespace_repo
328        .get_or_create(&request.agent_type, &request.agent_type)
329        .await?;
330
331    // For now, use text-based search (semantic search would require embeddings)
332    // Search in content using LIKE
333    let search_pattern = format!(
334        "%{}%",
335        request.query.replace("%", "\\%").replace("_", "\\_")
336    );
337
338    let query_str = "SELECT * FROM memories WHERE namespace_id = ? AND is_active = 1 AND content LIKE ? ORDER BY created_at DESC LIMIT ? OFFSET ?".to_string();
339
340    let rows: Vec<nexus_storage::models::MemoryRow> = sqlx::query_as(&query_str)
341        .bind(namespace.id)
342        .bind(&search_pattern)
343        .bind(request.limit as i64)
344        .bind(request.offset as i64)
345        .fetch_all(state.pool())
346        .await
347        .map_err(|e| WebError::Storage(e.to_string()))?;
348
349    // Convert rows to memories
350    let memories: Vec<nexus_core::Memory> = rows
351        .into_iter()
352        .map(row_to_memory)
353        .collect::<crate::error::Result<Vec<_>>>()?;
354
355    let results: Vec<MemoryResponse> = memories.into_iter().map(MemoryResponse::from).collect();
356
357    let total = results.len() as i64;
358
359    let filters = json!({
360        "category": request.category.map(|c| c.to_string()),
361        "memory_lane_type": request.memory_lane_type.map(|t| t.to_string()),
362        "threshold": request.threshold,
363    });
364
365    Ok(Json(SearchResponse {
366        success: true,
367        results,
368        total,
369        query: request.query,
370        agent_type: request.agent_type,
371        filters,
372        error: None,
373    }))
374}
375
376/// Convert a database row to a Memory
377fn row_to_memory(
378    row: nexus_storage::models::MemoryRow,
379) -> crate::error::Result<nexus_core::Memory> {
380    use nexus_core::{Memory, MemoryCategory, MemoryLaneType};
381
382    let labels: Vec<String> = serde_json::from_str(&row.labels).map_err(|e| {
383        crate::error::WebError::Storage(format!("corrupted labels JSON for memory {}: {e}", row.id))
384    })?;
385    let metadata: serde_json::Value = serde_json::from_str(&row.metadata).map_err(|e| {
386        crate::error::WebError::Storage(format!(
387            "corrupted metadata JSON for memory {}: {e}",
388            row.id
389        ))
390    })?;
391    let embedding: Option<Vec<f32>> = row
392        .content_embedding
393        .map(|e| {
394            serde_json::from_str(&e).map_err(|err| {
395                crate::error::WebError::Storage(format!(
396                    "corrupted embedding JSON for memory {}: {err}",
397                    row.id
398                ))
399            })
400        })
401        .transpose()?;
402
403    Ok(Memory {
404        id: row.id,
405        namespace_id: row.namespace_id,
406        content: row.content,
407        category: MemoryCategory::parse(&row.category).ok_or_else(|| {
408            WebError::Storage(format!(
409                "Unknown memory category '{}' persisted in database; row may be corrupted",
410                row.category
411            ))
412        })?,
413        memory_lane_type: match &row.memory_lane_type {
414            Some(s) => Some(MemoryLaneType::parse(s).ok_or_else(|| {
415                WebError::Storage(format!(
416                    "Unknown memory_lane_type '{}' persisted in database; row may be corrupted",
417                    s
418                ))
419            })?),
420            None => None,
421        },
422        labels,
423        metadata,
424        similarity_score: row.similarity_score,
425        relevance_score: row.relevance_score,
426        content_embedding: embedding,
427        embedding_model: row.embedding_model,
428        created_at: row.created_at,
429        updated_at: row.updated_at,
430        last_accessed: row.last_accessed,
431        is_active: row.is_active,
432        is_archived: row.is_archived,
433        access_count: row.access_count,
434    })
435}