Skip to main content

ainl_memory/
query.rs

1//! Graph traversal and querying utilities.
2//!
3//! Higher-level query functions built on top of GraphStore.
4
5use crate::node::{AinlMemoryNode, AinlNodeType};
6use crate::store::GraphStore;
7use uuid::Uuid;
8
9/// Walk the graph from a starting node, following edges with a specific label
10///
11/// # Arguments
12/// * `store` - The graph store to query
13/// * `start_id` - Node ID to start from
14/// * `edge_label` - Label of edges to follow
15/// * `max_depth` - Maximum depth to traverse (prevents infinite loops)
16///
17/// # Returns
18/// Vector of nodes encountered during the walk, in breadth-first order
19pub fn walk_from(
20    store: &dyn GraphStore,
21    start_id: Uuid,
22    edge_label: &str,
23    max_depth: usize,
24) -> Result<Vec<AinlMemoryNode>, String> {
25    let mut visited = std::collections::HashSet::new();
26    let mut result = Vec::new();
27    let mut current_level = vec![start_id];
28
29    for _ in 0..max_depth {
30        if current_level.is_empty() {
31            break;
32        }
33
34        let mut next_level = Vec::new();
35
36        for node_id in current_level {
37            if visited.contains(&node_id) {
38                continue;
39            }
40            visited.insert(node_id);
41
42            if let Some(node) = store.read_node(node_id)? {
43                result.push(node.clone());
44
45                // Follow edges with the specified label
46                for next_node in store.walk_edges(node_id, edge_label)? {
47                    if !visited.contains(&next_node.id) {
48                        next_level.push(next_node.id);
49                    }
50                }
51            }
52        }
53
54        current_level = next_level;
55    }
56
57    Ok(result)
58}
59
60/// Recall recent episodes, optionally filtered by tool usage
61///
62/// # Arguments
63/// * `store` - The graph store to query
64/// * `since_timestamp` - Only return episodes after this timestamp (Unix seconds)
65/// * `limit` - Maximum number of episodes to return
66/// * `tool_filter` - If Some, only return episodes that used this tool
67pub fn recall_recent(
68    store: &dyn GraphStore,
69    since_timestamp: i64,
70    limit: usize,
71    tool_filter: Option<&str>,
72) -> Result<Vec<AinlMemoryNode>, String> {
73    let episodes = store.query_episodes_since(since_timestamp, limit)?;
74
75    if let Some(tool_name) = tool_filter {
76        Ok(episodes
77            .into_iter()
78            .filter(|node| match &node.node_type {
79                AinlNodeType::Episode { tool_calls, .. } => tool_calls.contains(&tool_name.to_string()),
80                _ => false,
81            })
82            .collect())
83    } else {
84        Ok(episodes)
85    }
86}
87
88/// Find procedural patterns by name prefix
89///
90/// # Arguments
91/// * `store` - The graph store to query
92/// * `name_prefix` - Pattern name prefix to match
93///
94/// # Returns
95/// Vector of procedural nodes whose pattern_name starts with the prefix
96pub fn find_patterns(
97    store: &dyn GraphStore,
98    name_prefix: &str,
99) -> Result<Vec<AinlMemoryNode>, String> {
100    let all_procedural = store.find_by_type("procedural")?;
101
102    Ok(all_procedural
103        .into_iter()
104        .filter(|node| match &node.node_type {
105            AinlNodeType::Procedural { pattern_name, .. } => {
106                pattern_name.starts_with(name_prefix)
107            }
108            _ => false,
109        })
110        .collect())
111}
112
113/// Find semantic facts with confidence above a threshold
114///
115/// # Arguments
116/// * `store` - The graph store to query
117/// * `min_confidence` - Minimum confidence score (0.0-1.0)
118///
119/// # Returns
120/// Vector of semantic nodes with confidence >= min_confidence
121pub fn find_high_confidence_facts(
122    store: &dyn GraphStore,
123    min_confidence: f32,
124) -> Result<Vec<AinlMemoryNode>, String> {
125    let all_semantic = store.find_by_type("semantic")?;
126
127    Ok(all_semantic
128        .into_iter()
129        .filter(|node| match &node.node_type {
130            AinlNodeType::Semantic { confidence, .. } => *confidence >= min_confidence,
131            _ => false,
132        })
133        .collect())
134}
135
136/// Find persona traits sorted by strength
137///
138/// # Arguments
139/// * `store` - The graph store to query
140///
141/// # Returns
142/// Vector of persona nodes sorted by strength (descending)
143pub fn find_strong_traits(store: &dyn GraphStore) -> Result<Vec<AinlMemoryNode>, String> {
144    let mut all_persona = store.find_by_type("persona")?;
145
146    all_persona.sort_by(|a, b| {
147        let strength_a = match &a.node_type {
148            AinlNodeType::Persona { strength, .. } => *strength,
149            _ => 0.0,
150        };
151        let strength_b = match &b.node_type {
152            AinlNodeType::Persona { strength, .. } => *strength,
153            _ => 0.0,
154        };
155        strength_b.partial_cmp(&strength_a).unwrap_or(std::cmp::Ordering::Equal)
156    });
157
158    Ok(all_persona)
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::node::AinlMemoryNode;
165    use crate::store::SqliteGraphStore;
166
167    #[test]
168    fn test_recall_recent_with_tool_filter() {
169        let temp_dir = std::env::temp_dir();
170        let db_path = temp_dir.join("ainl_query_test_recall.db");
171        let _ = std::fs::remove_file(&db_path);
172
173        let store = SqliteGraphStore::open(&db_path).expect("Failed to open store");
174
175        let now = chrono::Utc::now().timestamp();
176
177        // Create episodes with different tools
178        let node1 = AinlMemoryNode::new_episode(
179            uuid::Uuid::new_v4(),
180            now,
181            vec!["file_read".to_string()],
182            None,
183            None,
184        );
185
186        let node2 = AinlMemoryNode::new_episode(
187            uuid::Uuid::new_v4(),
188            now + 1,
189            vec!["agent_delegate".to_string()],
190            Some("agent-B".to_string()),
191            None,
192        );
193
194        store.write_node(&node1).expect("Write failed");
195        store.write_node(&node2).expect("Write failed");
196
197        // Query with tool filter
198        let delegations = recall_recent(&store, now - 100, 10, Some("agent_delegate"))
199            .expect("Query failed");
200
201        assert_eq!(delegations.len(), 1);
202    }
203
204    #[test]
205    fn test_find_high_confidence_facts() {
206        let temp_dir = std::env::temp_dir();
207        let db_path = temp_dir.join("ainl_query_test_facts.db");
208        let _ = std::fs::remove_file(&db_path);
209
210        let store = SqliteGraphStore::open(&db_path).expect("Failed to open store");
211
212        let turn_id = uuid::Uuid::new_v4();
213
214        let fact1 = AinlMemoryNode::new_fact("User prefers Rust".to_string(), 0.95, turn_id);
215        let fact2 = AinlMemoryNode::new_fact("User dislikes Python".to_string(), 0.45, turn_id);
216
217        store.write_node(&fact1).expect("Write failed");
218        store.write_node(&fact2).expect("Write failed");
219
220        let high_conf = find_high_confidence_facts(&store, 0.7).expect("Query failed");
221
222        assert_eq!(high_conf.len(), 1);
223    }
224}