Skip to main content

phago_rag/
query.rs

1//! Query engine — traverses the Hebbian knowledge graph to retrieve relevant concepts.
2//!
3//! The query engine:
4//! 1. Tokenizes the query into terms
5//! 2. Finds matching nodes in the graph (exact label match)
6//! 3. Traverses outward following strongest edges (BFS weighted by edge weight)
7//! 4. Collects and ranks results by path weight × access count
8//! 5. Optionally reinforces traversed paths (the graph learns from queries)
9
10use phago_core::topology::TopologyGraph;
11use phago_core::types::*;
12use phago_runtime::colony::Colony;
13use serde::Serialize;
14
15/// A query to the knowledge graph.
16#[derive(Debug, Clone)]
17pub struct Query {
18    /// The raw query text.
19    pub text: String,
20    /// Maximum number of results to return.
21    pub max_results: usize,
22    /// Maximum traversal depth from seed nodes.
23    pub max_depth: usize,
24    /// Whether to reinforce traversed paths (learning from queries).
25    pub reinforce: bool,
26}
27
28impl Query {
29    pub fn new(text: &str) -> Self {
30        Self {
31            text: text.to_string(),
32            max_results: 10,
33            max_depth: 3,
34            reinforce: true,
35        }
36    }
37
38    pub fn with_max_results(mut self, n: usize) -> Self {
39        self.max_results = n;
40        self
41    }
42
43    pub fn with_max_depth(mut self, d: usize) -> Self {
44        self.max_depth = d;
45        self
46    }
47
48    pub fn without_reinforcement(mut self) -> Self {
49        self.reinforce = false;
50        self
51    }
52}
53
54/// A single result from a query.
55#[derive(Debug, Clone, Serialize)]
56pub struct QueryResult {
57    /// The concept label.
58    pub label: String,
59    /// Node type (Concept, Insight, Anomaly).
60    pub node_type: NodeType,
61    /// How many times this node has been accessed/reinforced.
62    pub access_count: u64,
63    /// Relevance score (path weight × access count).
64    pub score: f64,
65    /// The path from a seed term to this result.
66    pub path: Vec<String>,
67    /// Node ID for further traversal.
68    pub node_id: NodeId,
69}
70
71/// The query engine — traverses the Hebbian graph.
72pub struct QueryEngine;
73
74impl QueryEngine {
75    /// Execute a query against the colony's knowledge graph.
76    ///
77    /// Returns results ranked by score (highest first).
78    pub fn query(colony: &mut Colony, q: &Query) -> Vec<QueryResult> {
79        let terms = tokenize(&q.text);
80        let graph = colony.substrate().graph();
81
82        // Phase 1: Find seed nodes (fuzzy substring matching)
83        // Uses substring matching so queries like "membrane" find nodes
84        // labeled "cell_membrane", "membrane_proteins", etc.
85        let mut seed_nodes: Vec<(NodeId, String)> = Vec::new();
86        let mut seed_seen: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
87        for term in &terms {
88            // First try exact match
89            for nid in graph.find_nodes_by_exact_label(term) {
90                if seed_seen.insert(*nid) {
91                    if let Some(node) = graph.get_node(nid) {
92                        seed_nodes.push((*nid, node.label.clone()));
93                    }
94                }
95            }
96            // Then try substring match for broader coverage
97            for nid in graph.find_nodes_by_label(term) {
98                if seed_seen.insert(nid) {
99                    if let Some(node) = graph.get_node(&nid) {
100                        seed_nodes.push((nid, node.label.clone()));
101                    }
102                }
103            }
104        }
105
106        if seed_nodes.is_empty() {
107            return Vec::new();
108        }
109
110        // Phase 2: Priority-queue traversal weighted by edge weight.
111        // Uses a sorted frontier (best-first search) with limited expansion budget.
112        // Stronger edges get explored first, so reinforcement directly affects
113        // which nodes appear in results.
114        let mut results: Vec<QueryResult> = Vec::new();
115        let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
116
117        // (priority=cumulative_weight, node_id, path, depth)
118        // Using Vec as a max-heap (sort and pop from end)
119        let mut frontier: Vec<(f64, NodeId, Vec<String>, usize)> = Vec::new();
120        let max_expansions: usize = 200; // Budget limits how much of the graph we explore
121        let mut expansions = 0;
122
123        // Compute median edge weight to filter out weak edges during traversal
124        let all_edges = graph.all_edges();
125        let edge_threshold = if all_edges.is_empty() {
126            0.0
127        } else {
128            let mut weights: Vec<f64> = all_edges.iter().map(|(_, _, e)| e.weight).collect();
129            weights.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
130            weights[weights.len() * 75 / 100]
131        };
132
133        for (nid, label) in &seed_nodes {
134            visited.insert(*nid);
135            if let Some(node) = graph.get_node(nid) {
136                // Seeds that match query terms get a high base score
137                let term_overlap = terms.iter()
138                    .filter(|t| node.label.to_lowercase().contains(t.as_str()))
139                    .count() as f64;
140                results.push(QueryResult {
141                    label: node.label.clone(),
142                    node_type: node.node_type.clone(),
143                    access_count: node.access_count,
144                    score: 10.0 + term_overlap * 5.0,
145                    path: vec![label.clone()],
146                    node_id: *nid,
147                });
148            }
149            frontier.push((1.0, *nid, vec![label.clone()], 0));
150        }
151
152        // Best-first search with edge filtering and additive hop-decay scoring.
153        // Only follows edges above the 75th percentile weight to avoid noise
154        // in dense graphs. Uses additive scoring (base_weight * decay^depth)
155        // instead of multiplicative cumulative weight which decays too fast.
156        while !frontier.is_empty() && expansions < max_expansions {
157            frontier.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
158            let (weight, current, path, depth) = frontier.pop().unwrap();
159            expansions += 1;
160
161            if depth >= q.max_depth {
162                continue;
163            }
164
165            let neighbors = graph.neighbors(&current);
166            for (nid, edge) in &neighbors {
167                if visited.contains(nid) {
168                    continue;
169                }
170                // Only follow strong edges (above 75th percentile)
171                if edge.weight < edge_threshold {
172                    continue;
173                }
174                visited.insert(*nid);
175
176                if let Some(node) = graph.get_node(nid) {
177                    // Additive scoring: edge weight decays per hop, term overlap dominates
178                    let hop_decay = 0.5_f64.powi((depth + 1) as i32);
179                    let graph_score = edge.weight * hop_decay * (1.0 + edge.co_activations as f64 * 0.1);
180                    let term_overlap = terms.iter()
181                        .filter(|t| node.label.to_lowercase().contains(t.as_str()))
182                        .count() as f64;
183                    let term_bonus = term_overlap * 5.0;
184                    let score = graph_score + term_bonus;
185
186                    let mut node_path = path.clone();
187                    node_path.push(node.label.clone());
188
189                    results.push(QueryResult {
190                        label: node.label.clone(),
191                        node_type: node.node_type.clone(),
192                        access_count: node.access_count,
193                        score,
194                        path: node_path.clone(),
195                        node_id: *nid,
196                    });
197
198                    frontier.push((weight * edge.weight, *nid, node_path, depth + 1));
199                }
200            }
201        }
202
203        // Phase 3: Rank by score
204        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
205        results.truncate(q.max_results);
206
207        // Phase 4: Reinforce traversed paths (the graph learns from queries)
208        // True Hebbian: "neurons that fire together wire together"
209        //
210        // Key insight: A result node that connects to MULTIPLE seed terms
211        // is more likely relevant (it bridges query concepts). We reinforce
212        // these multi-seed results more heavily.
213        if q.reinforce && !results.is_empty() {
214            let seed_ids: Vec<NodeId> = seed_nodes.iter().map(|(nid, _)| *nid).collect();
215
216            // Count how many seed nodes each result is connected to
217            let graph_ref = colony.substrate().graph();
218            let mut result_seed_connections: Vec<(NodeId, usize)> = Vec::new();
219            for result in &results {
220                let mut seed_count = 0;
221                for seed_id in &seed_ids {
222                    if *seed_id == result.node_id {
223                        seed_count += 1;
224                    } else if graph_ref.get_edge(seed_id, &result.node_id).is_some() {
225                        seed_count += 1;
226                    }
227                }
228                result_seed_connections.push((result.node_id, seed_count));
229            }
230
231            let graph_mut = colony.substrate_mut().graph_mut();
232
233            // Reinforce based on seed connectivity (Hebbian correlation)
234            for (result_id, seed_count) in &result_seed_connections {
235                if *seed_count == 0 {
236                    continue;
237                }
238                let multi_seed_bonus = *seed_count as f64;
239
240                // Boost access count proportional to seed connectivity
241                if let Some(node) = graph_mut.get_node_mut(result_id) {
242                    node.access_count += (*seed_count as u64) * 2;
243                }
244
245                // Strengthen all seed↔result edges
246                for seed_id in &seed_ids {
247                    if seed_id == result_id { continue; }
248                    if let Some(edge) = graph_mut.get_edge_mut(seed_id, result_id) {
249                        let boost = 0.05 * multi_seed_bonus;
250                        edge.weight = (edge.weight + boost).min(1.0);
251                        edge.co_activations += 1;
252                    }
253                }
254            }
255        }
256
257        results
258    }
259}
260
261/// Simple tokenizer — lowercase, split on whitespace, filter stopwords and short words.
262fn tokenize(text: &str) -> Vec<String> {
263    let stopwords: std::collections::HashSet<&str> = [
264        "the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
265        "have", "has", "had", "do", "does", "did", "will", "would", "could",
266        "should", "may", "might", "shall", "can", "need", "dare", "ought",
267        "used", "to", "of", "in", "for", "on", "with", "at", "by", "from",
268        "as", "into", "through", "during", "before", "after", "above", "below",
269        "between", "out", "off", "over", "under", "again", "further", "then",
270        "once", "here", "there", "when", "where", "why", "how", "all", "each",
271        "every", "both", "few", "more", "most", "other", "some", "such", "no",
272        "nor", "not", "only", "own", "same", "so", "than", "too", "very",
273        "and", "but", "or", "if", "while", "what", "which", "who", "this",
274        "that", "these", "those", "it", "its",
275    ].iter().cloned().collect();
276
277    text.to_lowercase()
278        .split_whitespace()
279        .filter(|w| w.len() >= 3 && !stopwords.contains(w))
280        .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()).to_string())
281        .filter(|w| w.len() >= 3)
282        .collect()
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use phago_agents::digester::Digester;
289
290    #[test]
291    fn query_returns_results_from_digested_documents() {
292        let mut colony = Colony::new();
293        colony.ingest_document(
294            "Biology",
295            "The cell membrane controls transport of molecules. Proteins serve as channels \
296             and receptors for signaling cascades in the cellular environment.",
297            Position::new(0.0, 0.0),
298        );
299        colony.spawn(Box::new(Digester::new(Position::new(0.0, 0.0)).with_max_idle(80)));
300        colony.run(15);
301
302        let q = Query::new("cell membrane").without_reinforcement();
303        let results = QueryEngine::query(&mut colony, &q);
304
305        assert!(!results.is_empty(), "query should return results");
306        assert!(results[0].score > 0.0, "results should have positive scores");
307    }
308
309    #[test]
310    fn query_reinforces_traversed_nodes() {
311        let mut colony = Colony::new();
312        colony.ingest_document(
313            "Biology",
314            "The cell membrane controls transport of molecules. Proteins serve as channels.",
315            Position::new(0.0, 0.0),
316        );
317        colony.spawn(Box::new(Digester::new(Position::new(0.0, 0.0)).with_max_idle(80)));
318        colony.run(15);
319
320        // Get initial access count
321        let q = Query::new("cell").without_reinforcement();
322        let results_before = QueryEngine::query(&mut colony, &q);
323        let initial_access = results_before.first().map(|r| r.access_count).unwrap_or(0);
324
325        // Query with reinforcement
326        let q = Query::new("cell");
327        let _ = QueryEngine::query(&mut colony, &q);
328
329        // Check access count increased
330        let q = Query::new("cell").without_reinforcement();
331        let results_after = QueryEngine::query(&mut colony, &q);
332        let after_access = results_after.first().map(|r| r.access_count).unwrap_or(0);
333
334        assert!(after_access > initial_access, "reinforcement should increase access count");
335    }
336
337    #[test]
338    fn empty_query_returns_empty() {
339        let mut colony = Colony::new();
340        let q = Query::new("nonexistent term xyz");
341        let results = QueryEngine::query(&mut colony, &q);
342        assert!(results.is_empty());
343    }
344
345    #[test]
346    fn tokenizer_filters_stopwords() {
347        let tokens = tokenize("the cell is a membrane");
348        assert!(tokens.contains(&"cell".to_string()));
349        assert!(tokens.contains(&"membrane".to_string()));
350        assert!(!tokens.contains(&"the".to_string()));
351        assert!(!tokens.contains(&"is".to_string()));
352    }
353}