use nu_engine::command_prelude::*;
use nu_protocol::Config;
use std::{
    cmp::max,
    collections::{HashMap, HashSet},
};
#[derive(Clone)]
pub struct Join;
enum JoinType {
    Inner,
    Left,
    Right,
    Outer,
}
enum IncludeInner {
    No,
    Yes,
}
impl Command for Join {
    fn name(&self) -> &str {
        "join"
    }
    fn signature(&self) -> Signature {
        Signature::build("join")
            .required(
                "right-table",
                SyntaxShape::List(Box::new(SyntaxShape::Any)),
                "The right table in the join.",
            )
            .required(
                "left-on",
                SyntaxShape::String,
                "Name of column in input (left) table to join on.",
            )
            .optional(
                "right-on",
                SyntaxShape::String,
                "Name of column in right table to join on. Defaults to same column as left table.",
            )
            .switch("inner", "Inner join (default)", Some('i'))
            .switch("left", "Left-outer join", Some('l'))
            .switch("right", "Right-outer join", Some('r'))
            .switch("outer", "Outer join", Some('o'))
            .input_output_types(vec![(Type::table(), Type::table())])
            .category(Category::Filters)
    }
    fn usage(&self) -> &str {
        "Join two tables."
    }
    fn search_terms(&self) -> Vec<&str> {
        vec!["sql"]
    }
    fn run(
        &self,
        engine_state: &EngineState,
        stack: &mut Stack,
        call: &Call,
        input: PipelineData,
    ) -> Result<nu_protocol::PipelineData, nu_protocol::ShellError> {
        let metadata = input.metadata();
        let table_2: Value = call.req(engine_state, stack, 0)?;
        let l_on: Value = call.req(engine_state, stack, 1)?;
        let r_on: Value = call
            .opt(engine_state, stack, 2)?
            .unwrap_or_else(|| l_on.clone());
        let span = call.head;
        let join_type = join_type(engine_state, stack, call)?;
        let collected_input = input.into_value(span);
        match (&collected_input, &table_2, &l_on, &r_on) {
            (
                Value::List { vals: rows_1, .. },
                Value::List { vals: rows_2, .. },
                Value::String { val: l_on, .. },
                Value::String { val: r_on, .. },
            ) => {
                let result = join(rows_1, rows_2, l_on, r_on, join_type, span);
                Ok(PipelineData::Value(result, metadata))
            }
            _ => Err(ShellError::UnsupportedInput {
                msg: "(PipelineData<table>, table, string, string)".into(),
                input: format!(
                    "({:?}, {:?}, {:?} {:?})",
                    collected_input,
                    table_2.get_type(),
                    l_on.get_type(),
                    r_on.get_type(),
                ),
                msg_span: span,
                input_span: span,
            }),
        }
    }
    fn examples(&self) -> Vec<Example> {
        vec![Example {
            description: "Join two tables",
            example: "[{a: 1 b: 2}] | join [{a: 1 c: 3}] a",
            result: Some(Value::test_list(vec![Value::test_record(record! {
                "a" => Value::test_int(1), "b" => Value::test_int(2), "c" => Value::test_int(3),
            })])),
        }]
    }
}
fn join_type(
    engine_state: &EngineState,
    stack: &mut Stack,
    call: &Call,
) -> Result<JoinType, nu_protocol::ShellError> {
    match (
        call.has_flag(engine_state, stack, "inner")?,
        call.has_flag(engine_state, stack, "left")?,
        call.has_flag(engine_state, stack, "right")?,
        call.has_flag(engine_state, stack, "outer")?,
    ) {
        (_, false, false, false) => Ok(JoinType::Inner),
        (false, true, false, false) => Ok(JoinType::Left),
        (false, false, true, false) => Ok(JoinType::Right),
        (false, false, false, true) => Ok(JoinType::Outer),
        _ => Err(ShellError::UnsupportedInput {
            msg: "Choose one of: --inner, --left, --right, --outer".into(),
            input: "".into(),
            msg_span: call.head,
            input_span: call.head,
        }),
    }
}
fn join(
    left: &[Value],
    right: &[Value],
    left_join_key: &str,
    right_join_key: &str,
    join_type: JoinType,
    span: Span,
) -> Value {
    let config = Config::default();
    let sep = ",";
    let cap = max(left.len(), right.len());
    let shared_join_key = if left_join_key == right_join_key {
        Some(left_join_key)
    } else {
        None
    };
    let mut result: Vec<Value> = Vec::new();
    let is_outer = matches!(join_type, JoinType::Outer);
    let (this, this_join_key, other, other_keys, join_type) = match join_type {
        JoinType::Left | JoinType::Outer => (
            left,
            left_join_key,
            lookup_table(right, right_join_key, sep, cap, &config),
            column_names(right),
            JoinType::Left,
        ),
        JoinType::Inner | JoinType::Right => (
            right,
            right_join_key,
            lookup_table(left, left_join_key, sep, cap, &config),
            column_names(left),
            join_type,
        ),
    };
    join_rows(
        &mut result,
        this,
        this_join_key,
        other,
        other_keys,
        shared_join_key,
        &join_type,
        IncludeInner::Yes,
        sep,
        &config,
        span,
    );
    if is_outer {
        let (this, this_join_key, other, other_names, join_type) = (
            right,
            right_join_key,
            lookup_table(left, left_join_key, sep, cap, &config),
            column_names(left),
            JoinType::Right,
        );
        join_rows(
            &mut result,
            this,
            this_join_key,
            other,
            other_names,
            shared_join_key,
            &join_type,
            IncludeInner::No,
            sep,
            &config,
            span,
        );
    }
    Value::list(result, span)
}
#[allow(clippy::too_many_arguments)]
fn join_rows(
    result: &mut Vec<Value>,
    this: &[Value],
    this_join_key: &str,
    other: HashMap<String, Vec<&Record>>,
    other_keys: Vec<&String>,
    shared_join_key: Option<&str>,
    join_type: &JoinType,
    include_inner: IncludeInner,
    sep: &str,
    config: &Config,
    span: Span,
) {
    for this_row in this {
        if let Value::Record {
            val: this_record, ..
        } = this_row
        {
            if let Some(this_valkey) = this_record.get(this_join_key) {
                if let Some(other_rows) = other.get(&this_valkey.to_expanded_string(sep, config)) {
                    if matches!(include_inner, IncludeInner::Yes) {
                        for other_record in other_rows {
                            let record = match join_type {
                                JoinType::Inner | JoinType::Right => merge_records(
                                    other_record, this_record,
                                    shared_join_key,
                                ),
                                JoinType::Left => merge_records(
                                    this_record, other_record,
                                    shared_join_key,
                                ),
                                _ => panic!("not implemented"),
                            };
                            result.push(Value::record(record, span))
                        }
                    }
                } else if !matches!(join_type, JoinType::Inner) {
                    let other_record = other_keys
                        .iter()
                        .map(|&key| {
                            let val = if Some(key.as_ref()) == shared_join_key {
                                this_record
                                    .get(key)
                                    .cloned()
                                    .unwrap_or_else(|| Value::nothing(span))
                            } else {
                                Value::nothing(span)
                            };
                            (key.clone(), val)
                        })
                        .collect();
                    let record = match join_type {
                        JoinType::Inner | JoinType::Right => {
                            merge_records(&other_record, this_record, shared_join_key)
                        }
                        JoinType::Left => {
                            merge_records(this_record, &other_record, shared_join_key)
                        }
                        _ => panic!("not implemented"),
                    };
                    result.push(Value::record(record, span))
                }
            } };
    }
}
fn column_names(table: &[Value]) -> Vec<&String> {
    table
        .iter()
        .find_map(|val| match val {
            Value::Record { val, .. } => Some(val.columns().collect()),
            _ => None,
        })
        .unwrap_or_default()
}
fn lookup_table<'a>(
    rows: &'a [Value],
    on: &str,
    sep: &str,
    cap: usize,
    config: &Config,
) -> HashMap<String, Vec<&'a Record>> {
    let mut map = HashMap::<String, Vec<&'a Record>>::with_capacity(cap);
    for row in rows {
        if let Value::Record { val: record, .. } = row {
            if let Some(val) = record.get(on) {
                let valkey = val.to_expanded_string(sep, config);
                map.entry(valkey).or_default().push(record);
            }
        };
    }
    map
}
fn merge_records(left: &Record, right: &Record, shared_key: Option<&str>) -> Record {
    let cap = max(left.len(), right.len());
    let mut seen = HashSet::with_capacity(cap);
    let mut record = Record::with_capacity(cap);
    for (k, v) in left {
        record.push(k.clone(), v.clone());
        seen.insert(k);
    }
    for (k, v) in right {
        let k_seen = seen.contains(k);
        let k_shared = shared_key == Some(k.as_str());
        if !(k_seen && k_shared) {
            record.push(
                if k_seen { format!("{}_", k) } else { k.clone() },
                v.clone(),
            );
        }
    }
    record
}
#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn test_examples() {
        use crate::test_examples;
        test_examples(Join {})
    }
}