1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4 cmp::max,
5 collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12 Inner,
13 Left,
14 Right,
15 Outer,
16}
17
18enum IncludeInner {
19 No,
20 Yes,
21}
22
23#[derive(Debug, Default)]
24struct RightColumnRename {
25 prefix: Option<String>,
26 suffix: Option<String>,
27}
28
29impl Command for Join {
30 fn name(&self) -> &str {
31 "join"
32 }
33
34 fn signature(&self) -> Signature {
35 Signature::build("join")
36 .required(
37 "right-table",
38 SyntaxShape::Table([].into()),
39 "The right table in the join.",
40 )
41 .required(
42 "left-on",
43 SyntaxShape::String,
44 "Name of column in input (left) table to join on.",
45 )
46 .optional(
47 "right-on",
48 SyntaxShape::String,
49 "Name of column in right table to join on. Defaults to same column as left table.",
50 )
51 .named(
52 "prefix",
53 SyntaxShape::String,
54 "Prefix columns from the right table with this string (excluding the shared join key).",
55 Some('p'),
56 )
57 .named(
58 "suffix",
59 SyntaxShape::String,
60 "Suffix columns from the right table with this string (excluding the shared join key).",
61 Some('s'),
62 )
63 .switch("inner", "Inner join (default).", Some('i'))
64 .switch("left", "Left-outer join.", Some('l'))
65 .switch("right", "Right-outer join.", Some('r'))
66 .switch("outer", "Outer join.", Some('o'))
67 .input_output_types(vec![(Type::table(), Type::table())])
68 .category(Category::Filters)
69 }
70
71 fn description(&self) -> &str {
72 "Join two tables."
73 }
74
75 fn search_terms(&self) -> Vec<&str> {
76 vec!["sql"]
77 }
78
79 fn run(
80 &self,
81 engine_state: &EngineState,
82 stack: &mut Stack,
83 call: &Call,
84 input: PipelineData,
85 ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
86 let mut input = input.into_stream_or_original(engine_state);
87
88 let metadata = input.take_metadata();
89 let table_2: Value = call.req(engine_state, stack, 0)?;
90 let l_on: Value = call.req(engine_state, stack, 1)?;
91 let r_on: Value = call
92 .opt(engine_state, stack, 2)?
93 .unwrap_or_else(|| l_on.clone());
94 let span = call.head;
95 let join_type = join_type(engine_state, stack, call)?;
96 let rename = RightColumnRename {
97 prefix: call.get_flag(engine_state, stack, "prefix")?,
98 suffix: call.get_flag(engine_state, stack, "suffix")?,
99 };
100
101 let collected_input = input.into_value(span)?;
103
104 match (&collected_input, &table_2, &l_on, &r_on) {
105 (
106 Value::List { vals: rows_1, .. },
107 Value::List { vals: rows_2, .. },
108 Value::String { val: l_on, .. },
109 Value::String { val: r_on, .. },
110 ) => {
111 let result = join(rows_1, rows_2, l_on, r_on, join_type, &rename, span);
112 Ok(PipelineData::value(result, metadata))
113 }
114 _ => Err(ShellError::UnsupportedInput {
115 msg: "(PipelineData<table>, table, string, string)".into(),
116 input: format!(
117 "({:?}, {:?}, {:?} {:?})",
118 collected_input,
119 table_2.get_type(),
120 l_on.get_type(),
121 r_on.get_type(),
122 ),
123 msg_span: span,
124 input_span: span,
125 }),
126 }
127 }
128
129 fn examples(&self) -> Vec<Example<'_>> {
130 vec![
131 Example {
132 description: "Join two tables",
133 example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
134 result: Some(Value::test_list(vec![Value::test_record(record! {
135 "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
136 })])),
137 },
138 Example {
139 description: "Join multiple tables with distinct suffixes for the right table's columns",
140 example: "[{id: 1 x: 10}] | join --suffix _a [{id: 1 x: 20}] id | join --suffix _b [{id: 1 x: 30}] id",
141 result: Some(Value::test_list(vec![Value::test_record(record! {
142 "id" => Value::test_int(1),
143 "x" => Value::test_int(10),
144 "x_a" => Value::test_int(20),
145 "x_b" => Value::test_int(30),
146 })])),
147 },
148 Example {
149 description: "Join multiple tables with a prefix for the right table's columns",
150 example: "[{id: 1 x: 10}] | join --prefix r_ [{id: 1 x: 20}] id",
151 result: Some(Value::test_list(vec![Value::test_record(record! {
152 "id" => Value::test_int(1),
153 "x" => Value::test_int(10),
154 "r_x" => Value::test_int(20),
155 })])),
156 },
157 ]
158 }
159}
160
161fn join_type(
162 engine_state: &EngineState,
163 stack: &mut Stack,
164 call: &Call,
165) -> Result<JoinType, nu_protocol::ShellError> {
166 match (
167 call.has_flag(engine_state, stack, "inner")?,
168 call.has_flag(engine_state, stack, "left")?,
169 call.has_flag(engine_state, stack, "right")?,
170 call.has_flag(engine_state, stack, "outer")?,
171 ) {
172 (_, false, false, false) => Ok(JoinType::Inner),
173 (false, true, false, false) => Ok(JoinType::Left),
174 (false, false, true, false) => Ok(JoinType::Right),
175 (false, false, false, true) => Ok(JoinType::Outer),
176 _ => Err(ShellError::UnsupportedInput {
177 msg: "Choose one of: --inner, --left, --right, --outer".into(),
178 input: "".into(),
179 msg_span: call.head,
180 input_span: call.head,
181 }),
182 }
183}
184
185fn join(
186 left: &[Value],
187 right: &[Value],
188 left_join_key: &str,
189 right_join_key: &str,
190 join_type: JoinType,
191 rename: &RightColumnRename,
192 span: Span,
193) -> Value {
194 let config = Config::default();
220 let sep = ",";
221 let cap = max(left.len(), right.len());
222 let shared_join_key = if left_join_key == right_join_key {
223 Some(left_join_key)
224 } else {
225 None
226 };
227
228 let mut result: Vec<Value> = Vec::new();
231 let is_outer = matches!(join_type, JoinType::Outer);
232 let (this, this_join_key, other, other_keys, join_type) = match join_type {
233 JoinType::Left | JoinType::Outer => (
234 left,
235 left_join_key,
236 lookup_table(right, right_join_key, sep, cap, &config),
237 column_names(right),
238 JoinType::Left,
241 ),
242 JoinType::Inner | JoinType::Right => (
243 right,
244 right_join_key,
245 lookup_table(left, left_join_key, sep, cap, &config),
246 column_names(left),
247 join_type,
248 ),
249 };
250 join_rows(
251 &mut result,
252 this,
253 this_join_key,
254 other,
255 other_keys,
256 shared_join_key,
257 &join_type,
258 IncludeInner::Yes,
259 sep,
260 &config,
261 rename,
262 span,
263 );
264 if is_outer {
265 let (this, this_join_key, other, other_names, join_type) = (
266 right,
267 right_join_key,
268 lookup_table(left, left_join_key, sep, cap, &config),
269 column_names(left),
270 JoinType::Right,
271 );
272 join_rows(
273 &mut result,
274 this,
275 this_join_key,
276 other,
277 other_names,
278 shared_join_key,
279 &join_type,
280 IncludeInner::No,
281 sep,
282 &config,
283 rename,
284 span,
285 );
286 }
287 Value::list(result, span)
288}
289
290#[allow(clippy::too_many_arguments)]
293fn join_rows(
294 result: &mut Vec<Value>,
295 this: &[Value],
296 this_join_key: &str,
297 other: HashMap<String, Vec<&Record>>,
298 other_keys: Vec<&String>,
299 shared_join_key: Option<&str>,
300 join_type: &JoinType,
301 include_inner: IncludeInner,
302 sep: &str,
303 config: &Config,
304 rename: &RightColumnRename,
305 span: Span,
306) {
307 if !this
308 .iter()
309 .any(|this_record| match this_record.as_record() {
310 Ok(record) => record.contains(this_join_key),
311 Err(_) => false,
312 })
313 {
314 return;
316 }
317 for this_row in this {
318 if let Value::Record {
319 val: this_record, ..
320 } = this_row
321 {
322 if let Some(this_valkey) = this_record.get(this_join_key)
323 && let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config))
324 {
325 if let IncludeInner::Yes = include_inner {
326 for other_record in other_rows {
327 let record = match join_type {
329 JoinType::Inner | JoinType::Right => merge_records(
330 other_record, this_record,
332 shared_join_key,
333 rename,
334 ),
335 JoinType::Left => merge_records(
336 this_record, other_record,
338 shared_join_key,
339 rename,
340 ),
341 _ => panic!("not implemented"),
342 };
343 result.push(Value::record(record, span))
344 }
345 }
346 continue;
347 }
348 if !matches!(join_type, JoinType::Inner) {
349 let other_record = other_keys
354 .iter()
355 .map(|&key| {
356 let val = if Some(key.as_ref()) == shared_join_key {
357 this_record
358 .get(key)
359 .cloned()
360 .unwrap_or_else(|| Value::nothing(span))
361 } else {
362 Value::nothing(span)
363 };
364
365 (key.clone(), val)
366 })
367 .collect();
368
369 let record = match join_type {
370 JoinType::Inner | JoinType::Right => {
371 merge_records(&other_record, this_record, shared_join_key, rename)
372 }
373 JoinType::Left => {
374 merge_records(this_record, &other_record, shared_join_key, rename)
375 }
376 _ => panic!("not implemented"),
377 };
378
379 result.push(Value::record(record, span))
380 }
381 };
382 }
383}
384
385fn column_names(table: &[Value]) -> Vec<&String> {
388 table
389 .iter()
390 .find_map(|val| match val {
391 Value::Record { val, .. } => Some(val.columns().collect()),
392 _ => None,
393 })
394 .unwrap_or_default()
395}
396
397fn lookup_table<'a>(
400 rows: &'a [Value],
401 on: &str,
402 sep: &str,
403 cap: usize,
404 config: &Config,
405) -> HashMap<String, Vec<&'a Record>> {
406 let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
407 for row in rows {
408 if let Value::Record { val: record, .. } = row
409 && let Some(val) = record.get(on)
410 {
411 let valkey = val.to_expanded_string(sep, config);
412 map.entry(valkey).or_default().push(record);
413 };
414 }
415 map
416}
417
418fn merge_records(
422 left: &Record,
423 right: &Record,
424 shared_key: Option<&str>,
425 rename: &RightColumnRename,
426) -> Record {
427 let cap = max(left.len(), right.len());
428 let mut seen = HashSet::with_capacity(cap);
429 let mut record = Record::with_capacity(cap);
430 for (k, v) in left {
431 record.push(k.clone(), v.clone());
432 seen.insert(k.clone());
433 }
434
435 for (k, v) in right {
436 let k_shared = shared_key == Some(k.as_str());
437 if k_shared && seen.contains(k) {
439 continue;
440 }
441
442 let mut out_key = if rename.prefix.is_some() || rename.suffix.is_some() {
443 format!(
444 "{}{}{}",
445 rename.prefix.as_deref().unwrap_or(""),
446 k,
447 rename.suffix.as_deref().unwrap_or("")
448 )
449 } else if seen.contains(k) {
450 format!("{k}_")
451 } else {
452 k.clone()
453 };
454
455 while seen.contains(&out_key) {
457 out_key.push('_');
458 }
459
460 record.push(out_key.clone(), v.clone());
461 seen.insert(out_key);
462 }
463 record
464}
465
466#[cfg(test)]
467mod test {
468 use super::*;
469
470 #[test]
471 fn test_examples() -> nu_test_support::Result {
472 nu_test_support::test().examples(Join)
473 }
474}