Skip to main content

lago_api/routes/
memory.rs

1//! Auth-protected memory endpoints scoped to the user's Lago session.
2//!
3//! These routes provide the HTTP API for the context engine: file CRUD,
4//! server-side scored search, wikilink resolution, and graph traversal.
5
6use std::sync::Arc;
7
8use axum::Json;
9use axum::body::Bytes;
10use axum::extract::{Extension, Path, State};
11use axum::http::StatusCode;
12use serde::{Deserialize, Serialize};
13
14use lago_auth::UserContext;
15use lago_core::ManifestEntry;
16use lago_knowledge::{KnowledgeIndex, SearchResult, TraversalResult};
17
18use crate::error::ApiError;
19use crate::state::AppState;
20
21// Re-use the file helpers from the files module
22use super::files::{FileWriteResponse, ManifestResponse};
23
24// ---------------------------------------------------------------------------
25// Request / Response types
26// ---------------------------------------------------------------------------
27
28#[derive(Deserialize)]
29pub struct SearchRequest {
30    pub query: String,
31    #[serde(default = "default_max_results")]
32    pub max_results: usize,
33    #[serde(default)]
34    pub follow_links: bool,
35}
36
37fn default_max_results() -> usize {
38    10
39}
40
41#[derive(Serialize)]
42pub struct SearchResponse {
43    pub results: Vec<SearchResult>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub linked_notes: Option<Vec<LinkedNote>>,
46}
47
48#[derive(Serialize)]
49pub struct LinkedNote {
50    pub path: String,
51    pub name: String,
52    pub depth: usize,
53    pub links: Vec<String>,
54}
55
56#[derive(Deserialize)]
57pub struct TraverseRequest {
58    pub target: String,
59    #[serde(default = "default_depth")]
60    pub depth: usize,
61    #[serde(default = "default_max_notes")]
62    pub max_notes: usize,
63}
64
65fn default_depth() -> usize {
66    1
67}
68
69fn default_max_notes() -> usize {
70    15
71}
72
73#[derive(Serialize)]
74pub struct TraverseResponse {
75    pub notes: Vec<TraversalResult>,
76}
77
78#[derive(Serialize)]
79pub struct NoteResponse {
80    pub path: String,
81    pub name: String,
82    #[serde(with = "yaml_as_json")]
83    pub frontmatter: serde_yaml::Value,
84    pub body: String,
85    pub links: Vec<String>,
86}
87
88// ---------------------------------------------------------------------------
89// Helpers
90// ---------------------------------------------------------------------------
91
92/// Build a manifest for the user's vault session.
93async fn user_manifest(
94    state: &Arc<AppState>,
95    user: &UserContext,
96) -> Result<Vec<ManifestEntry>, ApiError> {
97    let session_id = &user.lago_session_id;
98    let branch_id = lago_core::BranchId::from_string("main".to_string());
99
100    let query = lago_core::EventQuery::new()
101        .session(session_id.clone())
102        .branch(branch_id);
103    let events = state.journal.read(query).await?;
104
105    let mut manifest = lago_fs::Manifest::new();
106    for event in &events {
107        match &event.payload {
108            lago_core::EventPayload::FileWrite {
109                path,
110                blob_hash,
111                size_bytes,
112                content_type,
113            } => {
114                manifest.apply_write(
115                    path.clone(),
116                    lago_core::BlobHash::from_hex(blob_hash.as_str()),
117                    *size_bytes,
118                    content_type.clone(),
119                    event.timestamp,
120                );
121            }
122            lago_core::EventPayload::FileDelete { path } => {
123                manifest.apply_delete(path);
124            }
125            lago_core::EventPayload::FileRename { old_path, new_path } => {
126                manifest.apply_rename(old_path, new_path.clone());
127            }
128            _ => {}
129        }
130    }
131
132    Ok(manifest.entries().values().cloned().collect())
133}
134
135/// Build a knowledge index for the user's vault, with 30s TTL caching.
136fn build_knowledge_index(
137    manifest: &[ManifestEntry],
138    state: &Arc<AppState>,
139) -> Result<KnowledgeIndex, ApiError> {
140    KnowledgeIndex::build(manifest, &state.blob_store)
141        .map_err(|e| ApiError::Internal(format!("failed to build knowledge index: {e}")))
142}
143
144/// Ensure the path starts with '/'.
145fn normalize_path(path: &str) -> String {
146    if path.starts_with('/') {
147        path.to_string()
148    } else {
149        format!("/{path}")
150    }
151}
152
153// ---------------------------------------------------------------------------
154// Handlers
155// ---------------------------------------------------------------------------
156
157/// GET /v1/memory/manifest — list all files in the user's vault.
158pub async fn get_manifest(
159    State(state): State<Arc<AppState>>,
160    Extension(user): Extension<UserContext>,
161) -> Result<Json<ManifestResponse>, ApiError> {
162    let entries = user_manifest(&state, &user).await?;
163
164    Ok(Json(ManifestResponse {
165        session_id: user.lago_session_id.to_string(),
166        entries,
167    }))
168}
169
170/// GET /v1/memory/files/{*path} — read a file from the user's vault.
171pub async fn read_file(
172    State(state): State<Arc<AppState>>,
173    Extension(user): Extension<UserContext>,
174    Path(file_path): Path<String>,
175) -> Result<axum::http::Response<axum::body::Body>, ApiError> {
176    let file_path = normalize_path(&file_path);
177    let manifest = user_manifest(&state, &user).await?;
178
179    let entry = manifest
180        .iter()
181        .find(|e| e.path == file_path)
182        .ok_or_else(|| ApiError::NotFound(format!("file not found: {file_path}")))?;
183
184    let data = state
185        .blob_store
186        .get(&entry.blob_hash)
187        .map_err(|e| ApiError::Internal(format!("failed to read blob: {e}")))?;
188
189    let content_type = entry
190        .content_type
191        .clone()
192        .unwrap_or_else(|| "text/markdown".to_string());
193
194    Ok(axum::http::Response::builder()
195        .status(StatusCode::OK)
196        .header("content-type", content_type)
197        .header("x-blob-hash", entry.blob_hash.as_str())
198        .body(axum::body::Body::from(data))
199        .unwrap())
200}
201
202/// PUT /v1/memory/files/{*path} — write a file to the user's vault.
203pub async fn write_file(
204    State(state): State<Arc<AppState>>,
205    Extension(user): Extension<UserContext>,
206    Path(file_path): Path<String>,
207    body: Bytes,
208) -> Result<(StatusCode, Json<FileWriteResponse>), ApiError> {
209    let file_path = normalize_path(&file_path);
210    let session_id = user.lago_session_id.clone();
211    let branch_id = lago_core::BranchId::from_string("main".to_string());
212
213    let blob_hash = state
214        .blob_store
215        .put(&body)
216        .map_err(|e| ApiError::Internal(format!("failed to store blob: {e}")))?;
217
218    let size_bytes = body.len() as u64;
219
220    let event = lago_core::event::EventEnvelope {
221        event_id: lago_core::EventId::new(),
222        session_id,
223        branch_id,
224        run_id: None,
225        seq: 0,
226        timestamp: lago_core::event::EventEnvelope::now_micros(),
227        parent_id: None,
228        payload: lago_core::EventPayload::FileWrite {
229            path: file_path.clone(),
230            blob_hash: blob_hash.clone().into(),
231            size_bytes,
232            content_type: Some("text/markdown".to_string()),
233        },
234        metadata: std::collections::HashMap::new(),
235        schema_version: 1,
236    };
237
238    state.journal.append(event).await?;
239
240    Ok((
241        StatusCode::CREATED,
242        Json(FileWriteResponse {
243            path: file_path,
244            blob_hash: blob_hash.to_string(),
245            size_bytes,
246        }),
247    ))
248}
249
250/// DELETE /v1/memory/files/{*path} — delete a file from the user's vault.
251pub async fn delete_file(
252    State(state): State<Arc<AppState>>,
253    Extension(user): Extension<UserContext>,
254    Path(file_path): Path<String>,
255) -> Result<StatusCode, ApiError> {
256    let file_path = normalize_path(&file_path);
257    let session_id = user.lago_session_id.clone();
258    let branch_id = lago_core::BranchId::from_string("main".to_string());
259
260    let event = lago_core::event::EventEnvelope {
261        event_id: lago_core::EventId::new(),
262        session_id,
263        branch_id,
264        run_id: None,
265        seq: 0,
266        timestamp: lago_core::event::EventEnvelope::now_micros(),
267        parent_id: None,
268        payload: lago_core::EventPayload::FileDelete { path: file_path },
269        metadata: std::collections::HashMap::new(),
270        schema_version: 1,
271    };
272
273    state.journal.append(event).await?;
274
275    Ok(StatusCode::NO_CONTENT)
276}
277
278/// POST /v1/memory/search — search with scoring + optional graph traversal.
279pub async fn search(
280    State(state): State<Arc<AppState>>,
281    Extension(user): Extension<UserContext>,
282    Json(body): Json<SearchRequest>,
283) -> Result<Json<SearchResponse>, ApiError> {
284    let manifest = user_manifest(&state, &user).await?;
285    let index = build_knowledge_index(&manifest, &state)?;
286
287    let results = index.search(&body.query, body.max_results);
288
289    let linked_notes = if body.follow_links && !results.is_empty() {
290        // Collect unique wikilink targets from top results
291        let mut seen_paths: std::collections::HashSet<String> =
292            results.iter().map(|r| r.path.clone()).collect();
293
294        let mut linked = Vec::new();
295        for result in &results {
296            for link in &result.links {
297                if let Some(note) = index.resolve_wikilink(link) {
298                    if seen_paths.insert(note.path.clone()) {
299                        linked.push(LinkedNote {
300                            path: note.path.clone(),
301                            name: note.name.clone(),
302                            depth: 1,
303                            links: note.links.clone(),
304                        });
305                    }
306                }
307                if linked.len() >= 10 {
308                    break;
309                }
310            }
311        }
312
313        if linked.is_empty() {
314            None
315        } else {
316            Some(linked)
317        }
318    } else {
319        None
320    };
321
322    Ok(Json(SearchResponse {
323        results,
324        linked_notes,
325    }))
326}
327
328/// POST /v1/memory/traverse — BFS graph traversal from a note.
329pub async fn traverse(
330    State(state): State<Arc<AppState>>,
331    Extension(user): Extension<UserContext>,
332    Json(body): Json<TraverseRequest>,
333) -> Result<Json<TraverseResponse>, ApiError> {
334    let manifest = user_manifest(&state, &user).await?;
335    let index = build_knowledge_index(&manifest, &state)?;
336
337    let notes = index.traverse(&body.target, body.depth, body.max_notes);
338
339    Ok(Json(TraverseResponse { notes }))
340}
341
342/// GET /v1/memory/note/{name} — resolve a wikilink to a full note.
343pub async fn read_note(
344    State(state): State<Arc<AppState>>,
345    Extension(user): Extension<UserContext>,
346    Path(name): Path<String>,
347) -> Result<Json<NoteResponse>, ApiError> {
348    let manifest = user_manifest(&state, &user).await?;
349    let index = build_knowledge_index(&manifest, &state)?;
350
351    let note = index
352        .resolve_wikilink(&name)
353        .ok_or_else(|| ApiError::NotFound(format!("note not found: {name}")))?;
354
355    Ok(Json(NoteResponse {
356        path: note.path.clone(),
357        name: note.name.clone(),
358        frontmatter: note.frontmatter.clone(),
359        body: note.body.clone(),
360        links: note.links.clone(),
361    }))
362}
363
364// ---------------------------------------------------------------------------
365// Serde helper for YAML → JSON serialization
366// ---------------------------------------------------------------------------
367
368mod yaml_as_json {
369    use serde::{Serialize, Serializer};
370
371    pub fn serialize<S>(value: &serde_yaml::Value, serializer: S) -> Result<S::Ok, S::Error>
372    where
373        S: Serializer,
374    {
375        let json = yaml_to_json(value);
376        json.serialize(serializer)
377    }
378
379    fn yaml_to_json(value: &serde_yaml::Value) -> serde_json::Value {
380        match value {
381            serde_yaml::Value::Null => serde_json::Value::Null,
382            serde_yaml::Value::Bool(b) => serde_json::Value::Bool(*b),
383            serde_yaml::Value::Number(n) => {
384                if let Some(i) = n.as_i64() {
385                    serde_json::Value::Number(i.into())
386                } else if let Some(f) = n.as_f64() {
387                    serde_json::json!(f)
388                } else {
389                    serde_json::Value::Null
390                }
391            }
392            serde_yaml::Value::String(s) => serde_json::Value::String(s.clone()),
393            serde_yaml::Value::Sequence(seq) => {
394                serde_json::Value::Array(seq.iter().map(yaml_to_json).collect())
395            }
396            serde_yaml::Value::Mapping(map) => {
397                let mut obj = serde_json::Map::new();
398                for (k, v) in map {
399                    let key = match k {
400                        serde_yaml::Value::String(s) => s.clone(),
401                        _ => format!("{k:?}"),
402                    };
403                    obj.insert(key, yaml_to_json(v));
404                }
405                serde_json::Value::Object(obj)
406            }
407            serde_yaml::Value::Tagged(tagged) => yaml_to_json(&tagged.value),
408        }
409    }
410}