1use nu_engine::command_prelude::*;
2use nu_protocol::Config;
3use std::{
4 cmp::max,
5 collections::{HashMap, HashSet},
6};
7
8#[derive(Clone)]
9pub struct Join;
10
11enum JoinType {
12 Inner,
13 Left,
14 Right,
15 Outer,
16}
17
18enum IncludeInner {
19 No,
20 Yes,
21}
22
23impl Command for Join {
24 fn name(&self) -> &str {
25 "join"
26 }
27
28 fn signature(&self) -> Signature {
29 Signature::build("join")
30 .required(
31 "right-table",
32 SyntaxShape::Table([].into()),
33 "The right table in the join.",
34 )
35 .required(
36 "left-on",
37 SyntaxShape::String,
38 "Name of column in input (left) table to join on.",
39 )
40 .optional(
41 "right-on",
42 SyntaxShape::String,
43 "Name of column in right table to join on. Defaults to same column as left table.",
44 )
45 .switch("inner", "Inner join (default)", Some('i'))
46 .switch("left", "Left-outer join", Some('l'))
47 .switch("right", "Right-outer join", Some('r'))
48 .switch("outer", "Outer join", Some('o'))
49 .input_output_types(vec![(Type::table(), Type::table())])
50 .category(Category::Filters)
51 }
52
53 fn description(&self) -> &str {
54 "Join two tables."
55 }
56
57 fn search_terms(&self) -> Vec<&str> {
58 vec!["sql"]
59 }
60
61 fn run(
62 &self,
63 engine_state: &EngineState,
64 stack: &mut Stack,
65 call: &Call,
66 input: PipelineData,
67 ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
68 let metadata = input.metadata();
69 let table_2: Value = call.req(engine_state, stack, 0)?;
70 let l_on: Value = call.req(engine_state, stack, 1)?;
71 let r_on: Value = call
72 .opt(engine_state, stack, 2)?
73 .unwrap_or_else(|| l_on.clone());
74 let span = call.head;
75 let join_type = join_type(engine_state, stack, call)?;
76
77 let collected_input = input.into_value(span)?;
79
80 match (&collected_input, &table_2, &l_on, &r_on) {
81 (
82 Value::List { vals: rows_1, .. },
83 Value::List { vals: rows_2, .. },
84 Value::String { val: l_on, .. },
85 Value::String { val: r_on, .. },
86 ) => {
87 let result = join(rows_1, rows_2, l_on, r_on, join_type, span);
88 Ok(PipelineData::value(result, metadata))
89 }
90 _ => Err(ShellError::UnsupportedInput {
91 msg: "(PipelineData<table>, table, string, string)".into(),
92 input: format!(
93 "({:?}, {:?}, {:?} {:?})",
94 collected_input,
95 table_2.get_type(),
96 l_on.get_type(),
97 r_on.get_type(),
98 ),
99 msg_span: span,
100 input_span: span,
101 }),
102 }
103 }
104
105 fn examples(&self) -> Vec<Example> {
106 vec![Example {
107 description: "Join two tables",
108 example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
109 result: Some(Value::test_list(vec![Value::test_record(record! {
110 "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
111 })])),
112 }]
113 }
114}
115
116fn join_type(
117 engine_state: &EngineState,
118 stack: &mut Stack,
119 call: &Call,
120) -> Result<JoinType, nu_protocol::ShellError> {
121 match (
122 call.has_flag(engine_state, stack, "inner")?,
123 call.has_flag(engine_state, stack, "left")?,
124 call.has_flag(engine_state, stack, "right")?,
125 call.has_flag(engine_state, stack, "outer")?,
126 ) {
127 (_, false, false, false) => Ok(JoinType::Inner),
128 (false, true, false, false) => Ok(JoinType::Left),
129 (false, false, true, false) => Ok(JoinType::Right),
130 (false, false, false, true) => Ok(JoinType::Outer),
131 _ => Err(ShellError::UnsupportedInput {
132 msg: "Choose one of: --inner, --left, --right, --outer".into(),
133 input: "".into(),
134 msg_span: call.head,
135 input_span: call.head,
136 }),
137 }
138}
139
140fn join(
141 left: &[Value],
142 right: &[Value],
143 left_join_key: &str,
144 right_join_key: &str,
145 join_type: JoinType,
146 span: Span,
147) -> Value {
148 let config = Config::default();
174 let sep = ",";
175 let cap = max(left.len(), right.len());
176 let shared_join_key = if left_join_key == right_join_key {
177 Some(left_join_key)
178 } else {
179 None
180 };
181
182 let mut result: Vec<Value> = Vec::new();
185 let is_outer = matches!(join_type, JoinType::Outer);
186 let (this, this_join_key, other, other_keys, join_type) = match join_type {
187 JoinType::Left | JoinType::Outer => (
188 left,
189 left_join_key,
190 lookup_table(right, right_join_key, sep, cap, &config),
191 column_names(right),
192 JoinType::Left,
195 ),
196 JoinType::Inner | JoinType::Right => (
197 right,
198 right_join_key,
199 lookup_table(left, left_join_key, sep, cap, &config),
200 column_names(left),
201 join_type,
202 ),
203 };
204 join_rows(
205 &mut result,
206 this,
207 this_join_key,
208 other,
209 other_keys,
210 shared_join_key,
211 &join_type,
212 IncludeInner::Yes,
213 sep,
214 &config,
215 span,
216 );
217 if is_outer {
218 let (this, this_join_key, other, other_names, join_type) = (
219 right,
220 right_join_key,
221 lookup_table(left, left_join_key, sep, cap, &config),
222 column_names(left),
223 JoinType::Right,
224 );
225 join_rows(
226 &mut result,
227 this,
228 this_join_key,
229 other,
230 other_names,
231 shared_join_key,
232 &join_type,
233 IncludeInner::No,
234 sep,
235 &config,
236 span,
237 );
238 }
239 Value::list(result, span)
240}
241
242#[allow(clippy::too_many_arguments)]
245fn join_rows(
246 result: &mut Vec<Value>,
247 this: &[Value],
248 this_join_key: &str,
249 other: HashMap<String, Vec<&Record>>,
250 other_keys: Vec<&String>,
251 shared_join_key: Option<&str>,
252 join_type: &JoinType,
253 include_inner: IncludeInner,
254 sep: &str,
255 config: &Config,
256 span: Span,
257) {
258 if !this
259 .iter()
260 .any(|this_record| match this_record.as_record() {
261 Ok(record) => record.contains(this_join_key),
262 Err(_) => false,
263 })
264 {
265 return;
267 }
268 for this_row in this {
269 if let Value::Record {
270 val: this_record, ..
271 } = this_row
272 {
273 if let Some(this_valkey) = this_record.get(this_join_key) {
274 if let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config)) {
275 if matches!(include_inner, IncludeInner::Yes) {
276 for other_record in other_rows {
277 let record = match join_type {
279 JoinType::Inner | JoinType::Right => merge_records(
280 other_record, this_record,
282 shared_join_key,
283 ),
284 JoinType::Left => merge_records(
285 this_record, other_record,
287 shared_join_key,
288 ),
289 _ => panic!("not implemented"),
290 };
291 result.push(Value::record(record, span))
292 }
293 }
294 continue;
295 }
296 }
297 if !matches!(join_type, JoinType::Inner) {
298 let other_record = other_keys
303 .iter()
304 .map(|&key| {
305 let val = if Some(key.as_ref()) == shared_join_key {
306 this_record
307 .get(key)
308 .cloned()
309 .unwrap_or_else(|| Value::nothing(span))
310 } else {
311 Value::nothing(span)
312 };
313
314 (key.clone(), val)
315 })
316 .collect();
317
318 let record = match join_type {
319 JoinType::Inner | JoinType::Right => {
320 merge_records(&other_record, this_record, shared_join_key)
321 }
322 JoinType::Left => merge_records(this_record, &other_record, shared_join_key),
323 _ => panic!("not implemented"),
324 };
325
326 result.push(Value::record(record, span))
327 }
328 };
329 }
330}
331
332fn column_names(table: &[Value]) -> Vec<&String> {
335 table
336 .iter()
337 .find_map(|val| match val {
338 Value::Record { val, .. } => Some(val.columns().collect()),
339 _ => None,
340 })
341 .unwrap_or_default()
342}
343
344fn lookup_table<'a>(
347 rows: &'a [Value],
348 on: &str,
349 sep: &str,
350 cap: usize,
351 config: &Config,
352) -> HashMap<String, Vec<&'a Record>> {
353 let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
354 for row in rows {
355 if let Value::Record { val: record, .. } = row {
356 if let Some(val) = record.get(on) {
357 let valkey = val.to_expanded_string(sep, config);
358 map.entry(valkey).or_default().push(record);
359 }
360 };
361 }
362 map
363}
364
365fn merge_records(left: &Record, right: &Record, shared_key: Option<&str>) -> Record {
369 let cap = max(left.len(), right.len());
370 let mut seen = HashSet::with_capacity(cap);
371 let mut record = Record::with_capacity(cap);
372 for (k, v) in left {
373 record.push(k.clone(), v.clone());
374 seen.insert(k);
375 }
376
377 for (k, v) in right {
378 let k_seen = seen.contains(k);
379 let k_shared = shared_key == Some(k.as_str());
380 if !(k_seen && k_shared) {
382 record.push(if k_seen { format!("{k}_") } else { k.clone() }, v.clone());
383 }
384 }
385 record
386}
387
388#[cfg(test)]
389mod test {
390 use super::*;
391
392 #[test]
393 fn test_examples() {
394 use crate::test_examples;
395
396 test_examples(Join {})
397 }
398}