use nu_engine::CallExt;
use nu_protocol::ast::Call;
use nu_protocol::engine::{Command, EngineState, Stack};
use nu_protocol::{
record, Category, Config, Example, PipelineData, Record, ShellError, Signature, Span,
SyntaxShape, Type, Value,
};
use std::cmp::max;
use std::collections::{HashMap, HashSet};
#[derive(Clone)]
pub struct Join;
enum JoinType {
Inner,
Left,
Right,
Outer,
}
enum IncludeInner {
No,
Yes,
}
impl Command for Join {
fn name(&self) -> &str {
"join"
}
fn signature(&self) -> Signature {
Signature::build("join")
.required(
"right-table",
SyntaxShape::List(Box::new(SyntaxShape::Any)),
"The right table in the join.",
)
.required(
"left-on",
SyntaxShape::String,
"Name of column in input (left) table to join on.",
)
.optional(
"right-on",
SyntaxShape::String,
"Name of column in right table to join on. Defaults to same column as left table.",
)
.switch("inner", "Inner join (default)", Some('i'))
.switch("left", "Left-outer join", Some('l'))
.switch("right", "Right-outer join", Some('r'))
.switch("outer", "Outer join", Some('o'))
.input_output_types(vec![(Type::Table(vec![]), Type::Table(vec![]))])
.category(Category::Filters)
}
fn usage(&self) -> &str {
"Join two tables."
}
fn search_terms(&self) -> Vec<&str> {
vec!["sql"]
}
fn run(
&self,
engine_state: &EngineState,
stack: &mut Stack,
call: &Call,
input: PipelineData,
) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
let metadata = input.metadata();
let table_2: Value = call.req(engine_state, stack, 0)?;
let l_on: Value = call.req(engine_state, stack, 1)?;
let r_on: Value = call
.opt(engine_state, stack, 2)?
.unwrap_or_else(|| l_on.clone());
let span = call.head;
let join_type = join_type(engine_state, stack, call)?;
let collected_input = input.into_value(span);
match (&collected_input, &table_2, &l_on, &r_on) {
(
Value::List { vals: rows_1, .. },
Value::List { vals: rows_2, .. },
Value::String { val: l_on, .. },
Value::String { val: r_on, .. },
) => {
let result = join(rows_1, rows_2, l_on, r_on, join_type, span);
Ok(PipelineData::Value(result, metadata))
}
_ => Err(ShellError::UnsupportedInput {
msg: "(PipelineData<table>, table, string, string)".into(),
input: format!(
"({:?}, {:?}, {:?} {:?})",
collected_input,
table_2.get_type(),
l_on.get_type(),
r_on.get_type(),
),
msg_span: span,
input_span: span,
}),
}
}
fn examples(&self) -> Vec<Example> {
vec![Example {
description: "Join two tables",
example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
result: Some(Value::test_list(vec![Value::test_record(record! {
"a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
})])),
}]
}
}
fn join_type(
engine_state: &EngineState,
stack: &mut Stack,
call: &Call,
) -> Result<JoinType, nu_protocol::ShellError> {
match (
call.has_flag(engine_state, stack, "inner")?,
call.has_flag(engine_state, stack, "left")?,
call.has_flag(engine_state, stack, "right")?,
call.has_flag(engine_state, stack, "outer")?,
) {
(_, false, false, false) => Ok(JoinType::Inner),
(false, true, false, false) => Ok(JoinType::Left),
(false, false, true, false) => Ok(JoinType::Right),
(false, false, false, true) => Ok(JoinType::Outer),
_ => Err(ShellError::UnsupportedInput {
msg: "Choose one of: --inner, --left, --right, --outer".into(),
input: "".into(),
msg_span: call.head,
input_span: call.head,
}),
}
}
fn join(
left: &[Value],
right: &[Value],
left_join_key: &str,
right_join_key: &str,
join_type: JoinType,
span: Span,
) -> Value {
let config = Config::default();
let sep = ",";
let cap = max(left.len(), right.len());
let shared_join_key = if left_join_key == right_join_key {
Some(left_join_key)
} else {
None
};
let mut result: Vec<Value> = Vec::new();
let is_outer = matches!(join_type, JoinType::Outer);
let (this, this_join_key, other, other_keys, join_type) = match join_type {
JoinType::Left | JoinType::Outer => (
left,
left_join_key,
lookup_table(right, right_join_key, sep, cap, &config),
column_names(right),
JoinType::Left,
),
JoinType::Inner | JoinType::Right => (
right,
right_join_key,
lookup_table(left, left_join_key, sep, cap, &config),
column_names(left),
join_type,
),
};
join_rows(
&mut result,
this,
this_join_key,
other,
other_keys,
shared_join_key,
&join_type,
IncludeInner::Yes,
sep,
&config,
span,
);
if is_outer {
let (this, this_join_key, other, other_names, join_type) = (
right,
right_join_key,
lookup_table(left, left_join_key, sep, cap, &config),
column_names(left),
JoinType::Right,
);
join_rows(
&mut result,
this,
this_join_key,
other,
other_names,
shared_join_key,
&join_type,
IncludeInner::No,
sep,
&config,
span,
);
}
Value::list(result, span)
}
#[allow(clippy::too_many_arguments)]
fn join_rows(
result: &mut Vec<Value>,
this: &[Value],
this_join_key: &str,
other: HashMap<String, Vec<&Record>>,
other_keys: Vec<&String>,
shared_join_key: Option<&str>,
join_type: &JoinType,
include_inner: IncludeInner,
sep: &str,
config: &Config,
span: Span,
) {
for this_row in this {
if let Value::Record {
val: this_record, ..
} = this_row
{
if let Some(this_valkey) = this_record.get(this_join_key) {
if let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config)) {
if matches!(include_inner, IncludeInner::Yes) {
for other_record in other_rows {
let record = match join_type {
JoinType::Inner | JoinType::Right => merge_records(
other_record, this_record,
shared_join_key,
),
JoinType::Left => merge_records(
this_record, other_record,
shared_join_key,
),
_ => panic!("not implemented"),
};
result.push(Value::record(record, span))
}
}
} else if !matches!(join_type, JoinType::Inner) {
let other_record = other_keys
.iter()
.map(|&key| {
let val = if Some(key.as_ref()) == shared_join_key {
this_record
.get(key)
.cloned()
.unwrap_or_else(|| Value::nothing(span))
} else {
Value::nothing(span)
};
(key.clone(), val)
})
.collect();
let record = match join_type {
JoinType::Inner | JoinType::Right => {
merge_records(&other_record, this_record, shared_join_key)
}
JoinType::Left => {
merge_records(this_record, &other_record, shared_join_key)
}
_ => panic!("not implemented"),
};
result.push(Value::record(record, span))
}
} };
}
}
fn column_names(table: &[Value]) -> Vec<&String> {
table
.iter()
.find_map(|val| match val {
Value::Record { val, .. } => Some(val.columns().collect()),
_ => None,
})
.unwrap_or_default()
}
fn lookup_table<'a>(
rows: &'a [Value],
on: &str,
sep: &str,
cap: usize,
config: &Config,
) -> HashMap<String, Vec<&'a Record>> {
let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
for row in rows {
if let Value::Record { val: record, .. } = row {
if let Some(val) = record.get(on) {
let valkey = val.to_expanded_string(sep, config);
map.entry(valkey).or_default().push(record);
}
};
}
map
}
fn merge_records(left: &Record, right: &Record, shared_key: Option<&str>) -> Record {
let cap = max(left.len(), right.len());
let mut seen = HashSet::with_capacity(cap);
let mut record = Record::with_capacity(cap);
for (k, v) in left {
record.push(k.clone(), v.clone());
seen.insert(k);
}
for (k, v) in right {
let k_seen = seen.contains(k);
let k_shared = shared_key == Some(k.as_str());
if !(k_seen && k_shared) {
record.push(
if k_seen { format!("{}_", k) } else { k.clone() },
v.clone(),
);
}
}
record
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_examples() {
use crate::test_examples;
test_examples(Join {})
}
}