Skip to main content

nu_command/filters/
join.rs

1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4    cmp::max,
5    collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12    Inner,
13    Left,
14    Right,
15    Outer,
16}
17
18enum IncludeInner {
19    No,
20    Yes,
21}
22
23#[derive(Debug, Default)]
24struct RightColumnRename {
25    prefix: Option<String>,
26    suffix: Option<String>,
27}
28
29impl Command for Join {
30    fn name(&self) -> &str {
31        "join"
32    }
33
34    fn signature(&self) -> Signature {
35        Signature::build("join")
36            .required(
37                "right-table",
38                SyntaxShape::Table([].into()),
39                "The right table in the join.",
40            )
41            .required(
42                "left-on",
43                SyntaxShape::String,
44                "Name of column in input (left) table to join on.",
45            )
46            .optional(
47                "right-on",
48                SyntaxShape::String,
49                "Name of column in right table to join on. Defaults to same column as left table.",
50            )
51            .named(
52                "prefix",
53                SyntaxShape::String,
54                "Prefix columns from the right table with this string (excluding the shared join key).",
55                Some('p'),
56            )
57            .named(
58                "suffix",
59                SyntaxShape::String,
60                "Suffix columns from the right table with this string (excluding the shared join key).",
61                Some('s'),
62            )
63            .switch("inner", "Inner join (default).", Some('i'))
64            .switch("left", "Left-outer join.", Some('l'))
65            .switch("right", "Right-outer join.", Some('r'))
66            .switch("outer", "Outer join.", Some('o'))
67            .input_output_types(vec![(Type::table(), Type::table())])
68            .category(Category::Filters)
69    }
70
71    fn description(&self) -> &str {
72        "Join two tables."
73    }
74
75    fn search_terms(&self) -> Vec<&str> {
76        vec!["sql"]
77    }
78
79    fn run(
80        &self,
81        engine_state: &EngineState,
82        stack: &mut Stack,
83        call: &Call,
84        input: PipelineData,
85    ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
86        let mut input = input.into_stream_or_original(engine_state);
87
88        let metadata = input.take_metadata();
89        let table_2: Value = call.req(engine_state, stack, 0)?;
90        let l_on: Value = call.req(engine_state, stack, 1)?;
91        let r_on: Value = call
92            .opt(engine_state, stack, 2)?
93            .unwrap_or_else(|| l_on.clone());
94        let span = call.head;
95        let join_type = join_type(engine_state, stack, call)?;
96        let rename = RightColumnRename {
97            prefix: call.get_flag(engine_state, stack, "prefix")?,
98            suffix: call.get_flag(engine_state, stack, "suffix")?,
99        };
100
101        // FIXME: we should handle ListStreams properly instead of collecting
102        let collected_input = input.into_value(span)?;
103
104        match (&collected_input, &table_2, &l_on, &r_on) {
105            (
106                Value::List { vals: rows_1, .. },
107                Value::List { vals: rows_2, .. },
108                Value::String { val: l_on, .. },
109                Value::String { val: r_on, .. },
110            ) => {
111                let result = join(rows_1, rows_2, l_on, r_on, join_type, &rename, span);
112                Ok(PipelineData::value(result, metadata))
113            }
114            _ => Err(ShellError::UnsupportedInput {
115                msg: "(PipelineData<table>, table, string, string)".into(),
116                input: format!(
117                    "({:?}, {:?}, {:?} {:?})",
118                    collected_input,
119                    table_2.get_type(),
120                    l_on.get_type(),
121                    r_on.get_type(),
122                ),
123                msg_span: span,
124                input_span: span,
125            }),
126        }
127    }
128
129    fn examples(&self) -> Vec<Example<'_>> {
130        vec![
131            Example {
132                description: "Join two tables",
133                example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
134                result: Some(Value::test_list(vec![Value::test_record(record! {
135                    "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
136                })])),
137            },
138            Example {
139                description: "Join multiple tables with distinct suffixes for the right table's columns",
140                example: "[{id: 1 x: 10}] | join --suffix _a [{id: 1 x: 20}] id | join --suffix _b [{id: 1 x: 30}] id",
141                result: Some(Value::test_list(vec![Value::test_record(record! {
142                    "id" => Value::test_int(1),
143                    "x" => Value::test_int(10),
144                    "x_a" => Value::test_int(20),
145                    "x_b" => Value::test_int(30),
146                })])),
147            },
148            Example {
149                description: "Join multiple tables with a prefix for the right table's columns",
150                example: "[{id: 1 x: 10}] | join --prefix r_ [{id: 1 x: 20}] id",
151                result: Some(Value::test_list(vec![Value::test_record(record! {
152                    "id" => Value::test_int(1),
153                    "x" => Value::test_int(10),
154                    "r_x" => Value::test_int(20),
155                })])),
156            },
157        ]
158    }
159}
160
161fn join_type(
162    engine_state: &EngineState,
163    stack: &mut Stack,
164    call: &Call,
165) -> Result<JoinType, nu_protocol::ShellError> {
166    match (
167        call.has_flag(engine_state, stack, "inner")?,
168        call.has_flag(engine_state, stack, "left")?,
169        call.has_flag(engine_state, stack, "right")?,
170        call.has_flag(engine_state, stack, "outer")?,
171    ) {
172        (_, false, false, false) => Ok(JoinType::Inner),
173        (false, true, false, false) => Ok(JoinType::Left),
174        (false, false, true, false) => Ok(JoinType::Right),
175        (false, false, false, true) => Ok(JoinType::Outer),
176        _ => Err(ShellError::UnsupportedInput {
177            msg: "Choose one of: --inner, --left, --right, --outer".into(),
178            input: "".into(),
179            msg_span: call.head,
180            input_span: call.head,
181        }),
182    }
183}
184
185fn join(
186    left: &[Value],
187    right: &[Value],
188    left_join_key: &str,
189    right_join_key: &str,
190    join_type: JoinType,
191    rename: &RightColumnRename,
192    span: Span,
193) -> Value {
194    // Inner / Right Join
195    // ------------------
196    // Make look-up table from rows on left
197    // For each row r on right:
198    //    If any matching rows on left:
199    //        For each matching row l on left:
200    //            Emit (l, r)
201    //    Else if RightJoin:
202    //        Emit (null, r)
203
204    // Left Join
205    // ----------
206    // Make look-up table from rows on right
207    // For each row l on left:
208    //    If any matching rows on right:
209    //        For each matching row r on right:
210    //            Emit (l, r)
211    //    Else:
212    //        Emit (l, null)
213
214    // Outer Join
215    // ----------
216    // Perform Left Join procedure
217    // Perform Right Join procedure, but excluding rows in Inner Join
218
219    let config = Config::default();
220    let sep = ",";
221    let cap = max(left.len(), right.len());
222    let shared_join_key = if left_join_key == right_join_key {
223        Some(left_join_key)
224    } else {
225        None
226    };
227
228    // For the "other" table, create a map from value in `on` column to a list of the
229    // rows having that value.
230    let mut result: Vec<Value> = Vec::new();
231    let is_outer = matches!(join_type, JoinType::Outer);
232    let (this, this_join_key, other, other_keys, join_type) = match join_type {
233        JoinType::Left | JoinType::Outer => (
234            left,
235            left_join_key,
236            lookup_table(right, right_join_key, sep, cap, &config),
237            column_names(right),
238            // For Outer we do a Left pass and a Right pass; this is the Left
239            // pass.
240            JoinType::Left,
241        ),
242        JoinType::Inner | JoinType::Right => (
243            right,
244            right_join_key,
245            lookup_table(left, left_join_key, sep, cap, &config),
246            column_names(left),
247            join_type,
248        ),
249    };
250    join_rows(
251        &mut result,
252        this,
253        this_join_key,
254        other,
255        other_keys,
256        shared_join_key,
257        &join_type,
258        IncludeInner::Yes,
259        sep,
260        &config,
261        rename,
262        span,
263    );
264    if is_outer {
265        let (this, this_join_key, other, other_names, join_type) = (
266            right,
267            right_join_key,
268            lookup_table(left, left_join_key, sep, cap, &config),
269            column_names(left),
270            JoinType::Right,
271        );
272        join_rows(
273            &mut result,
274            this,
275            this_join_key,
276            other,
277            other_names,
278            shared_join_key,
279            &join_type,
280            IncludeInner::No,
281            sep,
282            &config,
283            rename,
284            span,
285        );
286    }
287    Value::list(result, span)
288}
289
290// Join rows of `this` (a nushell table) to rows of `other` (a lookup-table
291// containing rows of a nushell table).
292#[allow(clippy::too_many_arguments)]
293fn join_rows(
294    result: &mut Vec<Value>,
295    this: &[Value],
296    this_join_key: &str,
297    other: HashMap<String, Vec<&Record>>,
298    other_keys: Vec<&String>,
299    shared_join_key: Option<&str>,
300    join_type: &JoinType,
301    include_inner: IncludeInner,
302    sep: &str,
303    config: &Config,
304    rename: &RightColumnRename,
305    span: Span,
306) {
307    if !this
308        .iter()
309        .any(|this_record| match this_record.as_record() {
310            Ok(record) => record.contains(this_join_key),
311            Err(_) => false,
312        })
313    {
314        // `this` table does not contain the join column; do nothing
315        return;
316    }
317    for this_row in this {
318        if let Value::Record {
319            val: this_record, ..
320        } = this_row
321        {
322            if let Some(this_valkey) = this_record.get(this_join_key)
323                && let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config))
324            {
325                if let IncludeInner::Yes = include_inner {
326                    for other_record in other_rows {
327                        // `other` table contains rows matching `this` row on the join column
328                        let record = match join_type {
329                            JoinType::Inner | JoinType::Right => merge_records(
330                                other_record, // `other` (lookup) is the left input table
331                                this_record,
332                                shared_join_key,
333                                rename,
334                            ),
335                            JoinType::Left => merge_records(
336                                this_record, // `this` is the left input table
337                                other_record,
338                                shared_join_key,
339                                rename,
340                            ),
341                            _ => panic!("not implemented"),
342                        };
343                        result.push(Value::record(record, span))
344                    }
345                }
346                continue;
347            }
348            if !matches!(join_type, JoinType::Inner) {
349                // Either `this` row is missing a value for the join column or
350                // `other` table did not contain any rows matching
351                // `this` row on the join column; emit a single joined
352                // row with null values for columns not present
353                let other_record = other_keys
354                    .iter()
355                    .map(|&key| {
356                        let val = if Some(key.as_ref()) == shared_join_key {
357                            this_record
358                                .get(key)
359                                .cloned()
360                                .unwrap_or_else(|| Value::nothing(span))
361                        } else {
362                            Value::nothing(span)
363                        };
364
365                        (key.clone(), val)
366                    })
367                    .collect();
368
369                let record = match join_type {
370                    JoinType::Inner | JoinType::Right => {
371                        merge_records(&other_record, this_record, shared_join_key, rename)
372                    }
373                    JoinType::Left => {
374                        merge_records(this_record, &other_record, shared_join_key, rename)
375                    }
376                    _ => panic!("not implemented"),
377                };
378
379                result.push(Value::record(record, span))
380            }
381        };
382    }
383}
384
385// Return column names (i.e. ordered keys from the first row; we assume that
386// these are the same for all rows).
387fn column_names(table: &[Value]) -> Vec<&String> {
388    table
389        .iter()
390        .find_map(|val| match val {
391            Value::Record { val, .. } => Some(val.columns().collect()),
392            _ => None,
393        })
394        .unwrap_or_default()
395}
396
397// Create a map from value in `on` column to a list of the rows having that
398// value.
399fn lookup_table<'a>(
400    rows: &'a [Value],
401    on: &str,
402    sep: &str,
403    cap: usize,
404    config: &Config,
405) -> HashMap<String, Vec<&'a Record>> {
406    let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
407    for row in rows {
408        if let Value::Record { val: record, .. } = row
409            && let Some(val) = record.get(on)
410        {
411            let valkey = val.to_expanded_string(sep, config);
412            map.entry(valkey).or_default().push(record);
413        };
414    }
415    map
416}
417
418// Merge `left` and `right` records, renaming keys in `right` where they clash
419// with keys in `left`. If `shared_key` is supplied then it is the name of a key
420// that should not be renamed (its values are guaranteed to be equal).
421fn merge_records(
422    left: &Record,
423    right: &Record,
424    shared_key: Option<&str>,
425    rename: &RightColumnRename,
426) -> Record {
427    let cap = max(left.len(), right.len());
428    let mut seen = HashSet::with_capacity(cap);
429    let mut record = Record::with_capacity(cap);
430    for (k, v) in left {
431        record.push(k.clone(), v.clone());
432        seen.insert(k.clone());
433    }
434
435    for (k, v) in right {
436        let k_shared = shared_key == Some(k.as_str());
437        // Do not output shared join key twice
438        if k_shared && seen.contains(k) {
439            continue;
440        }
441
442        let mut out_key = if rename.prefix.is_some() || rename.suffix.is_some() {
443            format!(
444                "{}{}{}",
445                rename.prefix.as_deref().unwrap_or(""),
446                k,
447                rename.suffix.as_deref().unwrap_or("")
448            )
449        } else if seen.contains(k) {
450            format!("{k}_")
451        } else {
452            k.clone()
453        };
454
455        // Ensure the output key is truly unique. If not, keep appending "_" until it is.
456        while seen.contains(&out_key) {
457            out_key.push('_');
458        }
459
460        record.push(out_key.clone(), v.clone());
461        seen.insert(out_key);
462    }
463    record
464}
465
466#[cfg(test)]
467mod test {
468    use super::*;
469
470    #[test]
471    fn test_examples() -> nu_test_support::Result {
472        nu_test_support::test().examples(Join)
473    }
474}