Skip to main content

ai_agent/utils/
set.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/utils/set.ts
2//! Set utility functions optimized for performance.
3//! This code is hot, so it's optimized for speed.
4
5/// Returns the difference of two sets (elements in `a` but not in `b`).
6pub fn difference<T: std::hash::Hash + Eq + Clone>(
7    a: &std::collections::HashSet<T>,
8    b: &std::collections::HashSet<T>,
9) -> std::collections::HashSet<T> {
10    a.iter().filter(|item| !b.contains(item)).cloned().collect()
11}
12
13/// Returns true if sets intersect (have any common elements).
14pub fn intersects<T: std::hash::Hash + Eq>(
15    a: &std::collections::HashSet<T>,
16    b: &std::collections::HashSet<T>,
17) -> bool {
18    if a.is_empty() || b.is_empty() {
19        return false;
20    }
21    a.iter().any(|item| b.contains(item))
22}
23
24/// Returns true if every element in `a` is also in `b`.
25pub fn every<T: std::hash::Hash + Eq>(
26    a: &std::collections::HashSet<T>,
27    b: &std::collections::HashSet<T>,
28) -> bool {
29    a.iter().all(|item| b.contains(item))
30}
31
32/// Returns the union of two sets.
33pub fn union<T: std::hash::Hash + Eq + Clone>(
34    a: &std::collections::HashSet<T>,
35    b: &std::collections::HashSet<T>,
36) -> std::collections::HashSet<T> {
37    let mut result = a.clone();
38    result.extend(b.iter().cloned());
39    result
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    #[test]
47    fn test_difference() {
48        let mut a = std::collections::HashSet::new();
49        a.insert(1);
50        a.insert(2);
51        a.insert(3);
52
53        let mut b = std::collections::HashSet::new();
54        b.insert(2);
55        b.insert(4);
56
57        let result = difference(&a, &b);
58        assert!(result.contains(&1));
59        assert!(result.contains(&3));
60        assert!(!result.contains(&2));
61    }
62
63    #[test]
64    fn test_intersects() {
65        let mut a = std::collections::HashSet::new();
66        a.insert(1);
67        a.insert(2);
68
69        let mut b = std::collections::HashSet::new();
70        b.insert(2);
71        b.insert(3);
72
73        assert!(intersects(&a, &b));
74
75        let mut c = std::collections::HashSet::new();
76        c.insert(4);
77
78        assert!(!intersects(&a, &c));
79    }
80
81    #[test]
82    fn test_every() {
83        let mut a = std::collections::HashSet::new();
84        a.insert(1);
85        a.insert(2);
86
87        let mut b = std::collections::HashSet::new();
88        b.insert(1);
89        b.insert(2);
90        b.insert(3);
91
92        assert!(every(&a, &b));
93
94        let mut c = std::collections::HashSet::new();
95        c.insert(1);
96        c.insert(4);
97
98        assert!(!every(&a, &c));
99    }
100
101    #[test]
102    fn test_union() {
103        let mut a = std::collections::HashSet::new();
104        a.insert(1);
105        a.insert(2);
106
107        let mut b = std::collections::HashSet::new();
108        b.insert(2);
109        b.insert(3);
110
111        let result = union(&a, &b);
112        assert_eq!(result.len(), 3);
113        assert!(result.contains(&1));
114        assert!(result.contains(&2));
115        assert!(result.contains(&3));
116    }
117}