Skip to main content

nu_command/filters/
intersect.rs

1use super::utils;
2use nu_engine::command_prelude::*;
3use std::collections::HashSet;
4
5#[derive(Clone)]
6pub struct Intersect;
7
8impl Command for Intersect {
9    fn name(&self) -> &str {
10        "intersect"
11    }
12
13    fn signature(&self) -> Signature {
14        Signature::build("intersect")
15            .input_output_types(vec![
16                (
17                    Type::List(Box::new(Type::Any)),
18                    Type::List(Box::new(Type::Any)),
19                ),
20                (Type::table(), Type::table()),
21            ])
22            .required(
23                "other",
24                SyntaxShape::List(Box::new(SyntaxShape::Any)),
25                "The other list to intersect with.",
26            )
27            .category(Category::Filters)
28    }
29
30    fn description(&self) -> &str {
31        "Returns a list of unique elements present in both the input and the provided list."
32    }
33
34    fn search_terms(&self) -> Vec<&str> {
35        vec!["common", "shared", "overlap", "filter"]
36    }
37
38    fn run(
39        &self,
40        engine_state: &EngineState,
41        stack: &mut Stack,
42        call: &Call,
43        mut input: PipelineData,
44    ) -> Result<PipelineData, ShellError> {
45        let head = call.head;
46        let metadata = input.take_metadata();
47
48        let other_vals = utils::extract_other_list(engine_state, stack, call, head)?;
49
50        let other_set: HashSet<String> = other_vals
51            .iter()
52            .map(|v| utils::value_to_key(engine_state, v, head))
53            .collect::<Result<HashSet<_>, _>>()?;
54
55        let signals = engine_state.signals().clone();
56        let mut seen: HashSet<String> = HashSet::new();
57        let mut result = Vec::new();
58
59        for val in input {
60            signals.check(&head)?;
61            let key = utils::value_to_key(engine_state, &val, head)?;
62            if other_set.contains(&key) && seen.insert(key) {
63                result.push(val);
64            }
65        }
66
67        Ok(PipelineData::Value(Value::list(result, head), metadata))
68    }
69
70    fn examples(&self) -> Vec<Example<'static>> {
71        vec![
72            Example {
73                example: "[1 2 3 4] | intersect [3 4 5 6]",
74                description: "Return the intersection of two lists",
75                result: Some(Value::test_list(vec![
76                    Value::test_int(3),
77                    Value::test_int(4),
78                ])),
79            },
80            Example {
81                example: "[1 2 3] | intersect [4 5 6]",
82                description: "Intersection with no common elements",
83                result: Some(Value::test_list(vec![])),
84            },
85            Example {
86                example: "[{a:1} {a:2} {a:3}] | intersect [{a:2} {a:3} {a:4}]",
87                description: "Intersection of two tables",
88                result: Some(Value::test_list(vec![
89                    Value::test_record(record!("a" => Value::test_int(2))),
90                    Value::test_record(record!("a" => Value::test_int(3))),
91                ])),
92            },
93        ]
94    }
95}
96
97#[cfg(test)]
98mod test {
99    use super::Intersect;
100    use nu_protocol::record;
101    use nu_test_support::prelude::*;
102
103    #[test]
104    fn test_examples() -> nu_test_support::Result {
105        nu_test_support::test().examples(Intersect)
106    }
107
108    #[test]
109    fn intersect_basic() -> Result {
110        test()
111            .run("[1 2 3 4] | intersect [3 4 5 6]")
112            .expect_value_eq([3, 4])
113    }
114
115    #[test]
116    fn intersect_no_common() -> Result {
117        test()
118            .run("[1 2 3] | intersect [4 5 6]")
119            .expect_value_eq(Value::test_list(vec![]))
120    }
121
122    #[test]
123    fn intersect_all_common() -> Result {
124        test()
125            .run("[1 2 3] | intersect [1 2 3]")
126            .expect_value_eq([1, 2, 3])
127    }
128
129    #[test]
130    fn intersect_empty_input() -> Result {
131        test()
132            .run("[] | intersect [1 2 3]")
133            .expect_value_eq(Value::test_list(vec![]))
134    }
135
136    #[test]
137    fn intersect_empty_other() -> Result {
138        test()
139            .run("[1 2 3] | intersect []")
140            .expect_value_eq(Value::test_list(vec![]))
141    }
142
143    #[test]
144    fn intersect_dedups_output() -> Result {
145        test()
146            .run("[1 1 2 3] | intersect [1 2 2 3]")
147            .expect_value_eq([1, 2, 3])
148    }
149
150    #[test]
151    fn intersect_preserves_input_order() -> Result {
152        test()
153            .run("[c a b d] | intersect [a d] | str join '-'")
154            .expect_value_eq("a-d")
155    }
156
157    #[test]
158    fn intersect_tables() -> Result {
159        let result: Value = test().run("[{a:1} {a:2} {a:3}] | intersect [{a:2} {a:3} {a:4}]")?;
160        assert_eq!(
161            result,
162            Value::test_list(vec![
163                Value::test_record(record!("a" => Value::test_int(2))),
164                Value::test_record(record!("a" => Value::test_int(3))),
165            ])
166        );
167        Ok(())
168    }
169
170    #[test]
171    fn intersect_mixed_types() -> Result {
172        test()
173            .run("[1 a 2.5 true] | intersect [2.5 b true]")
174            .expect_value_eq((2.5f64, true))
175    }
176
177    #[test]
178    fn intersect_other_not_a_list() {
179        let result: nu_test_support::Result = test().run("[1 2] | intersect 42");
180        assert!(result.is_err());
181    }
182}