Skip to main content

edgehdf5_memory/
knowledge.rs

1//! Knowledge graph data structures and cache.
2
3/// A knowledge graph entity.
4#[derive(Debug, Clone)]
5pub struct Entity {
6    pub id: u64,
7    pub name: String,
8    pub entity_type: String,
9    /// Index into the memory embeddings array, or -1 if none.
10    pub embedding_idx: i64,
11}
12
13/// A knowledge graph relation between two entities.
14#[derive(Debug, Clone)]
15pub struct Relation {
16    pub src: u64,
17    pub tgt: u64,
18    pub relation: String,
19    pub weight: f32,
20    pub ts: f64,
21}
22
23/// In-memory cache for the /knowledge_graph group.
24#[derive(Debug, Clone)]
25pub struct KnowledgeCache {
26    pub entities: Vec<Entity>,
27    pub relations: Vec<Relation>,
28    pub alias_strings: Vec<String>,
29    pub alias_entity_ids: Vec<i64>,
30    next_entity_id: u64,
31}
32
33impl KnowledgeCache {
34    pub fn new() -> Self {
35        Self {
36            entities: Vec::new(),
37            relations: Vec::new(),
38            alias_strings: Vec::new(),
39            alias_entity_ids: Vec::new(),
40            next_entity_id: 0,
41        }
42    }
43
44    pub fn new_with_next_id(next_id: u64) -> Self {
45        Self {
46            entities: Vec::new(),
47            relations: Vec::new(),
48            alias_strings: Vec::new(),
49            alias_entity_ids: Vec::new(),
50            next_entity_id: next_id,
51        }
52    }
53
54    /// Add an entity, returns its assigned ID.
55    pub fn add_entity(&mut self, name: &str, entity_type: &str, embedding_idx: i64) -> u64 {
56        let id = self.next_entity_id;
57        self.next_entity_id += 1;
58        self.entities.push(Entity {
59            id,
60            name: name.to_string(),
61            entity_type: entity_type.to_string(),
62            embedding_idx,
63        });
64        id
65    }
66
67    /// Add a relation between two entities.
68    pub fn add_relation(&mut self, src: u64, tgt: u64, relation: &str, weight: f32) {
69        let ts = std::time::SystemTime::now()
70            .duration_since(std::time::UNIX_EPOCH)
71            .unwrap_or_default()
72            .as_secs_f64()
73            * 1_000_000.0;
74        self.relations.push(Relation {
75            src,
76            tgt,
77            relation: relation.to_string(),
78            weight,
79            ts,
80        });
81    }
82
83    /// Find an entity by ID.
84    pub fn get_entity(&self, id: u64) -> Option<&Entity> {
85        self.entities.iter().find(|e| e.id == id)
86    }
87
88    /// Find all relations where the given entity is the source.
89    pub fn get_relations_from(&self, src_id: u64) -> Vec<&Relation> {
90        self.relations.iter().filter(|r| r.src == src_id).collect()
91    }
92
93    /// Find all relations where the given entity is the target.
94    pub fn get_relations_to(&self, tgt_id: u64) -> Vec<&Relation> {
95        self.relations.iter().filter(|r| r.tgt == tgt_id).collect()
96    }
97
98    /// Register an alias for an entity. Case-insensitive storage.
99    pub fn add_alias(&mut self, alias: &str, entity_id: i64) {
100        self.alias_strings.push(alias.to_lowercase());
101        self.alias_entity_ids.push(entity_id);
102    }
103
104    /// Get all aliases for a given entity.
105    pub fn get_aliases(&self, entity_id: i64) -> Vec<&str> {
106        self.alias_strings
107            .iter()
108            .zip(&self.alias_entity_ids)
109            .filter(|&(_, id)| *id == entity_id)
110            .map(|(s, _)| s.as_str())
111            .collect()
112    }
113
114    /// Get entity name by ID.
115    pub fn get_entity_name(&self, entity_id: i64) -> Option<&str> {
116        self.entities
117            .iter()
118            .find(|e| e.id == entity_id as u64)
119            .map(|e| e.name.as_str())
120    }
121
122    /// Resolve aliases in free text — greedy longest-match replacement.
123    pub fn resolve_aliases(&self, query: &str) -> String {
124        let lower = query.to_lowercase();
125        let mut pairs: Vec<(&str, String)> = self
126            .alias_strings
127            .iter()
128            .zip(&self.alias_entity_ids)
129            .filter_map(|(alias, &eid)| {
130                self.get_entity_name(eid)
131                    .map(|name| (alias.as_str(), name.to_lowercase()))
132            })
133            .collect();
134        pairs.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
135
136        let mut result = lower;
137        for (alias, name) in &pairs {
138            result = result.replace(alias, name);
139        }
140        result
141    }
142}
143
144impl Default for KnowledgeCache {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_add_and_get_aliases() {
156        let mut cache = KnowledgeCache::new();
157        let id = cache.add_entity("Henry", "person", -1);
158        cache.add_alias("my son", id as i64);
159        cache.add_alias("the kid", id as i64);
160
161        let aliases = cache.get_aliases(id as i64);
162        assert_eq!(aliases.len(), 2);
163        assert!(aliases.contains(&"my son"));
164        assert!(aliases.contains(&"the kid"));
165    }
166
167    #[test]
168    fn test_resolve_single_alias() {
169        let mut cache = KnowledgeCache::new();
170        let id = cache.add_entity("Henry", "person", -1);
171        cache.add_alias("my son", id as i64);
172
173        let resolved = cache.resolve_aliases("what does my son do?");
174        assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
175    }
176
177    #[test]
178    fn test_resolve_multiple_aliases() {
179        let mut cache = KnowledgeCache::new();
180        let h = cache.add_entity("Henry", "person", -1);
181        let a = cache.add_entity("Acme Corp", "company", -1);
182        cache.add_alias("my son", h as i64);
183        cache.add_alias("our main client", a as i64);
184
185        let resolved = cache.resolve_aliases("what does my son do at our main client?");
186        assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
187        assert!(resolved.contains("acme corp"), "expected 'acme corp' in '{resolved}'");
188    }
189
190    #[test]
191    fn test_longest_match_wins() {
192        let mut cache = KnowledgeCache::new();
193        let id = cache.add_entity("Henry", "person", -1);
194        cache.add_alias("my son", id as i64);
195        cache.add_alias("my son henry", id as i64);
196
197        let resolved = cache.resolve_aliases("ask my son henry");
198        // The longer alias "my son henry" should match first, producing "ask henry"
199        // If "my son" matched first, we'd get "ask henry henry" — wrong.
200        assert_eq!(resolved, "ask henry");
201    }
202
203    #[test]
204    fn test_unregistered_alias_passthrough() {
205        let mut cache = KnowledgeCache::new();
206        cache.add_entity("Henry", "person", -1);
207        // No aliases registered
208        let resolved = cache.resolve_aliases("unknown phrase here");
209        assert_eq!(resolved, "unknown phrase here");
210    }
211
212    #[test]
213    fn test_case_insensitive_resolve() {
214        let mut cache = KnowledgeCache::new();
215        let id = cache.add_entity("Henry", "person", -1);
216        cache.add_alias("henry", id as i64);
217
218        let resolved = cache.resolve_aliases("Tell HENRY about it");
219        assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
220    }
221
222    #[test]
223    fn test_empty_aliases() {
224        let cache = KnowledgeCache::new();
225        let resolved = cache.resolve_aliases("hello");
226        assert_eq!(resolved, "hello");
227    }
228
229    #[test]
230    fn test_get_entity_name() {
231        let mut cache = KnowledgeCache::new();
232        let id = cache.add_entity("Henry", "person", -1);
233        assert_eq!(cache.get_entity_name(id as i64), Some("Henry"));
234        assert_eq!(cache.get_entity_name(999), None);
235    }
236}