1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::entity::{Entity, EntityId, EntityType};
7use crate::relation::{Relation, RelationType};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct KnowledgeGraph {
38 entities: HashMap<EntityId, Entity>,
40 outgoing: HashMap<EntityId, Vec<Relation>>,
42 incoming: HashMap<EntityId, Vec<Relation>>,
44 type_index: HashMap<EntityType, Vec<EntityId>>,
46}
47
48impl KnowledgeGraph {
49 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn add_entity(&mut self, entity: Entity) -> EntityId {
56 let id = entity.id;
57 self.type_index
58 .entry(entity.entity_type.clone())
59 .or_default()
60 .push(id);
61 self.entities.insert(id, entity);
62 id
63 }
64
65 pub fn add_relation(&mut self, relation: Relation) {
67 self.incoming
68 .entry(relation.target)
69 .or_default()
70 .push(relation.clone());
71 self.outgoing
72 .entry(relation.source)
73 .or_default()
74 .push(relation);
75 }
76
77 pub fn entity(&self, id: EntityId) -> Option<&Entity> {
79 self.entities.get(&id)
80 }
81
82 pub fn entities_by_type(&self, entity_type: &EntityType) -> Vec<&Entity> {
84 self.type_index
85 .get(entity_type)
86 .map(|ids| ids.iter().filter_map(|id| self.entities.get(id)).collect())
87 .unwrap_or_default()
88 }
89
90 pub fn neighbors(
92 &self,
93 entity_id: EntityId,
94 relation_type: Option<RelationType>,
95 ) -> Vec<(&Entity, &Relation)> {
96 let relations = self.outgoing.get(&entity_id);
97 match relations {
98 Some(rels) => rels
99 .iter()
100 .filter(|r| {
101 relation_type
102 .as_ref()
103 .map(|rt| r.relation_type == *rt)
104 .unwrap_or(true)
105 })
106 .filter_map(|r| self.entities.get(&r.target).map(|e| (e, r)))
107 .collect(),
108 None => vec![],
109 }
110 }
111
112 pub fn incoming_neighbors(
114 &self,
115 entity_id: EntityId,
116 relation_type: Option<RelationType>,
117 ) -> Vec<(&Entity, &Relation)> {
118 let relations = self.incoming.get(&entity_id);
119 match relations {
120 Some(rels) => rels
121 .iter()
122 .filter(|r| {
123 relation_type
124 .as_ref()
125 .map(|rt| r.relation_type == *rt)
126 .unwrap_or(true)
127 })
128 .filter_map(|r| self.entities.get(&r.source).map(|e| (e, r)))
129 .collect(),
130 None => vec![],
131 }
132 }
133
134 pub fn traverse(
138 &self,
139 start: EntityId,
140 relation_types: &[RelationType],
141 max_hops: usize,
142 ) -> Vec<(EntityId, usize)> {
143 let mut visited = HashMap::new();
144 let mut frontier = vec![(start, 0usize)];
145
146 while let Some((node, depth)) = frontier.pop() {
147 if depth > max_hops {
148 continue;
149 }
150 if visited.contains_key(&node) {
151 continue;
152 }
153 visited.insert(node, depth);
154
155 if let Some(rels) = self.outgoing.get(&node) {
156 for rel in rels {
157 if relation_types.contains(&rel.relation_type)
158 && !visited.contains_key(&rel.target)
159 {
160 frontier.push((rel.target, depth + 1));
161 }
162 }
163 }
164 }
165
166 visited.remove(&start);
167 let mut result: Vec<_> = visited.into_iter().collect();
168 result.sort_by_key(|&(_, d)| d);
169 result
170 }
171
172 pub fn find_path(
176 &self,
177 from: EntityId,
178 to: EntityId,
179 relation_types: &[RelationType],
180 max_hops: usize,
181 ) -> Option<Vec<EntityId>> {
182 let mut visited = HashMap::new();
183 let mut frontier = vec![(from, vec![from])];
184
185 while let Some((node, path)) = frontier.pop() {
186 if node == to {
187 return Some(path);
188 }
189 if path.len() > max_hops + 1 {
190 continue;
191 }
192 if visited.contains_key(&node) {
193 continue;
194 }
195 visited.insert(node, true);
196
197 if let Some(rels) = self.outgoing.get(&node) {
198 for rel in rels {
199 if relation_types.contains(&rel.relation_type)
200 && !visited.contains_key(&rel.target)
201 {
202 let mut new_path = path.clone();
203 new_path.push(rel.target);
204 frontier.push((rel.target, new_path));
205 }
206 }
207 }
208 }
209
210 None
211 }
212
213 pub fn task_plan(&self, task_id: EntityId) -> Vec<EntityId> {
217 let required = self.neighbors(task_id, Some(RelationType::Requires));
219 if required.is_empty() {
220 return vec![];
221 }
222
223 let step_ids: Vec<EntityId> = required.iter().map(|(e, _)| e.id).collect();
225 let mut first = step_ids[0];
226 for &sid in &step_ids {
227 let predecessors = self.incoming_neighbors(sid, Some(RelationType::Precedes));
228 if predecessors.is_empty()
229 || predecessors.iter().all(|(e, _)| !step_ids.contains(&e.id))
230 {
231 first = sid;
232 break;
233 }
234 }
235
236 let mut plan = vec![first];
238 let mut current = first;
239 for _ in 0..100 {
240 let next = self.neighbors(current, Some(RelationType::Precedes));
242 if let Some((entity, _)) = next.first() {
243 plan.push(entity.id);
244 current = entity.id;
245 } else {
246 break;
247 }
248 }
249
250 plan
251 }
252
253 pub fn n_entities(&self) -> usize {
255 self.entities.len()
256 }
257
258 pub fn n_relations(&self) -> usize {
260 self.outgoing.values().map(|r| r.len()).sum()
261 }
262
263 pub fn stats(&self) -> String {
265 let mut type_counts: HashMap<&EntityType, usize> = HashMap::new();
266 for e in self.entities.values() {
267 *type_counts.entry(&e.entity_type).or_default() += 1;
268 }
269 let types: Vec<String> = type_counts
270 .iter()
271 .map(|(t, c)| format!("{t:?}={c}"))
272 .collect();
273 format!(
274 "{} entities ({}), {} relations",
275 self.n_entities(),
276 types.join(", "),
277 self.n_relations()
278 )
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 fn build_task_graph() -> KnowledgeGraph {
287 let mut kg = KnowledgeGraph::new();
288
289 kg.add_entity(Entity::new(1, EntityType::Task, "heat_then_place"));
291
292 kg.add_entity(Entity::new(10, EntityType::Action, "find"));
294 kg.add_entity(Entity::new(11, EntityType::Action, "take"));
295 kg.add_entity(Entity::new(12, EntityType::Action, "go_microwave"));
296 kg.add_entity(Entity::new(13, EntityType::Action, "heat"));
297 kg.add_entity(Entity::new(14, EntityType::Action, "take_heated"));
298 kg.add_entity(Entity::new(15, EntityType::Action, "go_target"));
299 kg.add_entity(Entity::new(16, EntityType::Action, "put"));
300
301 kg.add_relation(Relation::new(1, 10, RelationType::Requires, 1.0));
303
304 kg.add_relation(Relation::new(10, 11, RelationType::Precedes, 1.0));
306 kg.add_relation(Relation::new(11, 12, RelationType::Precedes, 1.0));
307 kg.add_relation(Relation::new(12, 13, RelationType::Precedes, 1.0));
308 kg.add_relation(Relation::new(13, 14, RelationType::Precedes, 1.0));
309 kg.add_relation(Relation::new(14, 15, RelationType::Precedes, 1.0));
310 kg.add_relation(Relation::new(15, 16, RelationType::Precedes, 1.0));
311
312 kg
313 }
314
315 #[test]
316 fn graph_structure() {
317 let kg = build_task_graph();
318 assert_eq!(kg.n_entities(), 8); assert_eq!(kg.n_relations(), 7); }
321
322 #[test]
323 fn neighbors_filtered() {
324 let kg = build_task_graph();
325
326 let required = kg.neighbors(1, Some(RelationType::Requires));
327 assert_eq!(required.len(), 1);
328 assert_eq!(required[0].0.name, "find");
329
330 let all = kg.neighbors(10, None);
331 assert_eq!(all.len(), 1); }
333
334 #[test]
335 fn traverse_chain() {
336 let kg = build_task_graph();
337
338 let reachable = kg.traverse(10, &[RelationType::Precedes], 10);
339 assert_eq!(reachable.len(), 6); let find_take = reachable.iter().find(|&&(id, _)| id == 11);
343 assert_eq!(find_take.unwrap().1, 1);
344
345 let find_put = reachable.iter().find(|&&(id, _)| id == 16);
346 assert_eq!(find_put.unwrap().1, 6);
347 }
348
349 #[test]
350 fn traverse_limited_hops() {
351 let kg = build_task_graph();
352 let reachable = kg.traverse(10, &[RelationType::Precedes], 2);
353 assert_eq!(reachable.len(), 2); }
355
356 #[test]
357 fn find_path() {
358 let kg = build_task_graph();
359
360 let path = kg.find_path(10, 16, &[RelationType::Precedes], 10);
361 assert!(path.is_some());
362 let path = path.unwrap();
363 assert_eq!(path.first(), Some(&10)); assert_eq!(path.last(), Some(&16)); assert_eq!(path.len(), 7);
366 }
367
368 #[test]
369 fn find_path_no_route() {
370 let kg = build_task_graph();
371 let path = kg.find_path(16, 10, &[RelationType::Precedes], 10);
372 assert!(path.is_none()); }
374
375 #[test]
376 fn task_plan() {
377 let kg = build_task_graph();
378 let plan = kg.task_plan(1);
379 assert_eq!(plan.len(), 7);
380 assert_eq!(plan[0], 10); assert_eq!(plan[6], 16); }
383
384 #[test]
385 fn entities_by_type() {
386 let kg = build_task_graph();
387 let actions = kg.entities_by_type(&EntityType::Action);
388 assert_eq!(actions.len(), 7);
389
390 let tasks = kg.entities_by_type(&EntityType::Task);
391 assert_eq!(tasks.len(), 1);
392 }
393
394 #[test]
395 fn incoming_neighbors() {
396 let kg = build_task_graph();
397 let predecessors = kg.incoming_neighbors(13, Some(RelationType::Precedes));
398 assert_eq!(predecessors.len(), 1);
399 assert_eq!(predecessors[0].0.name, "go_microwave");
400 }
401
402 #[test]
403 fn shared_sub_plans() {
404 let mut kg = build_task_graph();
405
406 kg.add_entity(Entity::new(2, EntityType::Task, "clean_then_place"));
408 kg.add_entity(Entity::new(20, EntityType::Action, "go_sink"));
409 kg.add_entity(Entity::new(21, EntityType::Action, "clean"));
410
411 kg.add_relation(Relation::new(2, 10, RelationType::Requires, 1.0));
413
414 kg.add_relation(Relation::new(11, 20, RelationType::Precedes, 1.0));
416 kg.add_relation(Relation::new(20, 21, RelationType::Precedes, 1.0));
417
418 let find_reachable = kg.traverse(10, &[RelationType::Precedes], 1);
420 assert_eq!(find_reachable.len(), 1); assert_eq!(find_reachable[0].0, 11);
422
423 let take_neighbors = kg.neighbors(11, Some(RelationType::Precedes));
425 assert_eq!(take_neighbors.len(), 2); }
427
428 #[test]
429 fn serialization() {
430 let kg = build_task_graph();
431 let bytes = postcard::to_allocvec(&kg).unwrap();
432 let restored: KnowledgeGraph = postcard::from_bytes(&bytes).unwrap();
433 assert_eq!(restored.n_entities(), 8);
434 assert_eq!(restored.n_relations(), 7);
435 }
436}