nu_command/filters/
join.rs

1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4    cmp::max,
5    collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12    Inner,
13    Left,
14    Right,
15    Outer,
16}
17
18enum IncludeInner {
19    No,
20    Yes,
21}
22
23impl Command for Join {
24    fn name(&self) -> &str {
25        "join"
26    }
27
28    fn signature(&self) -> Signature {
29        Signature::build("join")
30            .required(
31                "right-table",
32                SyntaxShape::Table([].into()),
33                "The right table in the join.",
34            )
35            .required(
36                "left-on",
37                SyntaxShape::String,
38                "Name of column in input (left) table to join on.",
39            )
40            .optional(
41                "right-on",
42                SyntaxShape::String,
43                "Name of column in right table to join on. Defaults to same column as left table.",
44            )
45            .switch("inner", "Inner join (default)", Some('i'))
46            .switch("left", "Left-outer join", Some('l'))
47            .switch("right", "Right-outer join", Some('r'))
48            .switch("outer", "Outer join", Some('o'))
49            .input_output_types(vec![(Type::table(), Type::table())])
50            .category(Category::Filters)
51    }
52
53    fn description(&self) -> &str {
54        "Join two tables."
55    }
56
57    fn search_terms(&self) -> Vec<&str> {
58        vec!["sql"]
59    }
60
61    fn run(
62        &self,
63        engine_state: &EngineState,
64        stack: &mut Stack,
65        call: &Call,
66        input: PipelineData,
67    ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
68        let metadata = input.metadata();
69        let table_2: Value = call.req(engine_state, stack, 0)?;
70        let l_on: Value = call.req(engine_state, stack, 1)?;
71        let r_on: Value = call
72            .opt(engine_state, stack, 2)?
73            .unwrap_or_else(|| l_on.clone());
74        let span = call.head;
75        let join_type = join_type(engine_state, stack, call)?;
76
77        // FIXME: we should handle ListStreams properly instead of collecting
78        let collected_input = input.into_value(span)?;
79
80        match (&collected_input, &table_2, &l_on, &r_on) {
81            (
82                Value::List { vals: rows_1, .. },
83                Value::List { vals: rows_2, .. },
84                Value::String { val: l_on, .. },
85                Value::String { val: r_on, .. },
86            ) => {
87                let result = join(rows_1, rows_2, l_on, r_on, join_type, span);
88                Ok(PipelineData::value(result, metadata))
89            }
90            _ => Err(ShellError::UnsupportedInput {
91                msg: "(PipelineData<table>, table, string, string)".into(),
92                input: format!(
93                    "({:?}, {:?}, {:?} {:?})",
94                    collected_input,
95                    table_2.get_type(),
96                    l_on.get_type(),
97                    r_on.get_type(),
98                ),
99                msg_span: span,
100                input_span: span,
101            }),
102        }
103    }
104
105    fn examples(&self) -> Vec<Example> {
106        vec![Example {
107            description: "Join two tables",
108            example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
109            result: Some(Value::test_list(vec![Value::test_record(record! {
110                "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
111            })])),
112        }]
113    }
114}
115
116fn join_type(
117    engine_state: &EngineState,
118    stack: &mut Stack,
119    call: &Call,
120) -> Result<JoinType, nu_protocol::ShellError> {
121    match (
122        call.has_flag(engine_state, stack, "inner")?,
123        call.has_flag(engine_state, stack, "left")?,
124        call.has_flag(engine_state, stack, "right")?,
125        call.has_flag(engine_state, stack, "outer")?,
126    ) {
127        (_, false, false, false) => Ok(JoinType::Inner),
128        (false, true, false, false) => Ok(JoinType::Left),
129        (false, false, true, false) => Ok(JoinType::Right),
130        (false, false, false, true) => Ok(JoinType::Outer),
131        _ => Err(ShellError::UnsupportedInput {
132            msg: "Choose one of: --inner, --left, --right, --outer".into(),
133            input: "".into(),
134            msg_span: call.head,
135            input_span: call.head,
136        }),
137    }
138}
139
140fn join(
141    left: &[Value],
142    right: &[Value],
143    left_join_key: &str,
144    right_join_key: &str,
145    join_type: JoinType,
146    span: Span,
147) -> Value {
148    // Inner / Right Join
149    // ------------------
150    // Make look-up table from rows on left
151    // For each row r on right:
152    //    If any matching rows on left:
153    //        For each matching row l on left:
154    //            Emit (l, r)
155    //    Else if RightJoin:
156    //        Emit (null, r)
157
158    // Left Join
159    // ----------
160    // Make look-up table from rows on right
161    // For each row l on left:
162    //    If any matching rows on right:
163    //        For each matching row r on right:
164    //            Emit (l, r)
165    //    Else:
166    //        Emit (l, null)
167
168    // Outer Join
169    // ----------
170    // Perform Left Join procedure
171    // Perform Right Join procedure, but excluding rows in Inner Join
172
173    let config = Config::default();
174    let sep = ",";
175    let cap = max(left.len(), right.len());
176    let shared_join_key = if left_join_key == right_join_key {
177        Some(left_join_key)
178    } else {
179        None
180    };
181
182    // For the "other" table, create a map from value in `on` column to a list of the
183    // rows having that value.
184    let mut result: Vec<Value> = Vec::new();
185    let is_outer = matches!(join_type, JoinType::Outer);
186    let (this, this_join_key, other, other_keys, join_type) = match join_type {
187        JoinType::Left | JoinType::Outer => (
188            left,
189            left_join_key,
190            lookup_table(right, right_join_key, sep, cap, &config),
191            column_names(right),
192            // For Outer we do a Left pass and a Right pass; this is the Left
193            // pass.
194            JoinType::Left,
195        ),
196        JoinType::Inner | JoinType::Right => (
197            right,
198            right_join_key,
199            lookup_table(left, left_join_key, sep, cap, &config),
200            column_names(left),
201            join_type,
202        ),
203    };
204    join_rows(
205        &mut result,
206        this,
207        this_join_key,
208        other,
209        other_keys,
210        shared_join_key,
211        &join_type,
212        IncludeInner::Yes,
213        sep,
214        &config,
215        span,
216    );
217    if is_outer {
218        let (this, this_join_key, other, other_names, join_type) = (
219            right,
220            right_join_key,
221            lookup_table(left, left_join_key, sep, cap, &config),
222            column_names(left),
223            JoinType::Right,
224        );
225        join_rows(
226            &mut result,
227            this,
228            this_join_key,
229            other,
230            other_names,
231            shared_join_key,
232            &join_type,
233            IncludeInner::No,
234            sep,
235            &config,
236            span,
237        );
238    }
239    Value::list(result, span)
240}
241
242// Join rows of `this` (a nushell table) to rows of `other` (a lookup-table
243// containing rows of a nushell table).
244#[allow(clippy::too_many_arguments)]
245fn join_rows(
246    result: &mut Vec<Value>,
247    this: &[Value],
248    this_join_key: &str,
249    other: HashMap<String, Vec<&Record>>,
250    other_keys: Vec<&String>,
251    shared_join_key: Option<&str>,
252    join_type: &JoinType,
253    include_inner: IncludeInner,
254    sep: &str,
255    config: &Config,
256    span: Span,
257) {
258    if !this
259        .iter()
260        .any(|this_record| match this_record.as_record() {
261            Ok(record) => record.contains(this_join_key),
262            Err(_) => false,
263        })
264    {
265        // `this` table does not contain the join column; do nothing
266        return;
267    }
268    for this_row in this {
269        if let Value::Record {
270            val: this_record, ..
271        } = this_row
272        {
273            if let Some(this_valkey) = this_record.get(this_join_key) {
274                if let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config)) {
275                    if matches!(include_inner, IncludeInner::Yes) {
276                        for other_record in other_rows {
277                            // `other` table contains rows matching `this` row on the join column
278                            let record = match join_type {
279                                JoinType::Inner | JoinType::Right => merge_records(
280                                    other_record, // `other` (lookup) is the left input table
281                                    this_record,
282                                    shared_join_key,
283                                ),
284                                JoinType::Left => merge_records(
285                                    this_record, // `this` is the left input table
286                                    other_record,
287                                    shared_join_key,
288                                ),
289                                _ => panic!("not implemented"),
290                            };
291                            result.push(Value::record(record, span))
292                        }
293                    }
294                    continue;
295                }
296            }
297            if !matches!(join_type, JoinType::Inner) {
298                // Either `this` row is missing a value for the join column or
299                // `other` table did not contain any rows matching
300                // `this` row on the join column; emit a single joined
301                // row with null values for columns not present
302                let other_record = other_keys
303                    .iter()
304                    .map(|&key| {
305                        let val = if Some(key.as_ref()) == shared_join_key {
306                            this_record
307                                .get(key)
308                                .cloned()
309                                .unwrap_or_else(|| Value::nothing(span))
310                        } else {
311                            Value::nothing(span)
312                        };
313
314                        (key.clone(), val)
315                    })
316                    .collect();
317
318                let record = match join_type {
319                    JoinType::Inner | JoinType::Right => {
320                        merge_records(&other_record, this_record, shared_join_key)
321                    }
322                    JoinType::Left => merge_records(this_record, &other_record, shared_join_key),
323                    _ => panic!("not implemented"),
324                };
325
326                result.push(Value::record(record, span))
327            }
328        };
329    }
330}
331
332// Return column names (i.e. ordered keys from the first row; we assume that
333// these are the same for all rows).
334fn column_names(table: &[Value]) -> Vec<&String> {
335    table
336        .iter()
337        .find_map(|val| match val {
338            Value::Record { val, .. } => Some(val.columns().collect()),
339            _ => None,
340        })
341        .unwrap_or_default()
342}
343
344// Create a map from value in `on` column to a list of the rows having that
345// value.
346fn lookup_table<'a>(
347    rows: &'a [Value],
348    on: &str,
349    sep: &str,
350    cap: usize,
351    config: &Config,
352) -> HashMap<String, Vec<&'a Record>> {
353    let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
354    for row in rows {
355        if let Value::Record { val: record, .. } = row {
356            if let Some(val) = record.get(on) {
357                let valkey = val.to_expanded_string(sep, config);
358                map.entry(valkey).or_default().push(record);
359            }
360        };
361    }
362    map
363}
364
365// Merge `left` and `right` records, renaming keys in `right` where they clash
366// with keys in `left`. If `shared_key` is supplied then it is the name of a key
367// that should not be renamed (its values are guaranteed to be equal).
368fn merge_records(left: &Record, right: &Record, shared_key: Option<&str>) -> Record {
369    let cap = max(left.len(), right.len());
370    let mut seen = HashSet::with_capacity(cap);
371    let mut record = Record::with_capacity(cap);
372    for (k, v) in left {
373        record.push(k.clone(), v.clone());
374        seen.insert(k);
375    }
376
377    for (k, v) in right {
378        let k_seen = seen.contains(k);
379        let k_shared = shared_key == Some(k.as_str());
380        // Do not output shared join key twice
381        if !(k_seen && k_shared) {
382            record.push(if k_seen { format!("{k}_") } else { k.clone() }, v.clone());
383        }
384    }
385    record
386}
387
388#[cfg(test)]
389mod test {
390    use super::*;
391
392    #[test]
393    fn test_examples() {
394        use crate::test_examples;
395
396        test_examples(Join {})
397    }
398}