Skip to main content

construct/gateway/
api_memory_graph.rs

1//! Aggregated memory graph endpoint for the Memory Auditor.
2//!
3//! `GET /api/memory/graph` — returns items + edges + spaces in one payload,
4//! ready for the Obsidian-style force-graph visualization.
5//!
6//! **Primary path**: Operator MCP tool (`memory_graph`) via direct SDK/gRPC.
7//! **Fallback path**: HTTP calls to Kumiho FastAPI (used when operator unavailable).
8
9use super::AppState;
10use super::api::require_auth;
11use super::api_agents::build_kumiho_client;
12use super::kumiho_client::ItemResponse;
13use axum::{
14    extract::{Query, State},
15    http::{HeaderMap, StatusCode},
16    response::{IntoResponse, Json},
17};
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21// ── Query parameters ────────────────────────────────────────────────────
22
23#[derive(Deserialize)]
24pub struct MemoryGraphQuery {
25    /// Kumiho project name (default: "CognitiveMemory").
26    pub project: Option<String>,
27    /// Maximum number of items to include (default 100, max 500).
28    pub limit: Option<u32>,
29    /// Comma-separated kind filter (e.g. "decision,fact,preference").
30    pub kinds: Option<String>,
31    /// Space path filter — only include items from this space.
32    pub space: Option<String>,
33    /// Sort mode: "recent" (default), "name".
34    pub sort: Option<String>,
35    /// Search query — if provided, filters to matching items via fulltext search.
36    pub search: Option<String>,
37}
38
39// ── Response types ──────────────────────────────────────────────────────
40
41#[derive(Serialize, Deserialize)]
42pub struct MemoryGraphResponse {
43    pub nodes: Vec<GraphNode>,
44    pub edges: Vec<GraphEdge>,
45    pub spaces: Vec<String>,
46    pub stats: GraphStats,
47}
48
49#[derive(Serialize, Deserialize)]
50pub struct GraphNode {
51    pub id: String,
52    pub name: String,
53    pub kind: String,
54    pub space: String,
55    pub created_at: Option<String>,
56    pub title: Option<String>,
57    pub summary: Option<String>,
58    pub revision_kref: Option<String>,
59}
60
61#[derive(Serialize, Deserialize)]
62pub struct GraphEdge {
63    pub source: String,
64    pub target: String,
65    pub edge_type: String,
66    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
67    pub metadata: HashMap<String, String>,
68}
69
70#[derive(Serialize, Deserialize)]
71pub struct GraphStats {
72    pub total_items: usize,
73    pub total_edges: usize,
74    pub kinds: HashMap<String, usize>,
75}
76
77// ── Handler ─────────────────────────────────────────────────────────────
78
79pub async fn handle_memory_graph(
80    State(state): State<AppState>,
81    headers: HeaderMap,
82    Query(query): Query<MemoryGraphQuery>,
83) -> impl IntoResponse {
84    if let Err(e) = require_auth(&state, &headers) {
85        return e.into_response();
86    }
87
88    // Build MCP tool arguments from query params
89    let mut mcp_args = serde_json::Map::new();
90    if let Some(ref p) = query.project {
91        mcp_args.insert("project".into(), serde_json::Value::String(p.clone()));
92    }
93    if let Some(l) = query.limit {
94        mcp_args.insert(
95            "limit".into(),
96            serde_json::Value::Number(serde_json::Number::from(l)),
97        );
98    }
99    if let Some(ref k) = query.kinds {
100        mcp_args.insert("kinds".into(), serde_json::Value::String(k.clone()));
101    }
102    if let Some(ref s) = query.space {
103        mcp_args.insert("space".into(), serde_json::Value::String(s.clone()));
104    }
105    if let Some(ref s) = query.sort {
106        mcp_args.insert("sort".into(), serde_json::Value::String(s.clone()));
107    }
108    if let Some(ref s) = query.search {
109        mcp_args.insert("search".into(), serde_json::Value::String(s.clone()));
110    }
111
112    // Try operator MCP tool first (direct SDK, no HTTP hop).
113    // Cap at 45s — the memory graph route has its own 60s timeout, not the global 30s.
114    if let Some(ref registry) = state.mcp_registry {
115        let tool_name = format!(
116            "{}__memory_graph",
117            crate::agent::operator::OPERATOR_SERVER_NAME
118        );
119        let mcp_future =
120            registry.call_tool(&tool_name, serde_json::Value::Object(mcp_args.clone()));
121
122        match tokio::time::timeout(std::time::Duration::from_secs(45), mcp_future).await {
123            Ok(Ok(result_str)) => {
124                // MCP tools/call returns {"content": [{"type":"text","text":"..."}]}
125                // Extract the inner text, then parse as MemoryGraphResponse.
126                if let Ok(wrapper) = serde_json::from_str::<serde_json::Value>(&result_str) {
127                    let inner_json = wrapper
128                        .get("content")
129                        .and_then(|c| c.as_array())
130                        .and_then(|arr| arr.first())
131                        .and_then(|item| item.get("text"))
132                        .and_then(|t| t.as_str())
133                        .and_then(|text| serde_json::from_str::<serde_json::Value>(text).ok());
134
135                    if let Some(val) = inner_json {
136                        if val.get("error").and_then(|e| e.as_str()).is_none() {
137                            if let Ok(response) = serde_json::from_value::<MemoryGraphResponse>(val)
138                            {
139                                tracing::info!(
140                                    "memory_graph: operator MCP path succeeded ({} nodes, {} edges)",
141                                    response.nodes.len(),
142                                    response.edges.len()
143                                );
144                                return (StatusCode::OK, Json(response)).into_response();
145                            }
146                        }
147                        tracing::warn!(
148                            "memory_graph: operator returned error or unparseable inner JSON"
149                        );
150                    } else {
151                        tracing::warn!(
152                            "memory_graph: could not extract text from MCP content wrapper"
153                        );
154                    }
155                }
156                // Fall through to HTTP fallback
157            }
158            Ok(Err(e)) => {
159                tracing::warn!("memory_graph: operator tool call failed: {e:#}");
160            }
161            Err(_) => {
162                tracing::warn!("memory_graph: operator tool call timed out (45s)");
163            }
164        }
165    }
166
167    // Fallback: HTTP calls to Kumiho FastAPI
168    http_fallback_memory_graph(&state, &query).await
169}
170
171// ── HTTP Fallback ───────────────────────────────────────────────────────
172
173/// Strip `kref://` prefix if present.
174fn strip_kref_scheme(kref: &str) -> &str {
175    kref.strip_prefix("kref://").unwrap_or(kref)
176}
177
178/// Extract the item-level ID from a revision kref.
179fn revision_kref_to_item_id(rev_kref: &str) -> String {
180    let stripped = strip_kref_scheme(rev_kref);
181    stripped.split('?').next().unwrap_or(stripped).to_string()
182}
183
184/// Extract `space_path` from an item kref.
185fn item_kref_to_space(kref: &str) -> String {
186    let stripped = strip_kref_scheme(kref);
187    match stripped.rfind('/') {
188        Some(pos) => stripped[..pos].to_string(),
189        None => String::new(),
190    }
191}
192
193fn item_to_node(
194    item: &ItemResponse,
195    rev_title: Option<&str>,
196    rev_summary: Option<&str>,
197    rev_kref: Option<&str>,
198) -> GraphNode {
199    let id = strip_kref_scheme(&item.kref).to_string();
200    let space = item_kref_to_space(&item.kref);
201    GraphNode {
202        id,
203        name: item.item_name.clone(),
204        kind: item.kind.clone(),
205        space,
206        created_at: item.created_at.clone(),
207        title: rev_title.map(|s| s.to_string()),
208        summary: rev_summary.map(|s| s.to_string()),
209        revision_kref: rev_kref.map(|s| s.to_string()),
210    }
211}
212
213async fn http_fallback_memory_graph(
214    state: &AppState,
215    query: &MemoryGraphQuery,
216) -> axum::response::Response {
217    let client = build_kumiho_client(state);
218    let default_project = {
219        let config = state.config.lock();
220        config.kumiho.memory_project.clone()
221    };
222    let project = query.project.as_deref().unwrap_or(&default_project);
223    let limit = query.limit.unwrap_or(100).min(500) as usize;
224    let kind_filter: Vec<String> = query
225        .kinds
226        .as_deref()
227        .unwrap_or("")
228        .split(',')
229        .map(|s| s.trim().to_string())
230        .filter(|s| !s.is_empty())
231        .collect();
232    let space_filter = query.space.as_deref().unwrap_or("");
233    let sort_mode = query.sort.as_deref().unwrap_or("recent");
234    let search_query = query.search.as_deref().unwrap_or("");
235
236    // 1. List all spaces recursively
237    let root_path = format!("/{project}");
238    let spaces_result = client.list_spaces(&root_path, true).await;
239    let space_paths: Vec<String> = match spaces_result {
240        Ok(spaces) => {
241            let mut paths = vec![root_path.clone()];
242            paths.extend(spaces.into_iter().map(|s| s.path));
243            paths
244        }
245        Err(e) => {
246            return (
247                StatusCode::BAD_GATEWAY,
248                Json(serde_json::json!({ "error": format!("Failed to list spaces: {e}") })),
249            )
250                .into_response();
251        }
252    };
253
254    let target_spaces: Vec<&str> = if space_filter.is_empty() {
255        space_paths.iter().map(|s| s.as_str()).collect()
256    } else {
257        space_paths
258            .iter()
259            .filter(|s| s.starts_with(space_filter) || *s == space_filter)
260            .map(|s| s.as_str())
261            .collect()
262    };
263
264    // 2. Fetch items
265    let mut all_items: Vec<ItemResponse> = Vec::new();
266
267    if !search_query.is_empty() {
268        match client.search_items(search_query, project, "", false).await {
269            Ok(results) => {
270                all_items = results.into_iter().map(|r| r.item).collect();
271            }
272            Err(e) => {
273                return (
274                    StatusCode::BAD_GATEWAY,
275                    Json(serde_json::json!({ "error": format!("Search failed: {e}") })),
276                )
277                    .into_response();
278            }
279        }
280    } else {
281        for chunk in target_spaces.chunks(10) {
282            let futs: Vec<_> = chunk
283                .iter()
284                .map(|sp| {
285                    let c = client.clone();
286                    let sp = sp.to_string();
287                    async move {
288                        c.list_items_paged(&sp, false, 200, 0)
289                            .await
290                            .unwrap_or_default()
291                    }
292                })
293                .collect();
294            let results = futures_util::future::join_all(futs).await;
295            for items in results {
296                all_items.extend(items);
297            }
298            if all_items.len() > limit * 2 {
299                break;
300            }
301        }
302    }
303
304    // 3. Apply kind filter
305    if !kind_filter.is_empty() {
306        all_items.retain(|item| kind_filter.contains(&item.kind));
307    }
308
309    // 4. Sort
310    match sort_mode {
311        "name" => all_items.sort_by(|a, b| a.item_name.cmp(&b.item_name)),
312        _ => {
313            all_items.sort_by(|a, b| {
314                let a_date = a.created_at.as_deref().unwrap_or("");
315                let b_date = b.created_at.as_deref().unwrap_or("");
316                b_date.cmp(a_date)
317            });
318        }
319    }
320
321    let mut kind_counts: HashMap<String, usize> = HashMap::new();
322    for item in &all_items {
323        *kind_counts.entry(item.kind.clone()).or_insert(0) += 1;
324    }
325    let total_items_count = all_items.len();
326
327    // 5. Truncate
328    all_items.truncate(limit);
329
330    // 6. Batch-fetch revisions
331    let item_krefs: Vec<String> = all_items.iter().map(|i| i.kref.clone()).collect();
332    let rev_map = client
333        .batch_get_revisions(&item_krefs, "latest")
334        .await
335        .unwrap_or_default();
336
337    // 7. Build nodes
338    let mut nodes: Vec<GraphNode> = Vec::with_capacity(all_items.len());
339    let mut item_id_set: std::collections::HashSet<String> = std::collections::HashSet::new();
340    let mut rev_krefs: Vec<String> = Vec::new();
341
342    for item in &all_items {
343        let rev = rev_map.get(&item.kref);
344        let title = rev.and_then(|r| r.metadata.get("title").map(|s| s.as_str()));
345        let summary = rev.and_then(|r| r.metadata.get("summary").map(|s| s.as_str()));
346        let rev_kref = rev.map(|r| r.kref.as_str());
347        nodes.push(item_to_node(item, title, summary, rev_kref));
348        item_id_set.insert(strip_kref_scheme(&item.kref).to_string());
349        if let Some(r) = rev {
350            rev_krefs.push(r.kref.clone());
351        }
352    }
353
354    // 8. Fetch edges
355    let mut edge_results = Vec::new();
356    for chunk in rev_krefs.chunks(10) {
357        let futs: Vec<_> = chunk
358            .iter()
359            .map(|rk| {
360                let c = client.clone();
361                let rk = rk.clone();
362                async move {
363                    c.list_edges(&rk, None, Some("both"))
364                        .await
365                        .unwrap_or_default()
366                }
367            })
368            .collect();
369        edge_results.extend(futures_util::future::join_all(futs).await);
370    }
371
372    // 9. Deduplicate edges
373    let mut seen_edges: std::collections::HashSet<(String, String, String)> =
374        std::collections::HashSet::new();
375    let mut edges: Vec<GraphEdge> = Vec::new();
376
377    for edge_list in edge_results {
378        for edge in edge_list {
379            let source_id = revision_kref_to_item_id(&edge.source_kref);
380            let target_id = revision_kref_to_item_id(&edge.target_kref);
381            if source_id == target_id {
382                continue;
383            }
384            if !item_id_set.contains(&source_id) || !item_id_set.contains(&target_id) {
385                continue;
386            }
387            let key = (source_id.clone(), target_id.clone(), edge.edge_type.clone());
388            if seen_edges.contains(&key) {
389                continue;
390            }
391            seen_edges.insert(key);
392            edges.push(GraphEdge {
393                source: source_id,
394                target: target_id,
395                edge_type: edge.edge_type,
396                metadata: edge.metadata.unwrap_or_default(),
397            });
398        }
399    }
400
401    let total_edges = edges.len();
402
403    let response = MemoryGraphResponse {
404        nodes,
405        edges,
406        spaces: space_paths
407            .into_iter()
408            .map(|s| s.trim_start_matches('/').to_string())
409            .collect(),
410        stats: GraphStats {
411            total_items: total_items_count,
412            total_edges,
413            kinds: kind_counts,
414        },
415    };
416
417    (StatusCode::OK, Json(response)).into_response()
418}