Skip to main content

recall_echo/graph/
query.rs

1//! Hybrid query — combines semantic search, graph expansion, and episode search.
2//!
3//! Pipeline:
4//! 1. **Semantic phase**: HNSW KNN with `limit * 2` to gather candidates
5//! 2. **Graph phase**: 1-hop expansion from top-N results, scored as `parent_score * 0.5`
6//! 3. **Merge + deduplicate** by entity ID, keeping highest score
7//! 4. **Episode search** (optional) — separate KNN on episodes
8
9use std::collections::HashMap;
10
11use surrealdb::Surreal;
12
13use super::embed::Embedder;
14use super::error::GraphError;
15use super::store::Db;
16use super::types::*;
17
18/// Run a hybrid query: semantic search + graph expansion + optional episode search.
19pub async fn query(
20    db: &Surreal<Db>,
21    embedder: &dyn Embedder,
22    query_text: &str,
23    options: &QueryOptions,
24) -> Result<QueryResult, GraphError> {
25    let limit = if options.limit == 0 {
26        10
27    } else {
28        options.limit
29    };
30
31    // Phase 1: Semantic search with 2x limit to get candidates
32    let semantic_options = SearchOptions {
33        limit: limit * 2,
34        entity_type: options.entity_type.clone(),
35        keyword: options.keyword.clone(),
36    };
37    let semantic_results =
38        super::search::search_with_options(db, embedder, query_text, &semantic_options).await?;
39
40    // Collect into dedup map (id -> ScoredEntity)
41    let mut entity_map: HashMap<String, ScoredEntity> = HashMap::new();
42    for result in semantic_results {
43        entity_map.insert(result.entity.id_string(), result);
44    }
45
46    // Phase 2: Graph expansion — 1-hop from top-N semantic results
47    if options.graph_depth > 0 {
48        let top_n: Vec<(String, f64)> = {
49            let mut entries: Vec<_> = entity_map
50                .values()
51                .map(|e| (e.entity.id_string(), e.score))
52                .collect();
53            entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54            entries.truncate(3); // Expand from top 3
55            entries
56        };
57
58        for (parent_id, parent_score) in &top_n {
59            let parent_name = entity_map
60                .get(parent_id)
61                .map(|e| e.entity.name.clone())
62                .unwrap_or_default();
63
64            let neighbors = get_neighbor_details(db, parent_id).await?;
65
66            for (neighbor, rel_type, confidence) in neighbors {
67                let neighbor_id = neighbor.id_string();
68                if entity_map.contains_key(&neighbor_id) {
69                    continue; // Already in results
70                }
71
72                // Apply type filter
73                if let Some(ref et) = options.entity_type {
74                    if neighbor.entity_type.to_string() != *et {
75                        continue;
76                    }
77                }
78
79                let graph_score = parent_score * confidence;
80                entity_map.insert(
81                    neighbor_id,
82                    ScoredEntity {
83                        entity: neighbor,
84                        score: graph_score,
85                        source: MatchSource::Graph {
86                            parent: parent_name.clone(),
87                            rel_type,
88                        },
89                    },
90                );
91            }
92        }
93    }
94
95    // Sort by score descending, truncate to limit
96    let mut entities: Vec<ScoredEntity> = entity_map.into_values().collect();
97    entities.sort_by(|a, b| {
98        b.score
99            .partial_cmp(&a.score)
100            .unwrap_or(std::cmp::Ordering::Equal)
101    });
102    entities.truncate(limit);
103
104    // Phase 3: Episode search (optional)
105    let episodes = if options.include_episodes {
106        super::search::search_episodes(db, embedder, query_text, limit).await?
107    } else {
108        vec![]
109    };
110
111    Ok(QueryResult { entities, episodes })
112}
113
114/// Get 1-hop neighbors as L1 (EntityDetail) with the relationship type and confidence.
115async fn get_neighbor_details(
116    db: &Surreal<Db>,
117    entity_id: &str,
118) -> Result<Vec<(EntityDetail, String, f64)>, GraphError> {
119    // Outgoing
120    let mut response = db
121        .query(
122            r#"
123            SELECT rel_type, confidence, out AS target_id
124            FROM relates_to
125            WHERE in = type::record($id) AND valid_until IS NONE AND confidence >= 0.1
126            "#,
127        )
128        .bind(("id", entity_id.to_string()))
129        .await?;
130
131    let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
132
133    // Incoming
134    let mut response = db
135        .query(
136            r#"
137            SELECT rel_type, confidence, in AS target_id
138            FROM relates_to
139            WHERE out = type::record($id) AND valid_until IS NONE AND confidence >= 0.1
140            "#,
141        )
142        .bind(("id", entity_id.to_string()))
143        .await?;
144
145    let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
146
147    let mut results = Vec::new();
148    let all_edges: Vec<_> = outgoing.into_iter().chain(incoming).collect();
149
150    for edge in all_edges {
151        let tid = match &edge.target_id {
152            serde_json::Value::String(s) => s.clone(),
153            other => other.to_string(),
154        };
155
156        if let Some(detail) = super::crud::get_entity_detail(db, &tid).await? {
157            results.push((detail, edge.rel_type, edge.confidence));
158        }
159    }
160
161    Ok(results)
162}
163
164fn default_rel_confidence() -> f64 {
165    1.0
166}
167
168#[derive(serde::Deserialize)]
169struct RelTarget {
170    rel_type: String,
171    target_id: serde_json::Value,
172    #[serde(default = "default_rel_confidence")]
173    confidence: f64,
174}
175
176// ── Pipeline queries ─────────────────────────────────────────────────
177
178/// Get all pipeline entities for a given stage, optionally filtered by status.
179pub async fn pipeline_entities(
180    db: &Surreal<Db>,
181    stage: &str,
182    status: Option<&str>,
183) -> Result<Vec<EntityDetail>, GraphError> {
184    let query = match status {
185        Some(_) => {
186            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
187               FROM entity
188               WHERE attributes.pipeline_stage = $stage
189                 AND attributes.pipeline_status = $status
190               ORDER BY updated_at DESC"#
191        }
192        None => {
193            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
194               FROM entity
195               WHERE attributes.pipeline_stage = $stage
196               ORDER BY updated_at DESC"#
197        }
198    };
199
200    let stage_owned = stage.to_string();
201    let mut response = match status {
202        Some(s) => {
203            let status_owned = s.to_string();
204            db.query(query)
205                .bind(("stage", stage_owned))
206                .bind(("status", status_owned))
207                .await?
208        }
209        None => db.query(query).bind(("stage", stage_owned)).await?,
210    };
211
212    let entities: Vec<EntityDetail> = super::deserialize_take(&mut response, 0)?;
213    Ok(entities)
214}
215
216/// Get pipeline stats: counts by (stage, status), stale entities.
217pub async fn pipeline_stats(
218    db: &Surreal<Db>,
219    staleness_days: u32,
220) -> Result<PipelineGraphStats, GraphError> {
221    // Count by stage and status
222    let mut response = db
223        .query(
224            r#"SELECT
225                 attributes.pipeline_stage AS stage,
226                 attributes.pipeline_status AS status,
227                 count() AS count
228               FROM entity
229               WHERE attributes.pipeline_stage IS NOT NONE
230               GROUP BY attributes.pipeline_stage, attributes.pipeline_status"#,
231        )
232        .await?;
233
234    let rows: Vec<StageStatusCount> = super::deserialize_take(&mut response, 0)?;
235
236    let mut by_stage: std::collections::HashMap<String, std::collections::HashMap<String, u64>> =
237        std::collections::HashMap::new();
238    let mut total = 0u64;
239
240    for row in rows {
241        total += row.count;
242        by_stage
243            .entry(row.stage)
244            .or_default()
245            .insert(row.status, row.count);
246    }
247
248    // Find stale thoughts (active, not updated in staleness_days)
249    let mut stale_response = db
250        .query(
251            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
252               FROM entity
253               WHERE attributes.pipeline_stage = 'thoughts'
254                 AND attributes.pipeline_status = 'active'
255                 AND updated_at < time::now() - type::duration($threshold)
256               ORDER BY updated_at ASC"#,
257        )
258        .bind(("threshold", format!("{}d", staleness_days)))
259        .await?;
260
261    let stale_thoughts: Vec<EntityDetail> = super::deserialize_take(&mut stale_response, 0)?;
262
263    // Find stale questions
264    let mut stale_q_response = db
265        .query(
266            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
267               FROM entity
268               WHERE attributes.pipeline_stage = 'curiosity'
269                 AND attributes.pipeline_status = 'active'
270                 AND attributes.sub_type IS NONE
271                 AND updated_at < time::now() - type::duration($threshold)
272               ORDER BY updated_at ASC"#,
273        )
274        .bind(("threshold", format!("{}d", staleness_days * 2)))
275        .await?;
276
277    let stale_questions: Vec<EntityDetail> = super::deserialize_take(&mut stale_q_response, 0)?;
278
279    // Last movement (most recent graduated/dissolved/explored entity)
280    let mut movement_response = db
281        .query(
282            r#"SELECT updated_at
283               FROM entity
284               WHERE attributes.pipeline_status IN ['graduated', 'dissolved', 'explored']
285               ORDER BY updated_at DESC
286               LIMIT 1"#,
287        )
288        .await?;
289
290    let movement_rows: Vec<UpdatedAtRow> = super::deserialize_take(&mut movement_response, 0)?;
291    let last_movement = movement_rows.first().map(|r| match &r.updated_at {
292        serde_json::Value::String(s) => s.clone(),
293        other => other.to_string(),
294    });
295
296    Ok(PipelineGraphStats {
297        by_stage,
298        stale_thoughts,
299        stale_questions,
300        total_entities: total,
301        last_movement,
302    })
303}
304
305/// Trace the lineage of a pipeline entity through relationship chains.
306pub async fn pipeline_flow(
307    db: &Surreal<Db>,
308    entity_name: &str,
309) -> Result<Vec<(EntityDetail, String, EntityDetail)>, GraphError> {
310    // Get the entity
311    let entity = super::crud::get_entity_by_name(db, entity_name)
312        .await?
313        .ok_or_else(|| GraphError::NotFound(format!("entity: {}", entity_name)))?;
314
315    let entity_id = entity.id_string();
316    let mut chain = Vec::new();
317
318    // Get all pipeline relationships (both directions)
319    let pipeline_rel_types = [
320        "EVOLVED_FROM",
321        "CRYSTALLIZED_FROM",
322        "INFORMED_BY",
323        "GRADUATED_TO",
324        "CONNECTED_TO",
325        "EXPLORES",
326        "ARCHIVED_FROM",
327    ];
328    let rel_types_str = pipeline_rel_types
329        .iter()
330        .map(|r| format!("'{}'", r))
331        .collect::<Vec<_>>()
332        .join(", ");
333
334    // Outgoing relationships
335    let query_out = format!(
336        r#"SELECT rel_type, out AS target_id
337           FROM relates_to
338           WHERE in = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
339        rel_types_str
340    );
341    let mut response = db.query(&query_out).bind(("id", entity_id.clone())).await?;
342    let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
343
344    for edge in &outgoing {
345        let tid = match &edge.target_id {
346            serde_json::Value::String(s) => s.clone(),
347            other => other.to_string(),
348        };
349        if let Some(target) = super::crud::get_entity_detail(db, &tid).await? {
350            let source_detail = super::crud::get_entity_detail(db, &entity_id)
351                .await?
352                .unwrap();
353            chain.push((source_detail, edge.rel_type.clone(), target));
354        }
355    }
356
357    // Incoming relationships
358    let query_in = format!(
359        r#"SELECT rel_type, in AS target_id
360           FROM relates_to
361           WHERE out = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
362        rel_types_str
363    );
364    let mut response = db.query(&query_in).bind(("id", entity_id.clone())).await?;
365    let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
366
367    for edge in &incoming {
368        let tid = match &edge.target_id {
369            serde_json::Value::String(s) => s.clone(),
370            other => other.to_string(),
371        };
372        if let Some(source) = super::crud::get_entity_detail(db, &tid).await? {
373            let target_detail = super::crud::get_entity_detail(db, &entity_id)
374                .await?
375                .unwrap();
376            chain.push((source, edge.rel_type.clone(), target_detail));
377        }
378    }
379
380    Ok(chain)
381}
382
383#[derive(serde::Deserialize)]
384struct StageStatusCount {
385    stage: String,
386    status: String,
387    count: u64,
388}
389
390#[derive(serde::Deserialize)]
391struct UpdatedAtRow {
392    updated_at: serde_json::Value,
393}