hedl_core/visitor/
utils.rs

1// Dweve HEDL - Hierarchical Entity Data Language
2//
3// Copyright (c) 2025 Dweve IP B.V. and individual contributors.
4//
5// SPDX-License-Identifier: Apache-2.0
6
7//! Utility visitors for common patterns.
8
9use crate::visitor::{VisitDecision, Visitor, VisitorContext};
10use crate::{Node, Reference, Value};
11use std::collections::HashMap;
12
13/// Visitor that counts maximum depth reached.
14///
15/// # Example
16///
17/// ```
18/// use hedl_core::visitor::{utils::DepthCounter, traverse, TraversalConfig};
19/// use hedl_core::Document;
20///
21/// let doc = Document::new((1, 0));
22/// let mut counter = DepthCounter::new();
23/// traverse(&doc, &mut counter, &TraversalConfig::default());
24/// assert_eq!(counter.max_depth(), 0);
25/// ```
26#[derive(Debug, Default)]
27pub struct DepthCounter {
28    max_depth: usize,
29}
30
31impl DepthCounter {
32    /// Create a new depth counter.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Get the maximum depth reached.
38    pub fn max_depth(&self) -> usize {
39        self.max_depth
40    }
41}
42
43impl Visitor for DepthCounter {
44    fn visit_node(&mut self, _node: &Node, ctx: &VisitorContext<'_>) -> VisitDecision {
45        self.max_depth = self.max_depth.max(ctx.depth);
46        VisitDecision::Continue
47    }
48
49    fn visit_scalar(
50        &mut self,
51        _key: &str,
52        _value: &Value,
53        ctx: &VisitorContext<'_>,
54    ) -> VisitDecision {
55        self.max_depth = self.max_depth.max(ctx.depth);
56        VisitDecision::Continue
57    }
58}
59
60/// Visitor that collects all nodes of a specific type.
61///
62/// # Example
63///
64/// ```
65/// use hedl_core::visitor::{utils::NodeCollector, traverse, TraversalConfig};
66/// use hedl_core::{Document, Node};
67///
68/// let doc = Document::new((1, 0));
69/// let mut collector = NodeCollector::new("User");
70/// traverse(&doc, &mut collector, &TraversalConfig::default());
71/// let users = collector.into_nodes();
72/// ```
73#[derive(Debug)]
74pub struct NodeCollector {
75    type_filter: Option<String>,
76    nodes: Vec<Node>,
77}
78
79impl NodeCollector {
80    /// Create a collector for all nodes.
81    pub fn new_all() -> Self {
82        Self {
83            type_filter: None,
84            nodes: Vec::new(),
85        }
86    }
87
88    /// Create a collector for a specific node type.
89    pub fn new(type_name: impl Into<String>) -> Self {
90        Self {
91            type_filter: Some(type_name.into()),
92            nodes: Vec::new(),
93        }
94    }
95
96    /// Get the collected nodes.
97    pub fn nodes(&self) -> &[Node] {
98        &self.nodes
99    }
100
101    /// Consume the collector and return collected nodes.
102    pub fn into_nodes(self) -> Vec<Node> {
103        self.nodes
104    }
105
106    /// Get the count of collected nodes.
107    pub fn count(&self) -> usize {
108        self.nodes.len()
109    }
110}
111
112impl Visitor for NodeCollector {
113    fn visit_node(&mut self, node: &Node, _ctx: &VisitorContext<'_>) -> VisitDecision {
114        let should_collect = if let Some(ref filter) = self.type_filter {
115            &node.type_name == filter
116        } else {
117            true
118        };
119
120        if should_collect {
121            self.nodes.push(node.clone());
122        }
123
124        VisitDecision::Continue
125    }
126}
127
128/// Visitor that collects all unique paths in the document.
129///
130/// # Example
131///
132/// ```
133/// use hedl_core::visitor::{utils::PathCollector, traverse, TraversalConfig};
134/// use hedl_core::Document;
135///
136/// let doc = Document::new((1, 0));
137/// let mut collector = PathCollector::new();
138/// traverse(&doc, &mut collector, &TraversalConfig::default());
139/// let paths = collector.paths();
140/// ```
141#[derive(Debug, Default)]
142pub struct PathCollector {
143    paths: Vec<String>,
144}
145
146impl PathCollector {
147    /// Create a new path collector.
148    pub fn new() -> Self {
149        Self::default()
150    }
151
152    /// Get the collected paths.
153    pub fn paths(&self) -> &[String] {
154        &self.paths
155    }
156
157    /// Consume the collector and return collected paths.
158    pub fn into_paths(self) -> Vec<String> {
159        self.paths
160    }
161}
162
163impl Visitor for PathCollector {
164    fn visit_node(&mut self, _node: &Node, ctx: &VisitorContext<'_>) -> VisitDecision {
165        self.paths.push(ctx.path_string());
166        VisitDecision::Continue
167    }
168
169    fn visit_scalar(
170        &mut self,
171        _key: &str,
172        _value: &Value,
173        ctx: &VisitorContext<'_>,
174    ) -> VisitDecision {
175        self.paths.push(ctx.path_string());
176        VisitDecision::Continue
177    }
178}
179
180/// Visitor that collects all references in the document.
181///
182/// # Example
183///
184/// ```
185/// use hedl_core::visitor::{utils::ReferenceCollector, traverse, TraversalConfig};
186/// use hedl_core::Document;
187///
188/// let doc = Document::new((1, 0));
189/// let mut collector = ReferenceCollector::new();
190/// traverse(&doc, &mut collector, &TraversalConfig::default());
191/// let refs = collector.references();
192/// ```
193#[derive(Debug, Default)]
194pub struct ReferenceCollector {
195    references: Vec<Reference>,
196    by_type: HashMap<String, Vec<String>>,
197}
198
199impl ReferenceCollector {
200    /// Create a new reference collector.
201    pub fn new() -> Self {
202        Self::default()
203    }
204
205    /// Get all collected references.
206    pub fn references(&self) -> &[Reference] {
207        &self.references
208    }
209
210    /// Get references grouped by type.
211    pub fn by_type(&self) -> &HashMap<String, Vec<String>> {
212        &self.by_type
213    }
214
215    /// Get the count of collected references.
216    pub fn count(&self) -> usize {
217        self.references.len()
218    }
219
220    /// Consume the collector and return collected references.
221    pub fn into_references(self) -> Vec<Reference> {
222        self.references
223    }
224}
225
226impl Visitor for ReferenceCollector {
227    fn visit_reference(
228        &mut self,
229        reference: &Reference,
230        _ctx: &VisitorContext<'_>,
231    ) -> VisitDecision {
232        self.references.push(reference.clone());
233
234        // Group by type if qualified
235        if let Some(ref type_name) = reference.type_name {
236            self.by_type
237                .entry(type_name.to_string())
238                .or_default()
239                .push(reference.id.to_string());
240        }
241
242        VisitDecision::Continue
243    }
244}
245
246/// Visitor that finds the first node matching a predicate.
247///
248/// # Example
249///
250/// ```
251/// use hedl_core::visitor::{utils::FindNode, traverse, TraversalConfig};
252/// use hedl_core::{Document, Node};
253///
254/// let doc = Document::new((1, 0));
255/// let mut finder = FindNode::new(|node: &Node| node.id == "target");
256/// traverse(&doc, &mut finder, &TraversalConfig::default());
257/// if let Some(node) = finder.found() {
258///     // Found the target node
259/// }
260/// ```
261pub struct FindNode<F>
262where
263    F: Fn(&Node) -> bool,
264{
265    predicate: F,
266    found: Option<Node>,
267}
268
269impl<F> FindNode<F>
270where
271    F: Fn(&Node) -> bool,
272{
273    /// Create a new finder with a predicate.
274    pub fn new(predicate: F) -> Self {
275        Self {
276            predicate,
277            found: None,
278        }
279    }
280
281    /// Get the found node, if any.
282    pub fn found(&self) -> Option<&Node> {
283        self.found.as_ref()
284    }
285
286    /// Consume the finder and return the found node.
287    pub fn into_found(self) -> Option<Node> {
288        self.found
289    }
290}
291
292impl<F> Visitor for FindNode<F>
293where
294    F: Fn(&Node) -> bool,
295{
296    fn visit_node(&mut self, node: &Node, _ctx: &VisitorContext<'_>) -> VisitDecision {
297        if (self.predicate)(node) {
298            self.found = Some(node.clone());
299            VisitDecision::Stop // Early termination
300        } else {
301            VisitDecision::Continue
302        }
303    }
304}
305
306/// Visitor that counts nodes by type.
307///
308/// # Example
309///
310/// ```
311/// use hedl_core::visitor::{utils::TypeCounter, traverse, TraversalConfig};
312/// use hedl_core::Document;
313///
314/// let doc = Document::new((1, 0));
315/// let mut counter = TypeCounter::new();
316/// traverse(&doc, &mut counter, &TraversalConfig::default());
317/// let counts = counter.counts();
318/// ```
319#[derive(Debug, Default)]
320pub struct TypeCounter {
321    counts: HashMap<String, usize>,
322}
323
324impl TypeCounter {
325    /// Create a new type counter.
326    pub fn new() -> Self {
327        Self::default()
328    }
329
330    /// Get the counts map.
331    pub fn counts(&self) -> &HashMap<String, usize> {
332        &self.counts
333    }
334
335    /// Get the count for a specific type.
336    pub fn count_for(&self, type_name: &str) -> usize {
337        self.counts.get(type_name).copied().unwrap_or(0)
338    }
339
340    /// Get the total number of nodes counted.
341    pub fn total(&self) -> usize {
342        self.counts.values().sum()
343    }
344
345    /// Consume the counter and return the counts map.
346    pub fn into_counts(self) -> HashMap<String, usize> {
347        self.counts
348    }
349}
350
351impl Visitor for TypeCounter {
352    fn visit_node(&mut self, node: &Node, _ctx: &VisitorContext<'_>) -> VisitDecision {
353        *self.counts.entry(node.type_name.clone()).or_insert(0) += 1;
354        VisitDecision::Continue
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::Document;
362
363    #[test]
364    fn test_depth_counter() {
365        let mut counter = DepthCounter::new();
366        let doc = Document::new((1, 0));
367        let ctx = VisitorContext::new(&doc);
368
369        assert_eq!(counter.max_depth(), 0);
370
371        let node = Node::new("User", "1", vec![]);
372        counter.visit_node(&node, &ctx);
373        assert_eq!(counter.max_depth(), 0);
374
375        let child_ctx = ctx.child(crate::visitor::PathSegment::Key("a".to_string()));
376        counter.visit_node(&node, &child_ctx);
377        assert_eq!(counter.max_depth(), 1);
378    }
379
380    #[test]
381    fn test_node_collector_all() {
382        let mut collector = NodeCollector::new_all();
383        let doc = Document::new((1, 0));
384        let ctx = VisitorContext::new(&doc);
385
386        let node1 = Node::new("User", "1", vec![]);
387        let node2 = Node::new("Post", "2", vec![]);
388
389        collector.visit_node(&node1, &ctx);
390        collector.visit_node(&node2, &ctx);
391
392        assert_eq!(collector.count(), 2);
393    }
394
395    #[test]
396    fn test_node_collector_filtered() {
397        let mut collector = NodeCollector::new("User");
398        let doc = Document::new((1, 0));
399        let ctx = VisitorContext::new(&doc);
400
401        let node1 = Node::new("User", "1", vec![]);
402        let node2 = Node::new("Post", "2", vec![]);
403
404        collector.visit_node(&node1, &ctx);
405        collector.visit_node(&node2, &ctx);
406
407        assert_eq!(collector.count(), 1);
408        assert_eq!(collector.nodes()[0].type_name, "User");
409    }
410
411    #[test]
412    fn test_path_collector() {
413        let mut collector = PathCollector::new();
414        let doc = Document::new((1, 0));
415        let ctx = VisitorContext::new(&doc);
416
417        collector.visit_scalar("a", &Value::Int(1), &ctx);
418
419        let child_ctx = ctx.child(crate::visitor::PathSegment::Key("b".to_string()));
420        collector.visit_scalar("c", &Value::Int(2), &child_ctx);
421
422        assert_eq!(collector.paths().len(), 2);
423        assert_eq!(collector.paths()[0], "root");
424        assert_eq!(collector.paths()[1], "b");
425    }
426
427    #[test]
428    fn test_reference_collector() {
429        let mut collector = ReferenceCollector::new();
430        let doc = Document::new((1, 0));
431        let ctx = VisitorContext::new(&doc);
432
433        let ref1 = Reference::qualified("User", "alice");
434        let ref2 = Reference::qualified("User", "bob");
435        let ref3 = Reference::qualified("Post", "post1");
436
437        collector.visit_reference(&ref1, &ctx);
438        collector.visit_reference(&ref2, &ctx);
439        collector.visit_reference(&ref3, &ctx);
440
441        assert_eq!(collector.count(), 3);
442        assert_eq!(collector.by_type().get("User").unwrap().len(), 2);
443        assert_eq!(collector.by_type().get("Post").unwrap().len(), 1);
444    }
445
446    #[test]
447    fn test_find_node() {
448        let mut finder = FindNode::new(|node: &Node| node.id == "target");
449        let doc = Document::new((1, 0));
450        let ctx = VisitorContext::new(&doc);
451
452        let node1 = Node::new("User", "alice", vec![]);
453        let node2 = Node::new("User", "target", vec![]);
454
455        assert_eq!(finder.visit_node(&node1, &ctx), VisitDecision::Continue);
456        assert_eq!(finder.found(), None);
457
458        assert_eq!(finder.visit_node(&node2, &ctx), VisitDecision::Stop);
459        assert!(finder.found().is_some());
460        assert_eq!(finder.found().unwrap().id, "target");
461    }
462
463    #[test]
464    fn test_type_counter() {
465        let mut counter = TypeCounter::new();
466        let doc = Document::new((1, 0));
467        let ctx = VisitorContext::new(&doc);
468
469        let node1 = Node::new("User", "1", vec![]);
470        let node2 = Node::new("User", "2", vec![]);
471        let node3 = Node::new("Post", "1", vec![]);
472
473        counter.visit_node(&node1, &ctx);
474        counter.visit_node(&node2, &ctx);
475        counter.visit_node(&node3, &ctx);
476
477        assert_eq!(counter.count_for("User"), 2);
478        assert_eq!(counter.count_for("Post"), 1);
479        assert_eq!(counter.total(), 3);
480    }
481
482    #[test]
483    fn test_type_counter_empty() {
484        let counter = TypeCounter::new();
485        assert_eq!(counter.count_for("User"), 0);
486        assert_eq!(counter.total(), 0);
487    }
488}