1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4 cmp::max,
5 collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12 Inner,
13 Left,
14 Right,
15 Outer,
16}
17
18enum IncludeInner {
19 No,
20 Yes,
21}
22
23#[derive(Debug, Default)]
24struct RightColumnRename {
25 prefix: Option<String>,
26 suffix: Option<String>,
27}
28
29impl Command for Join {
30 fn name(&self) -> &str {
31 "join"
32 }
33
34 fn signature(&self) -> Signature {
35 Signature::build("join")
36 .required(
37 "right-table",
38 SyntaxShape::Table([].into()),
39 "The right table in the join.",
40 )
41 .required(
42 "left-on",
43 SyntaxShape::String,
44 "Name of column in input (left) table to join on.",
45 )
46 .optional(
47 "right-on",
48 SyntaxShape::String,
49 "Name of column in right table to join on. Defaults to same column as left table.",
50 )
51 .named(
52 "prefix",
53 SyntaxShape::String,
54 "Prefix columns from the right table with this string (excluding the shared join key).",
55 Some('p'),
56 )
57 .named(
58 "suffix",
59 SyntaxShape::String,
60 "Suffix columns from the right table with this string (excluding the shared join key).",
61 Some('s'),
62 )
63 .switch("inner", "Inner join (default).", Some('i'))
64 .switch("left", "Left-outer join.", Some('l'))
65 .switch("right", "Right-outer join.", Some('r'))
66 .switch("outer", "Outer join.", Some('o'))
67 .input_output_types(vec![(Type::table(), Type::table())])
68 .category(Category::Filters)
69 }
70
71 fn description(&self) -> &str {
72 "Join two tables."
73 }
74
75 fn search_terms(&self) -> Vec<&str> {
76 vec!["sql"]
77 }
78
79 fn run(
80 &self,
81 engine_state: &EngineState,
82 stack: &mut Stack,
83 call: &Call,
84 input: PipelineData,
85 ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
86 let metadata = input.metadata();
87 let table_2: Value = call.req(engine_state, stack, 0)?;
88 let l_on: Value = call.req(engine_state, stack, 1)?;
89 let r_on: Value = call
90 .opt(engine_state, stack, 2)?
91 .unwrap_or_else(|| l_on.clone());
92 let span = call.head;
93 let join_type = join_type(engine_state, stack, call)?;
94 let rename = RightColumnRename {
95 prefix: call.get_flag(engine_state, stack, "prefix")?,
96 suffix: call.get_flag(engine_state, stack, "suffix")?,
97 };
98
99 let collected_input = input.into_value(span)?;
101
102 match (&collected_input, &table_2, &l_on, &r_on) {
103 (
104 Value::List { vals: rows_1, .. },
105 Value::List { vals: rows_2, .. },
106 Value::String { val: l_on, .. },
107 Value::String { val: r_on, .. },
108 ) => {
109 let result = join(rows_1, rows_2, l_on, r_on, join_type, &rename, span);
110 Ok(PipelineData::value(result, metadata))
111 }
112 _ => Err(ShellError::UnsupportedInput {
113 msg: "(PipelineData<table>, table, string, string)".into(),
114 input: format!(
115 "({:?}, {:?}, {:?} {:?})",
116 collected_input,
117 table_2.get_type(),
118 l_on.get_type(),
119 r_on.get_type(),
120 ),
121 msg_span: span,
122 input_span: span,
123 }),
124 }
125 }
126
127 fn examples(&self) -> Vec<Example<'_>> {
128 vec![
129 Example {
130 description: "Join two tables",
131 example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
132 result: Some(Value::test_list(vec![Value::test_record(record! {
133 "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
134 })])),
135 },
136 Example {
137 description: "Join multiple tables with distinct suffixes for the right table's columns",
138 example: "[{id: 1 x: 10}] | join --suffix _a [{id: 1 x: 20}] id | join --suffix _b [{id: 1 x: 30}] id",
139 result: Some(Value::test_list(vec![Value::test_record(record! {
140 "id" => Value::test_int(1),
141 "x" => Value::test_int(10),
142 "x_a" => Value::test_int(20),
143 "x_b" => Value::test_int(30),
144 })])),
145 },
146 Example {
147 description: "Join multiple tables with a prefix for the right table's columns",
148 example: "[{id: 1 x: 10}] | join --prefix r_ [{id: 1 x: 20}] id",
149 result: Some(Value::test_list(vec![Value::test_record(record! {
150 "id" => Value::test_int(1),
151 "x" => Value::test_int(10),
152 "r_x" => Value::test_int(20),
153 })])),
154 },
155 ]
156 }
157}
158
159fn join_type(
160 engine_state: &EngineState,
161 stack: &mut Stack,
162 call: &Call,
163) -> Result<JoinType, nu_protocol::ShellError> {
164 match (
165 call.has_flag(engine_state, stack, "inner")?,
166 call.has_flag(engine_state, stack, "left")?,
167 call.has_flag(engine_state, stack, "right")?,
168 call.has_flag(engine_state, stack, "outer")?,
169 ) {
170 (_, false, false, false) => Ok(JoinType::Inner),
171 (false, true, false, false) => Ok(JoinType::Left),
172 (false, false, true, false) => Ok(JoinType::Right),
173 (false, false, false, true) => Ok(JoinType::Outer),
174 _ => Err(ShellError::UnsupportedInput {
175 msg: "Choose one of: --inner, --left, --right, --outer".into(),
176 input: "".into(),
177 msg_span: call.head,
178 input_span: call.head,
179 }),
180 }
181}
182
183fn join(
184 left: &[Value],
185 right: &[Value],
186 left_join_key: &str,
187 right_join_key: &str,
188 join_type: JoinType,
189 rename: &RightColumnRename,
190 span: Span,
191) -> Value {
192 let config = Config::default();
218 let sep = ",";
219 let cap = max(left.len(), right.len());
220 let shared_join_key = if left_join_key == right_join_key {
221 Some(left_join_key)
222 } else {
223 None
224 };
225
226 let mut result: Vec<Value> = Vec::new();
229 let is_outer = matches!(join_type, JoinType::Outer);
230 let (this, this_join_key, other, other_keys, join_type) = match join_type {
231 JoinType::Left | JoinType::Outer => (
232 left,
233 left_join_key,
234 lookup_table(right, right_join_key, sep, cap, &config),
235 column_names(right),
236 JoinType::Left,
239 ),
240 JoinType::Inner | JoinType::Right => (
241 right,
242 right_join_key,
243 lookup_table(left, left_join_key, sep, cap, &config),
244 column_names(left),
245 join_type,
246 ),
247 };
248 join_rows(
249 &mut result,
250 this,
251 this_join_key,
252 other,
253 other_keys,
254 shared_join_key,
255 &join_type,
256 IncludeInner::Yes,
257 sep,
258 &config,
259 rename,
260 span,
261 );
262 if is_outer {
263 let (this, this_join_key, other, other_names, join_type) = (
264 right,
265 right_join_key,
266 lookup_table(left, left_join_key, sep, cap, &config),
267 column_names(left),
268 JoinType::Right,
269 );
270 join_rows(
271 &mut result,
272 this,
273 this_join_key,
274 other,
275 other_names,
276 shared_join_key,
277 &join_type,
278 IncludeInner::No,
279 sep,
280 &config,
281 rename,
282 span,
283 );
284 }
285 Value::list(result, span)
286}
287
288#[allow(clippy::too_many_arguments)]
291fn join_rows(
292 result: &mut Vec<Value>,
293 this: &[Value],
294 this_join_key: &str,
295 other: HashMap<String, Vec<&Record>>,
296 other_keys: Vec<&String>,
297 shared_join_key: Option<&str>,
298 join_type: &JoinType,
299 include_inner: IncludeInner,
300 sep: &str,
301 config: &Config,
302 rename: &RightColumnRename,
303 span: Span,
304) {
305 if !this
306 .iter()
307 .any(|this_record| match this_record.as_record() {
308 Ok(record) => record.contains(this_join_key),
309 Err(_) => false,
310 })
311 {
312 return;
314 }
315 for this_row in this {
316 if let Value::Record {
317 val: this_record, ..
318 } = this_row
319 {
320 if let Some(this_valkey) = this_record.get(this_join_key)
321 && let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config))
322 {
323 if let IncludeInner::Yes = include_inner {
324 for other_record in other_rows {
325 let record = match join_type {
327 JoinType::Inner | JoinType::Right => merge_records(
328 other_record, this_record,
330 shared_join_key,
331 rename,
332 ),
333 JoinType::Left => merge_records(
334 this_record, other_record,
336 shared_join_key,
337 rename,
338 ),
339 _ => panic!("not implemented"),
340 };
341 result.push(Value::record(record, span))
342 }
343 }
344 continue;
345 }
346 if !matches!(join_type, JoinType::Inner) {
347 let other_record = other_keys
352 .iter()
353 .map(|&key| {
354 let val = if Some(key.as_ref()) == shared_join_key {
355 this_record
356 .get(key)
357 .cloned()
358 .unwrap_or_else(|| Value::nothing(span))
359 } else {
360 Value::nothing(span)
361 };
362
363 (key.clone(), val)
364 })
365 .collect();
366
367 let record = match join_type {
368 JoinType::Inner | JoinType::Right => {
369 merge_records(&other_record, this_record, shared_join_key, rename)
370 }
371 JoinType::Left => {
372 merge_records(this_record, &other_record, shared_join_key, rename)
373 }
374 _ => panic!("not implemented"),
375 };
376
377 result.push(Value::record(record, span))
378 }
379 };
380 }
381}
382
383fn column_names(table: &[Value]) -> Vec<&String> {
386 table
387 .iter()
388 .find_map(|val| match val {
389 Value::Record { val, .. } => Some(val.columns().collect()),
390 _ => None,
391 })
392 .unwrap_or_default()
393}
394
395fn lookup_table<'a>(
398 rows: &'a [Value],
399 on: &str,
400 sep: &str,
401 cap: usize,
402 config: &Config,
403) -> HashMap<String, Vec<&'a Record>> {
404 let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
405 for row in rows {
406 if let Value::Record { val: record, .. } = row
407 && let Some(val) = record.get(on)
408 {
409 let valkey = val.to_expanded_string(sep, config);
410 map.entry(valkey).or_default().push(record);
411 };
412 }
413 map
414}
415
416fn merge_records(
420 left: &Record,
421 right: &Record,
422 shared_key: Option<&str>,
423 rename: &RightColumnRename,
424) -> Record {
425 let cap = max(left.len(), right.len());
426 let mut seen = HashSet::with_capacity(cap);
427 let mut record = Record::with_capacity(cap);
428 for (k, v) in left {
429 record.push(k.clone(), v.clone());
430 seen.insert(k.clone());
431 }
432
433 for (k, v) in right {
434 let k_shared = shared_key == Some(k.as_str());
435 if k_shared && seen.contains(k) {
437 continue;
438 }
439
440 let mut out_key = if rename.prefix.is_some() || rename.suffix.is_some() {
441 format!(
442 "{}{}{}",
443 rename.prefix.as_deref().unwrap_or(""),
444 k,
445 rename.suffix.as_deref().unwrap_or("")
446 )
447 } else if seen.contains(k) {
448 format!("{k}_")
449 } else {
450 k.clone()
451 };
452
453 while seen.contains(&out_key) {
455 out_key.push('_');
456 }
457
458 record.push(out_key.clone(), v.clone());
459 seen.insert(out_key);
460 }
461 record
462}
463
464#[cfg(test)]
465mod test {
466 use super::*;
467
468 #[test]
469 fn test_examples() {
470 use crate::test_examples;
471
472 test_examples(Join {})
473 }
474}