Skip to main content

recall_echo/graph/
query.rs

1//! Hybrid query — combines semantic search, graph expansion, and episode search.
2//!
3//! Pipeline:
4//! 1. **Semantic phase**: HNSW KNN with `limit * 2` to gather candidates
5//! 2. **Graph phase**: 1-hop expansion from top-N results, scored as `parent_score * 0.5`
6//! 3. **Merge + deduplicate** by entity ID, keeping highest score
7//! 4. **Episode search** (optional) — separate KNN on episodes
8
9use std::collections::HashMap;
10
11use surrealdb::Surreal;
12
13use super::embed::Embedder;
14use super::error::GraphError;
15use super::store::Db;
16use super::types::*;
17
18/// Run a hybrid query: semantic search + graph expansion + optional episode search.
19pub async fn query(
20    db: &Surreal<Db>,
21    embedder: &dyn Embedder,
22    query_text: &str,
23    options: &QueryOptions,
24) -> Result<QueryResult, GraphError> {
25    let limit = if options.limit == 0 {
26        10
27    } else {
28        options.limit
29    };
30
31    // Phase 1: Semantic search with 2x limit to get candidates
32    let semantic_options = SearchOptions {
33        limit: limit * 2,
34        entity_type: options.entity_type.clone(),
35        keyword: options.keyword.clone(),
36    };
37    let semantic_results =
38        super::search::search_with_options(db, embedder, query_text, &semantic_options).await?;
39
40    // Collect into dedup map (id -> ScoredEntity)
41    let mut entity_map: HashMap<String, ScoredEntity> = HashMap::new();
42    for result in semantic_results {
43        entity_map.insert(result.entity.id_string(), result);
44    }
45
46    // Phase 2: Graph expansion — 1-hop from top-N semantic results
47    if options.graph_depth > 0 {
48        let top_n: Vec<(String, f64)> = {
49            let mut entries: Vec<_> = entity_map
50                .values()
51                .map(|e| (e.entity.id_string(), e.score))
52                .collect();
53            entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54            entries.truncate(3); // Expand from top 3
55            entries
56        };
57
58        for (parent_id, parent_score) in &top_n {
59            let parent_name = entity_map
60                .get(parent_id)
61                .map(|e| e.entity.name.clone())
62                .unwrap_or_default();
63
64            let neighbors = get_neighbor_details(db, parent_id).await?;
65
66            for (neighbor, rel_type) in neighbors {
67                let neighbor_id = neighbor.id_string();
68                if entity_map.contains_key(&neighbor_id) {
69                    continue; // Already in results
70                }
71
72                // Apply type filter
73                if let Some(ref et) = options.entity_type {
74                    if neighbor.entity_type.to_string() != *et {
75                        continue;
76                    }
77                }
78
79                let graph_score = parent_score * 0.5;
80                entity_map.insert(
81                    neighbor_id,
82                    ScoredEntity {
83                        entity: neighbor,
84                        score: graph_score,
85                        source: MatchSource::Graph {
86                            parent: parent_name.clone(),
87                            rel_type,
88                        },
89                    },
90                );
91            }
92        }
93    }
94
95    // Sort by score descending, truncate to limit
96    let mut entities: Vec<ScoredEntity> = entity_map.into_values().collect();
97    entities.sort_by(|a, b| {
98        b.score
99            .partial_cmp(&a.score)
100            .unwrap_or(std::cmp::Ordering::Equal)
101    });
102    entities.truncate(limit);
103
104    // Phase 3: Episode search (optional)
105    let episodes = if options.include_episodes {
106        super::search::search_episodes(db, embedder, query_text, limit).await?
107    } else {
108        vec![]
109    };
110
111    Ok(QueryResult { entities, episodes })
112}
113
114/// Get 1-hop neighbors as L1 (EntityDetail) with the relationship type.
115async fn get_neighbor_details(
116    db: &Surreal<Db>,
117    entity_id: &str,
118) -> Result<Vec<(EntityDetail, String)>, GraphError> {
119    // Outgoing
120    let mut response = db
121        .query(
122            r#"
123            SELECT rel_type, out AS target_id
124            FROM relates_to
125            WHERE in = type::record($id) AND valid_until IS NONE
126            "#,
127        )
128        .bind(("id", entity_id.to_string()))
129        .await?;
130
131    let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
132
133    // Incoming
134    let mut response = db
135        .query(
136            r#"
137            SELECT rel_type, in AS target_id
138            FROM relates_to
139            WHERE out = type::record($id) AND valid_until IS NONE
140            "#,
141        )
142        .bind(("id", entity_id.to_string()))
143        .await?;
144
145    let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
146
147    let mut results = Vec::new();
148    let all_edges: Vec<_> = outgoing.into_iter().chain(incoming).collect();
149
150    for edge in all_edges {
151        let tid = match &edge.target_id {
152            serde_json::Value::String(s) => s.clone(),
153            other => other.to_string(),
154        };
155
156        if let Some(detail) = super::crud::get_entity_detail(db, &tid).await? {
157            results.push((detail, edge.rel_type));
158        }
159    }
160
161    Ok(results)
162}
163
164#[derive(serde::Deserialize)]
165struct RelTarget {
166    rel_type: String,
167    target_id: serde_json::Value,
168}
169
170// ── Pipeline queries ─────────────────────────────────────────────────
171
172/// Get all pipeline entities for a given stage, optionally filtered by status.
173pub async fn pipeline_entities(
174    db: &Surreal<Db>,
175    stage: &str,
176    status: Option<&str>,
177) -> Result<Vec<EntityDetail>, GraphError> {
178    let query = match status {
179        Some(_) => {
180            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
181               FROM entity
182               WHERE attributes.pipeline_stage = $stage
183                 AND attributes.pipeline_status = $status
184               ORDER BY updated_at DESC"#
185        }
186        None => {
187            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
188               FROM entity
189               WHERE attributes.pipeline_stage = $stage
190               ORDER BY updated_at DESC"#
191        }
192    };
193
194    let stage_owned = stage.to_string();
195    let mut response = match status {
196        Some(s) => {
197            let status_owned = s.to_string();
198            db.query(query)
199                .bind(("stage", stage_owned))
200                .bind(("status", status_owned))
201                .await?
202        }
203        None => db.query(query).bind(("stage", stage_owned)).await?,
204    };
205
206    let entities: Vec<EntityDetail> = super::deserialize_take(&mut response, 0)?;
207    Ok(entities)
208}
209
210/// Get pipeline stats: counts by (stage, status), stale entities.
211pub async fn pipeline_stats(
212    db: &Surreal<Db>,
213    staleness_days: u32,
214) -> Result<PipelineGraphStats, GraphError> {
215    // Count by stage and status
216    let mut response = db
217        .query(
218            r#"SELECT
219                 attributes.pipeline_stage AS stage,
220                 attributes.pipeline_status AS status,
221                 count() AS count
222               FROM entity
223               WHERE attributes.pipeline_stage IS NOT NONE
224               GROUP BY attributes.pipeline_stage, attributes.pipeline_status"#,
225        )
226        .await?;
227
228    let rows: Vec<StageStatusCount> = super::deserialize_take(&mut response, 0)?;
229
230    let mut by_stage: std::collections::HashMap<String, std::collections::HashMap<String, u64>> =
231        std::collections::HashMap::new();
232    let mut total = 0u64;
233
234    for row in rows {
235        total += row.count;
236        by_stage
237            .entry(row.stage)
238            .or_default()
239            .insert(row.status, row.count);
240    }
241
242    // Find stale thoughts (active, not updated in staleness_days)
243    let mut stale_response = db
244        .query(
245            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
246               FROM entity
247               WHERE attributes.pipeline_stage = 'thoughts'
248                 AND attributes.pipeline_status = 'active'
249                 AND updated_at < time::now() - type::duration($threshold)
250               ORDER BY updated_at ASC"#,
251        )
252        .bind(("threshold", format!("{}d", staleness_days)))
253        .await?;
254
255    let stale_thoughts: Vec<EntityDetail> = super::deserialize_take(&mut stale_response, 0)?;
256
257    // Find stale questions
258    let mut stale_q_response = db
259        .query(
260            r#"SELECT id, name, entity_type, abstract, overview, attributes, access_count, updated_at, source
261               FROM entity
262               WHERE attributes.pipeline_stage = 'curiosity'
263                 AND attributes.pipeline_status = 'active'
264                 AND attributes.sub_type IS NONE
265                 AND updated_at < time::now() - type::duration($threshold)
266               ORDER BY updated_at ASC"#,
267        )
268        .bind(("threshold", format!("{}d", staleness_days * 2)))
269        .await?;
270
271    let stale_questions: Vec<EntityDetail> = super::deserialize_take(&mut stale_q_response, 0)?;
272
273    // Last movement (most recent graduated/dissolved/explored entity)
274    let mut movement_response = db
275        .query(
276            r#"SELECT updated_at
277               FROM entity
278               WHERE attributes.pipeline_status IN ['graduated', 'dissolved', 'explored']
279               ORDER BY updated_at DESC
280               LIMIT 1"#,
281        )
282        .await?;
283
284    let movement_rows: Vec<UpdatedAtRow> = super::deserialize_take(&mut movement_response, 0)?;
285    let last_movement = movement_rows.first().map(|r| match &r.updated_at {
286        serde_json::Value::String(s) => s.clone(),
287        other => other.to_string(),
288    });
289
290    Ok(PipelineGraphStats {
291        by_stage,
292        stale_thoughts,
293        stale_questions,
294        total_entities: total,
295        last_movement,
296    })
297}
298
299/// Trace the lineage of a pipeline entity through relationship chains.
300pub async fn pipeline_flow(
301    db: &Surreal<Db>,
302    entity_name: &str,
303) -> Result<Vec<(EntityDetail, String, EntityDetail)>, GraphError> {
304    // Get the entity
305    let entity = super::crud::get_entity_by_name(db, entity_name)
306        .await?
307        .ok_or_else(|| GraphError::NotFound(format!("entity: {}", entity_name)))?;
308
309    let entity_id = entity.id_string();
310    let mut chain = Vec::new();
311
312    // Get all pipeline relationships (both directions)
313    let pipeline_rel_types = [
314        "EVOLVED_FROM",
315        "CRYSTALLIZED_FROM",
316        "INFORMED_BY",
317        "GRADUATED_TO",
318        "CONNECTED_TO",
319        "EXPLORES",
320        "ARCHIVED_FROM",
321    ];
322    let rel_types_str = pipeline_rel_types
323        .iter()
324        .map(|r| format!("'{}'", r))
325        .collect::<Vec<_>>()
326        .join(", ");
327
328    // Outgoing relationships
329    let query_out = format!(
330        r#"SELECT rel_type, out AS target_id
331           FROM relates_to
332           WHERE in = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
333        rel_types_str
334    );
335    let mut response = db.query(&query_out).bind(("id", entity_id.clone())).await?;
336    let outgoing: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
337
338    for edge in &outgoing {
339        let tid = match &edge.target_id {
340            serde_json::Value::String(s) => s.clone(),
341            other => other.to_string(),
342        };
343        if let Some(target) = super::crud::get_entity_detail(db, &tid).await? {
344            let source_detail = super::crud::get_entity_detail(db, &entity_id)
345                .await?
346                .unwrap();
347            chain.push((source_detail, edge.rel_type.clone(), target));
348        }
349    }
350
351    // Incoming relationships
352    let query_in = format!(
353        r#"SELECT rel_type, in AS target_id
354           FROM relates_to
355           WHERE out = type::record($id) AND rel_type IN [{}] AND valid_until IS NONE"#,
356        rel_types_str
357    );
358    let mut response = db.query(&query_in).bind(("id", entity_id.clone())).await?;
359    let incoming: Vec<RelTarget> = super::deserialize_take(&mut response, 0)?;
360
361    for edge in &incoming {
362        let tid = match &edge.target_id {
363            serde_json::Value::String(s) => s.clone(),
364            other => other.to_string(),
365        };
366        if let Some(source) = super::crud::get_entity_detail(db, &tid).await? {
367            let target_detail = super::crud::get_entity_detail(db, &entity_id)
368                .await?
369                .unwrap();
370            chain.push((source, edge.rel_type.clone(), target_detail));
371        }
372    }
373
374    Ok(chain)
375}
376
377#[derive(serde::Deserialize)]
378struct StageStatusCount {
379    stage: String,
380    status: String,
381    count: u64,
382}
383
384#[derive(serde::Deserialize)]
385struct UpdatedAtRow {
386    updated_at: serde_json::Value,
387}