1#[derive(Debug, Clone)]
5pub struct Entity {
6 pub id: u64,
7 pub name: String,
8 pub entity_type: String,
9 pub embedding_idx: i64,
11}
12
13#[derive(Debug, Clone)]
15pub struct Relation {
16 pub src: u64,
17 pub tgt: u64,
18 pub relation: String,
19 pub weight: f32,
20 pub ts: f64,
21}
22
23#[derive(Debug, Clone)]
25pub struct KnowledgeCache {
26 pub entities: Vec<Entity>,
27 pub relations: Vec<Relation>,
28 pub alias_strings: Vec<String>,
29 pub alias_entity_ids: Vec<i64>,
30 next_entity_id: u64,
31}
32
33impl KnowledgeCache {
34 pub fn new() -> Self {
35 Self {
36 entities: Vec::new(),
37 relations: Vec::new(),
38 alias_strings: Vec::new(),
39 alias_entity_ids: Vec::new(),
40 next_entity_id: 0,
41 }
42 }
43
44 pub fn new_with_next_id(next_id: u64) -> Self {
45 Self {
46 entities: Vec::new(),
47 relations: Vec::new(),
48 alias_strings: Vec::new(),
49 alias_entity_ids: Vec::new(),
50 next_entity_id: next_id,
51 }
52 }
53
54 pub fn add_entity(&mut self, name: &str, entity_type: &str, embedding_idx: i64) -> u64 {
56 let id = self.next_entity_id;
57 self.next_entity_id += 1;
58 self.entities.push(Entity {
59 id,
60 name: name.to_string(),
61 entity_type: entity_type.to_string(),
62 embedding_idx,
63 });
64 id
65 }
66
67 pub fn add_relation(&mut self, src: u64, tgt: u64, relation: &str, weight: f32) {
69 let ts = std::time::SystemTime::now()
70 .duration_since(std::time::UNIX_EPOCH)
71 .unwrap_or_default()
72 .as_secs_f64()
73 * 1_000_000.0;
74 self.relations.push(Relation {
75 src,
76 tgt,
77 relation: relation.to_string(),
78 weight,
79 ts,
80 });
81 }
82
83 pub fn get_entity(&self, id: u64) -> Option<&Entity> {
85 self.entities.iter().find(|e| e.id == id)
86 }
87
88 pub fn get_relations_from(&self, src_id: u64) -> Vec<&Relation> {
90 self.relations.iter().filter(|r| r.src == src_id).collect()
91 }
92
93 pub fn get_relations_to(&self, tgt_id: u64) -> Vec<&Relation> {
95 self.relations.iter().filter(|r| r.tgt == tgt_id).collect()
96 }
97
98 pub fn add_alias(&mut self, alias: &str, entity_id: i64) {
100 self.alias_strings.push(alias.to_lowercase());
101 self.alias_entity_ids.push(entity_id);
102 }
103
104 pub fn get_aliases(&self, entity_id: i64) -> Vec<&str> {
106 self.alias_strings
107 .iter()
108 .zip(&self.alias_entity_ids)
109 .filter(|&(_, id)| *id == entity_id)
110 .map(|(s, _)| s.as_str())
111 .collect()
112 }
113
114 pub fn get_entity_name(&self, entity_id: i64) -> Option<&str> {
116 self.entities
117 .iter()
118 .find(|e| e.id == entity_id as u64)
119 .map(|e| e.name.as_str())
120 }
121
122 pub fn resolve_aliases(&self, query: &str) -> String {
124 let lower = query.to_lowercase();
125 let mut pairs: Vec<(&str, String)> = self
126 .alias_strings
127 .iter()
128 .zip(&self.alias_entity_ids)
129 .filter_map(|(alias, &eid)| {
130 self.get_entity_name(eid)
131 .map(|name| (alias.as_str(), name.to_lowercase()))
132 })
133 .collect();
134 pairs.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
135
136 let mut result = lower;
137 for (alias, name) in &pairs {
138 result = result.replace(alias, name);
139 }
140 result
141 }
142}
143
144impl Default for KnowledgeCache {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_add_and_get_aliases() {
156 let mut cache = KnowledgeCache::new();
157 let id = cache.add_entity("Henry", "person", -1);
158 cache.add_alias("my son", id as i64);
159 cache.add_alias("the kid", id as i64);
160
161 let aliases = cache.get_aliases(id as i64);
162 assert_eq!(aliases.len(), 2);
163 assert!(aliases.contains(&"my son"));
164 assert!(aliases.contains(&"the kid"));
165 }
166
167 #[test]
168 fn test_resolve_single_alias() {
169 let mut cache = KnowledgeCache::new();
170 let id = cache.add_entity("Henry", "person", -1);
171 cache.add_alias("my son", id as i64);
172
173 let resolved = cache.resolve_aliases("what does my son do?");
174 assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
175 }
176
177 #[test]
178 fn test_resolve_multiple_aliases() {
179 let mut cache = KnowledgeCache::new();
180 let h = cache.add_entity("Henry", "person", -1);
181 let a = cache.add_entity("Acme Corp", "company", -1);
182 cache.add_alias("my son", h as i64);
183 cache.add_alias("our main client", a as i64);
184
185 let resolved = cache.resolve_aliases("what does my son do at our main client?");
186 assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
187 assert!(resolved.contains("acme corp"), "expected 'acme corp' in '{resolved}'");
188 }
189
190 #[test]
191 fn test_longest_match_wins() {
192 let mut cache = KnowledgeCache::new();
193 let id = cache.add_entity("Henry", "person", -1);
194 cache.add_alias("my son", id as i64);
195 cache.add_alias("my son henry", id as i64);
196
197 let resolved = cache.resolve_aliases("ask my son henry");
198 assert_eq!(resolved, "ask henry");
201 }
202
203 #[test]
204 fn test_unregistered_alias_passthrough() {
205 let mut cache = KnowledgeCache::new();
206 cache.add_entity("Henry", "person", -1);
207 let resolved = cache.resolve_aliases("unknown phrase here");
209 assert_eq!(resolved, "unknown phrase here");
210 }
211
212 #[test]
213 fn test_case_insensitive_resolve() {
214 let mut cache = KnowledgeCache::new();
215 let id = cache.add_entity("Henry", "person", -1);
216 cache.add_alias("henry", id as i64);
217
218 let resolved = cache.resolve_aliases("Tell HENRY about it");
219 assert!(resolved.contains("henry"), "expected 'henry' in '{resolved}'");
220 }
221
222 #[test]
223 fn test_empty_aliases() {
224 let cache = KnowledgeCache::new();
225 let resolved = cache.resolve_aliases("hello");
226 assert_eq!(resolved, "hello");
227 }
228
229 #[test]
230 fn test_get_entity_name() {
231 let mut cache = KnowledgeCache::new();
232 let id = cache.add_entity("Henry", "person", -1);
233 assert_eq!(cache.get_entity_name(id as i64), Some("Henry"));
234 assert_eq!(cache.get_entity_name(999), None);
235 }
236}