Skip to main content

nu_command/filters/
join.rs

1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4    cmp::max,
5    collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12    Inner,
13    Left,
14    Right,
15    Outer,
16}
17
18enum IncludeInner {
19    No,
20    Yes,
21}
22
23#[derive(Debug, Default)]
24struct RightColumnRename {
25    prefix: Option<String>,
26    suffix: Option<String>,
27}
28
29impl Command for Join {
30    fn name(&self) -> &str {
31        "join"
32    }
33
34    fn signature(&self) -> Signature {
35        Signature::build("join")
36            .required(
37                "right-table",
38                SyntaxShape::Table([].into()),
39                "The right table in the join.",
40            )
41            .required(
42                "left-on",
43                SyntaxShape::String,
44                "Name of column in input (left) table to join on.",
45            )
46            .optional(
47                "right-on",
48                SyntaxShape::String,
49                "Name of column in right table to join on. Defaults to same column as left table.",
50            )
51            .named(
52                "prefix",
53                SyntaxShape::String,
54                "Prefix columns from the right table with this string (excluding the shared join key).",
55                Some('p'),
56            )
57            .named(
58                "suffix",
59                SyntaxShape::String,
60                "Suffix columns from the right table with this string (excluding the shared join key).",
61                Some('s'),
62            )
63            .switch("inner", "Inner join (default).", Some('i'))
64            .switch("left", "Left-outer join.", Some('l'))
65            .switch("right", "Right-outer join.", Some('r'))
66            .switch("outer", "Outer join.", Some('o'))
67            .input_output_types(vec![(Type::table(), Type::table())])
68            .category(Category::Filters)
69    }
70
71    fn description(&self) -> &str {
72        "Join two tables."
73    }
74
75    fn search_terms(&self) -> Vec<&str> {
76        vec!["sql"]
77    }
78
79    fn run(
80        &self,
81        engine_state: &EngineState,
82        stack: &mut Stack,
83        call: &Call,
84        input: PipelineData,
85    ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
86        let metadata = input.metadata();
87        let table_2: Value = call.req(engine_state, stack, 0)?;
88        let l_on: Value = call.req(engine_state, stack, 1)?;
89        let r_on: Value = call
90            .opt(engine_state, stack, 2)?
91            .unwrap_or_else(|| l_on.clone());
92        let span = call.head;
93        let join_type = join_type(engine_state, stack, call)?;
94        let rename = RightColumnRename {
95            prefix: call.get_flag(engine_state, stack, "prefix")?,
96            suffix: call.get_flag(engine_state, stack, "suffix")?,
97        };
98
99        // FIXME: we should handle ListStreams properly instead of collecting
100        let collected_input = input.into_value(span)?;
101
102        match (&collected_input, &table_2, &l_on, &r_on) {
103            (
104                Value::List { vals: rows_1, .. },
105                Value::List { vals: rows_2, .. },
106                Value::String { val: l_on, .. },
107                Value::String { val: r_on, .. },
108            ) => {
109                let result = join(rows_1, rows_2, l_on, r_on, join_type, &rename, span);
110                Ok(PipelineData::value(result, metadata))
111            }
112            _ => Err(ShellError::UnsupportedInput {
113                msg: "(PipelineData<table>, table, string, string)".into(),
114                input: format!(
115                    "({:?}, {:?}, {:?} {:?})",
116                    collected_input,
117                    table_2.get_type(),
118                    l_on.get_type(),
119                    r_on.get_type(),
120                ),
121                msg_span: span,
122                input_span: span,
123            }),
124        }
125    }
126
127    fn examples(&self) -> Vec<Example<'_>> {
128        vec![
129            Example {
130                description: "Join two tables",
131                example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
132                result: Some(Value::test_list(vec![Value::test_record(record! {
133                    "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
134                })])),
135            },
136            Example {
137                description: "Join multiple tables with distinct suffixes for the right table's columns",
138                example: "[{id: 1 x: 10}] | join --suffix _a [{id: 1 x: 20}] id | join --suffix _b [{id: 1 x: 30}] id",
139                result: Some(Value::test_list(vec![Value::test_record(record! {
140                    "id" => Value::test_int(1),
141                    "x" => Value::test_int(10),
142                    "x_a" => Value::test_int(20),
143                    "x_b" => Value::test_int(30),
144                })])),
145            },
146            Example {
147                description: "Join multiple tables with a prefix for the right table's columns",
148                example: "[{id: 1 x: 10}] | join --prefix r_ [{id: 1 x: 20}] id",
149                result: Some(Value::test_list(vec![Value::test_record(record! {
150                    "id" => Value::test_int(1),
151                    "x" => Value::test_int(10),
152                    "r_x" => Value::test_int(20),
153                })])),
154            },
155        ]
156    }
157}
158
159fn join_type(
160    engine_state: &EngineState,
161    stack: &mut Stack,
162    call: &Call,
163) -> Result<JoinType, nu_protocol::ShellError> {
164    match (
165        call.has_flag(engine_state, stack, "inner")?,
166        call.has_flag(engine_state, stack, "left")?,
167        call.has_flag(engine_state, stack, "right")?,
168        call.has_flag(engine_state, stack, "outer")?,
169    ) {
170        (_, false, false, false) => Ok(JoinType::Inner),
171        (false, true, false, false) => Ok(JoinType::Left),
172        (false, false, true, false) => Ok(JoinType::Right),
173        (false, false, false, true) => Ok(JoinType::Outer),
174        _ => Err(ShellError::UnsupportedInput {
175            msg: "Choose one of: --inner, --left, --right, --outer".into(),
176            input: "".into(),
177            msg_span: call.head,
178            input_span: call.head,
179        }),
180    }
181}
182
183fn join(
184    left: &[Value],
185    right: &[Value],
186    left_join_key: &str,
187    right_join_key: &str,
188    join_type: JoinType,
189    rename: &RightColumnRename,
190    span: Span,
191) -> Value {
192    // Inner / Right Join
193    // ------------------
194    // Make look-up table from rows on left
195    // For each row r on right:
196    //    If any matching rows on left:
197    //        For each matching row l on left:
198    //            Emit (l, r)
199    //    Else if RightJoin:
200    //        Emit (null, r)
201
202    // Left Join
203    // ----------
204    // Make look-up table from rows on right
205    // For each row l on left:
206    //    If any matching rows on right:
207    //        For each matching row r on right:
208    //            Emit (l, r)
209    //    Else:
210    //        Emit (l, null)
211
212    // Outer Join
213    // ----------
214    // Perform Left Join procedure
215    // Perform Right Join procedure, but excluding rows in Inner Join
216
217    let config = Config::default();
218    let sep = ",";
219    let cap = max(left.len(), right.len());
220    let shared_join_key = if left_join_key == right_join_key {
221        Some(left_join_key)
222    } else {
223        None
224    };
225
226    // For the "other" table, create a map from value in `on` column to a list of the
227    // rows having that value.
228    let mut result: Vec<Value> = Vec::new();
229    let is_outer = matches!(join_type, JoinType::Outer);
230    let (this, this_join_key, other, other_keys, join_type) = match join_type {
231        JoinType::Left | JoinType::Outer => (
232            left,
233            left_join_key,
234            lookup_table(right, right_join_key, sep, cap, &config),
235            column_names(right),
236            // For Outer we do a Left pass and a Right pass; this is the Left
237            // pass.
238            JoinType::Left,
239        ),
240        JoinType::Inner | JoinType::Right => (
241            right,
242            right_join_key,
243            lookup_table(left, left_join_key, sep, cap, &config),
244            column_names(left),
245            join_type,
246        ),
247    };
248    join_rows(
249        &mut result,
250        this,
251        this_join_key,
252        other,
253        other_keys,
254        shared_join_key,
255        &join_type,
256        IncludeInner::Yes,
257        sep,
258        &config,
259        rename,
260        span,
261    );
262    if is_outer {
263        let (this, this_join_key, other, other_names, join_type) = (
264            right,
265            right_join_key,
266            lookup_table(left, left_join_key, sep, cap, &config),
267            column_names(left),
268            JoinType::Right,
269        );
270        join_rows(
271            &mut result,
272            this,
273            this_join_key,
274            other,
275            other_names,
276            shared_join_key,
277            &join_type,
278            IncludeInner::No,
279            sep,
280            &config,
281            rename,
282            span,
283        );
284    }
285    Value::list(result, span)
286}
287
288// Join rows of `this` (a nushell table) to rows of `other` (a lookup-table
289// containing rows of a nushell table).
290#[allow(clippy::too_many_arguments)]
291fn join_rows(
292    result: &mut Vec<Value>,
293    this: &[Value],
294    this_join_key: &str,
295    other: HashMap<String, Vec<&Record>>,
296    other_keys: Vec<&String>,
297    shared_join_key: Option<&str>,
298    join_type: &JoinType,
299    include_inner: IncludeInner,
300    sep: &str,
301    config: &Config,
302    rename: &RightColumnRename,
303    span: Span,
304) {
305    if !this
306        .iter()
307        .any(|this_record| match this_record.as_record() {
308            Ok(record) => record.contains(this_join_key),
309            Err(_) => false,
310        })
311    {
312        // `this` table does not contain the join column; do nothing
313        return;
314    }
315    for this_row in this {
316        if let Value::Record {
317            val: this_record, ..
318        } = this_row
319        {
320            if let Some(this_valkey) = this_record.get(this_join_key)
321                && let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config))
322            {
323                if let IncludeInner::Yes = include_inner {
324                    for other_record in other_rows {
325                        // `other` table contains rows matching `this` row on the join column
326                        let record = match join_type {
327                            JoinType::Inner | JoinType::Right => merge_records(
328                                other_record, // `other` (lookup) is the left input table
329                                this_record,
330                                shared_join_key,
331                                rename,
332                            ),
333                            JoinType::Left => merge_records(
334                                this_record, // `this` is the left input table
335                                other_record,
336                                shared_join_key,
337                                rename,
338                            ),
339                            _ => panic!("not implemented"),
340                        };
341                        result.push(Value::record(record, span))
342                    }
343                }
344                continue;
345            }
346            if !matches!(join_type, JoinType::Inner) {
347                // Either `this` row is missing a value for the join column or
348                // `other` table did not contain any rows matching
349                // `this` row on the join column; emit a single joined
350                // row with null values for columns not present
351                let other_record = other_keys
352                    .iter()
353                    .map(|&key| {
354                        let val = if Some(key.as_ref()) == shared_join_key {
355                            this_record
356                                .get(key)
357                                .cloned()
358                                .unwrap_or_else(|| Value::nothing(span))
359                        } else {
360                            Value::nothing(span)
361                        };
362
363                        (key.clone(), val)
364                    })
365                    .collect();
366
367                let record = match join_type {
368                    JoinType::Inner | JoinType::Right => {
369                        merge_records(&other_record, this_record, shared_join_key, rename)
370                    }
371                    JoinType::Left => {
372                        merge_records(this_record, &other_record, shared_join_key, rename)
373                    }
374                    _ => panic!("not implemented"),
375                };
376
377                result.push(Value::record(record, span))
378            }
379        };
380    }
381}
382
383// Return column names (i.e. ordered keys from the first row; we assume that
384// these are the same for all rows).
385fn column_names(table: &[Value]) -> Vec<&String> {
386    table
387        .iter()
388        .find_map(|val| match val {
389            Value::Record { val, .. } => Some(val.columns().collect()),
390            _ => None,
391        })
392        .unwrap_or_default()
393}
394
395// Create a map from value in `on` column to a list of the rows having that
396// value.
397fn lookup_table<'a>(
398    rows: &'a [Value],
399    on: &str,
400    sep: &str,
401    cap: usize,
402    config: &Config,
403) -> HashMap<String, Vec<&'a Record>> {
404    let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
405    for row in rows {
406        if let Value::Record { val: record, .. } = row
407            && let Some(val) = record.get(on)
408        {
409            let valkey = val.to_expanded_string(sep, config);
410            map.entry(valkey).or_default().push(record);
411        };
412    }
413    map
414}
415
416// Merge `left` and `right` records, renaming keys in `right` where they clash
417// with keys in `left`. If `shared_key` is supplied then it is the name of a key
418// that should not be renamed (its values are guaranteed to be equal).
419fn merge_records(
420    left: &Record,
421    right: &Record,
422    shared_key: Option<&str>,
423    rename: &RightColumnRename,
424) -> Record {
425    let cap = max(left.len(), right.len());
426    let mut seen = HashSet::with_capacity(cap);
427    let mut record = Record::with_capacity(cap);
428    for (k, v) in left {
429        record.push(k.clone(), v.clone());
430        seen.insert(k.clone());
431    }
432
433    for (k, v) in right {
434        let k_shared = shared_key == Some(k.as_str());
435        // Do not output shared join key twice
436        if k_shared && seen.contains(k) {
437            continue;
438        }
439
440        let mut out_key = if rename.prefix.is_some() || rename.suffix.is_some() {
441            format!(
442                "{}{}{}",
443                rename.prefix.as_deref().unwrap_or(""),
444                k,
445                rename.suffix.as_deref().unwrap_or("")
446            )
447        } else if seen.contains(k) {
448            format!("{k}_")
449        } else {
450            k.clone()
451        };
452
453        // Ensure the output key is truly unique. If not, keep appending "_" until it is.
454        while seen.contains(&out_key) {
455            out_key.push('_');
456        }
457
458        record.push(out_key.clone(), v.clone());
459        seen.insert(out_key);
460    }
461    record
462}
463
464#[cfg(test)]
465mod test {
466    use super::*;
467
468    #[test]
469    fn test_examples() {
470        use crate::test_examples;
471
472        test_examples(Join {})
473    }
474}