Skip to main content

cvx_graph/
graph.rs

1//! Knowledge graph: typed property graph with traversal and query.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use crate::entity::{Entity, EntityId, EntityType};
7use crate::relation::{Relation, RelationType};
8
9/// A knowledge graph with typed entities and relations.
10///
11/// # Example
12///
13/// ```
14/// use cvx_graph::{KnowledgeGraph, Entity, EntityType, Relation, RelationType};
15///
16/// let mut kg = KnowledgeGraph::new();
17///
18/// // Define a task plan
19/// let task = kg.add_entity(Entity::new(1, EntityType::Task, "heat_then_place"));
20/// let find = kg.add_entity(Entity::new(2, EntityType::Action, "find"));
21/// let take = kg.add_entity(Entity::new(3, EntityType::Action, "take"));
22/// let heat = kg.add_entity(Entity::new(4, EntityType::Action, "heat"));
23///
24/// kg.add_relation(Relation::new(task, find, RelationType::Requires, 1.0));
25/// kg.add_relation(Relation::new(find, take, RelationType::Precedes, 1.0));
26/// kg.add_relation(Relation::new(take, heat, RelationType::Precedes, 1.0));
27///
28/// // Query: what steps does heat_then_place require?
29/// let steps = kg.neighbors(task, Some(RelationType::Requires));
30/// assert_eq!(steps.len(), 1);
31///
32/// // Multi-hop: what comes after find?
33/// let chain = kg.traverse(find, &[RelationType::Precedes], 3);
34/// assert!(chain.len() >= 2); // take, heat
35/// ```
36#[derive(Debug, Clone, Default, Serialize, Deserialize)]
37pub struct KnowledgeGraph {
38    /// Entities indexed by ID.
39    entities: HashMap<EntityId, Entity>,
40    /// Outgoing relations: source → [relations].
41    outgoing: HashMap<EntityId, Vec<Relation>>,
42    /// Incoming relations: target → [relations].
43    incoming: HashMap<EntityId, Vec<Relation>>,
44    /// Index by entity type.
45    type_index: HashMap<EntityType, Vec<EntityId>>,
46}
47
48impl KnowledgeGraph {
49    /// Create an empty graph.
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Add an entity. Returns its ID.
55    pub fn add_entity(&mut self, entity: Entity) -> EntityId {
56        let id = entity.id;
57        self.type_index
58            .entry(entity.entity_type.clone())
59            .or_default()
60            .push(id);
61        self.entities.insert(id, entity);
62        id
63    }
64
65    /// Add a directed relation.
66    pub fn add_relation(&mut self, relation: Relation) {
67        self.incoming
68            .entry(relation.target)
69            .or_default()
70            .push(relation.clone());
71        self.outgoing
72            .entry(relation.source)
73            .or_default()
74            .push(relation);
75    }
76
77    /// Get an entity by ID.
78    pub fn entity(&self, id: EntityId) -> Option<&Entity> {
79        self.entities.get(&id)
80    }
81
82    /// Get all entities of a given type.
83    pub fn entities_by_type(&self, entity_type: &EntityType) -> Vec<&Entity> {
84        self.type_index
85            .get(entity_type)
86            .map(|ids| ids.iter().filter_map(|id| self.entities.get(id)).collect())
87            .unwrap_or_default()
88    }
89
90    /// Get outgoing neighbors, optionally filtered by relation type.
91    pub fn neighbors(
92        &self,
93        entity_id: EntityId,
94        relation_type: Option<RelationType>,
95    ) -> Vec<(&Entity, &Relation)> {
96        let relations = self.outgoing.get(&entity_id);
97        match relations {
98            Some(rels) => rels
99                .iter()
100                .filter(|r| {
101                    relation_type
102                        .as_ref()
103                        .map(|rt| r.relation_type == *rt)
104                        .unwrap_or(true)
105                })
106                .filter_map(|r| self.entities.get(&r.target).map(|e| (e, r)))
107                .collect(),
108            None => vec![],
109        }
110    }
111
112    /// Get incoming neighbors (reverse edges).
113    pub fn incoming_neighbors(
114        &self,
115        entity_id: EntityId,
116        relation_type: Option<RelationType>,
117    ) -> Vec<(&Entity, &Relation)> {
118        let relations = self.incoming.get(&entity_id);
119        match relations {
120            Some(rels) => rels
121                .iter()
122                .filter(|r| {
123                    relation_type
124                        .as_ref()
125                        .map(|rt| r.relation_type == *rt)
126                        .unwrap_or(true)
127                })
128                .filter_map(|r| self.entities.get(&r.source).map(|e| (e, r)))
129                .collect(),
130            None => vec![],
131        }
132    }
133
134    /// Multi-hop traversal following specified relation types.
135    ///
136    /// Returns all reachable entities with their hop distance.
137    pub fn traverse(
138        &self,
139        start: EntityId,
140        relation_types: &[RelationType],
141        max_hops: usize,
142    ) -> Vec<(EntityId, usize)> {
143        let mut visited = HashMap::new();
144        let mut frontier = vec![(start, 0usize)];
145
146        while let Some((node, depth)) = frontier.pop() {
147            if depth > max_hops {
148                continue;
149            }
150            if visited.contains_key(&node) {
151                continue;
152            }
153            visited.insert(node, depth);
154
155            if let Some(rels) = self.outgoing.get(&node) {
156                for rel in rels {
157                    if relation_types.contains(&rel.relation_type)
158                        && !visited.contains_key(&rel.target)
159                    {
160                        frontier.push((rel.target, depth + 1));
161                    }
162                }
163            }
164        }
165
166        visited.remove(&start);
167        let mut result: Vec<_> = visited.into_iter().collect();
168        result.sort_by_key(|&(_, d)| d);
169        result
170    }
171
172    /// Find a path between two entities following given relation types.
173    ///
174    /// Returns the path as a sequence of entity IDs, or None if no path.
175    pub fn find_path(
176        &self,
177        from: EntityId,
178        to: EntityId,
179        relation_types: &[RelationType],
180        max_hops: usize,
181    ) -> Option<Vec<EntityId>> {
182        let mut visited = HashMap::new();
183        let mut frontier = vec![(from, vec![from])];
184
185        while let Some((node, path)) = frontier.pop() {
186            if node == to {
187                return Some(path);
188            }
189            if path.len() > max_hops + 1 {
190                continue;
191            }
192            if visited.contains_key(&node) {
193                continue;
194            }
195            visited.insert(node, true);
196
197            if let Some(rels) = self.outgoing.get(&node) {
198                for rel in rels {
199                    if relation_types.contains(&rel.relation_type)
200                        && !visited.contains_key(&rel.target)
201                    {
202                        let mut new_path = path.clone();
203                        new_path.push(rel.target);
204                        frontier.push((rel.target, new_path));
205                    }
206                }
207            }
208        }
209
210        None
211    }
212
213    /// Get the ordered sequence of steps for a task.
214    ///
215    /// Follows `Requires` from the task, then `Precedes` between steps.
216    pub fn task_plan(&self, task_id: EntityId) -> Vec<EntityId> {
217        // Find the first step (required by task, not preceded by anything else in this task)
218        let required = self.neighbors(task_id, Some(RelationType::Requires));
219        if required.is_empty() {
220            return vec![];
221        }
222
223        // Find the step that has no predecessor in this set
224        let step_ids: Vec<EntityId> = required.iter().map(|(e, _)| e.id).collect();
225        let mut first = step_ids[0];
226        for &sid in &step_ids {
227            let predecessors = self.incoming_neighbors(sid, Some(RelationType::Precedes));
228            if predecessors.is_empty()
229                || predecessors.iter().all(|(e, _)| !step_ids.contains(&e.id))
230            {
231                first = sid;
232                break;
233            }
234        }
235
236        // Walk the Precedes chain
237        let mut plan = vec![first];
238        let mut current = first;
239        for _ in 0..100 {
240            // safety limit
241            let next = self.neighbors(current, Some(RelationType::Precedes));
242            if let Some((entity, _)) = next.first() {
243                plan.push(entity.id);
244                current = entity.id;
245            } else {
246                break;
247            }
248        }
249
250        plan
251    }
252
253    /// Number of entities.
254    pub fn n_entities(&self) -> usize {
255        self.entities.len()
256    }
257
258    /// Number of relations.
259    pub fn n_relations(&self) -> usize {
260        self.outgoing.values().map(|r| r.len()).sum()
261    }
262
263    /// Summary statistics.
264    pub fn stats(&self) -> String {
265        let mut type_counts: HashMap<&EntityType, usize> = HashMap::new();
266        for e in self.entities.values() {
267            *type_counts.entry(&e.entity_type).or_default() += 1;
268        }
269        let types: Vec<String> = type_counts
270            .iter()
271            .map(|(t, c)| format!("{t:?}={c}"))
272            .collect();
273        format!(
274            "{} entities ({}), {} relations",
275            self.n_entities(),
276            types.join(", "),
277            self.n_relations()
278        )
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    fn build_task_graph() -> KnowledgeGraph {
287        let mut kg = KnowledgeGraph::new();
288
289        // Task: heat_then_place
290        kg.add_entity(Entity::new(1, EntityType::Task, "heat_then_place"));
291
292        // Steps in order
293        kg.add_entity(Entity::new(10, EntityType::Action, "find"));
294        kg.add_entity(Entity::new(11, EntityType::Action, "take"));
295        kg.add_entity(Entity::new(12, EntityType::Action, "go_microwave"));
296        kg.add_entity(Entity::new(13, EntityType::Action, "heat"));
297        kg.add_entity(Entity::new(14, EntityType::Action, "take_heated"));
298        kg.add_entity(Entity::new(15, EntityType::Action, "go_target"));
299        kg.add_entity(Entity::new(16, EntityType::Action, "put"));
300
301        // Task requires first step
302        kg.add_relation(Relation::new(1, 10, RelationType::Requires, 1.0));
303
304        // Step chain
305        kg.add_relation(Relation::new(10, 11, RelationType::Precedes, 1.0));
306        kg.add_relation(Relation::new(11, 12, RelationType::Precedes, 1.0));
307        kg.add_relation(Relation::new(12, 13, RelationType::Precedes, 1.0));
308        kg.add_relation(Relation::new(13, 14, RelationType::Precedes, 1.0));
309        kg.add_relation(Relation::new(14, 15, RelationType::Precedes, 1.0));
310        kg.add_relation(Relation::new(15, 16, RelationType::Precedes, 1.0));
311
312        kg
313    }
314
315    #[test]
316    fn graph_structure() {
317        let kg = build_task_graph();
318        assert_eq!(kg.n_entities(), 8); // 1 task + 7 steps
319        assert_eq!(kg.n_relations(), 7); // 1 requires + 6 precedes
320    }
321
322    #[test]
323    fn neighbors_filtered() {
324        let kg = build_task_graph();
325
326        let required = kg.neighbors(1, Some(RelationType::Requires));
327        assert_eq!(required.len(), 1);
328        assert_eq!(required[0].0.name, "find");
329
330        let all = kg.neighbors(10, None);
331        assert_eq!(all.len(), 1); // only Precedes to take
332    }
333
334    #[test]
335    fn traverse_chain() {
336        let kg = build_task_graph();
337
338        let reachable = kg.traverse(10, &[RelationType::Precedes], 10);
339        assert_eq!(reachable.len(), 6); // take, go_microwave, heat, take_heated, go_target, put
340
341        // Check hop distances
342        let find_take = reachable.iter().find(|&&(id, _)| id == 11);
343        assert_eq!(find_take.unwrap().1, 1);
344
345        let find_put = reachable.iter().find(|&&(id, _)| id == 16);
346        assert_eq!(find_put.unwrap().1, 6);
347    }
348
349    #[test]
350    fn traverse_limited_hops() {
351        let kg = build_task_graph();
352        let reachable = kg.traverse(10, &[RelationType::Precedes], 2);
353        assert_eq!(reachable.len(), 2); // only take and go_microwave
354    }
355
356    #[test]
357    fn find_path() {
358        let kg = build_task_graph();
359
360        let path = kg.find_path(10, 16, &[RelationType::Precedes], 10);
361        assert!(path.is_some());
362        let path = path.unwrap();
363        assert_eq!(path.first(), Some(&10)); // find
364        assert_eq!(path.last(), Some(&16)); // put
365        assert_eq!(path.len(), 7);
366    }
367
368    #[test]
369    fn find_path_no_route() {
370        let kg = build_task_graph();
371        let path = kg.find_path(16, 10, &[RelationType::Precedes], 10);
372        assert!(path.is_none()); // can't go backwards
373    }
374
375    #[test]
376    fn task_plan() {
377        let kg = build_task_graph();
378        let plan = kg.task_plan(1);
379        assert_eq!(plan.len(), 7);
380        assert_eq!(plan[0], 10); // find
381        assert_eq!(plan[6], 16); // put
382    }
383
384    #[test]
385    fn entities_by_type() {
386        let kg = build_task_graph();
387        let actions = kg.entities_by_type(&EntityType::Action);
388        assert_eq!(actions.len(), 7);
389
390        let tasks = kg.entities_by_type(&EntityType::Task);
391        assert_eq!(tasks.len(), 1);
392    }
393
394    #[test]
395    fn incoming_neighbors() {
396        let kg = build_task_graph();
397        let predecessors = kg.incoming_neighbors(13, Some(RelationType::Precedes));
398        assert_eq!(predecessors.len(), 1);
399        assert_eq!(predecessors[0].0.name, "go_microwave");
400    }
401
402    #[test]
403    fn shared_sub_plans() {
404        let mut kg = build_task_graph();
405
406        // Add clean_then_place — shares find→take prefix
407        kg.add_entity(Entity::new(2, EntityType::Task, "clean_then_place"));
408        kg.add_entity(Entity::new(20, EntityType::Action, "go_sink"));
409        kg.add_entity(Entity::new(21, EntityType::Action, "clean"));
410
411        // clean_then_place requires same find step
412        kg.add_relation(Relation::new(2, 10, RelationType::Requires, 1.0));
413
414        // After take, clean diverges from heat
415        kg.add_relation(Relation::new(11, 20, RelationType::Precedes, 1.0));
416        kg.add_relation(Relation::new(20, 21, RelationType::Precedes, 1.0));
417
418        // find→take is shared
419        let find_reachable = kg.traverse(10, &[RelationType::Precedes], 1);
420        assert_eq!(find_reachable.len(), 1); // take
421        assert_eq!(find_reachable[0].0, 11);
422
423        // From take, two paths diverge
424        let take_neighbors = kg.neighbors(11, Some(RelationType::Precedes));
425        assert_eq!(take_neighbors.len(), 2); // go_microwave AND go_sink
426    }
427
428    #[test]
429    fn serialization() {
430        let kg = build_task_graph();
431        let bytes = postcard::to_allocvec(&kg).unwrap();
432        let restored: KnowledgeGraph = postcard::from_bytes(&bytes).unwrap();
433        assert_eq!(restored.n_entities(), 8);
434        assert_eq!(restored.n_relations(), 7);
435    }
436}