csv_managed/
join.rs

1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use encoding_rs::Encoding;
5use log::info;
6
7use crate::{
8    cli::{JoinArgs, JoinKind},
9    data::parse_typed_value,
10    io_utils,
11    schema::{self, ColumnType, Schema},
12};
13
14const KEY_SEPARATOR: &str = "\u{1f}";
15
16pub fn execute(args: &JoinArgs) -> Result<()> {
17    if args.left_key.is_empty() || args.right_key.is_empty() {
18        return Err(anyhow!("Join requires --left-key and --right-key"));
19    }
20    if io_utils::is_dash(&args.right) {
21        return Err(anyhow!(
22            "Right input cannot be stdin for join operations; provide a file path"
23        ));
24    }
25    if io_utils::is_dash(&args.left) && args.left_schema.is_none() {
26        return Err(anyhow!(
27            "Joining from stdin requires --left-schema (or --left-meta) to describe the schema"
28        ));
29    }
30
31    let left_keys = parse_key_list(&args.left_key)?;
32    let right_keys = parse_key_list(&args.right_key)?;
33    if left_keys.len() != right_keys.len() {
34        return Err(anyhow!(
35            "Left and right join keys must contain the same number of columns"
36        ));
37    }
38
39    let left_delimiter = io_utils::resolve_input_delimiter(&args.left, args.delimiter);
40    let right_delimiter = io_utils::resolve_input_delimiter(&args.right, args.delimiter);
41    let output_delimiter =
42        io_utils::resolve_output_delimiter(args.output.as_deref(), None, left_delimiter);
43    let left_encoding = io_utils::resolve_encoding(args.left_encoding.as_deref())?;
44    let right_encoding = io_utils::resolve_encoding(args.right_encoding.as_deref())?;
45    let output_encoding = io_utils::resolve_encoding(args.output_encoding.as_deref())?;
46
47    let left_schema = load_schema(
48        &args.left,
49        args.left_schema.as_ref(),
50        left_delimiter,
51        left_encoding,
52    )?;
53    let right_schema = load_schema(
54        &args.right,
55        args.right_schema.as_ref(),
56        right_delimiter,
57        right_encoding,
58    )?;
59
60    let left_indices = column_indices(&left_schema, &left_keys)?;
61    let right_indices = column_indices(&right_schema, &right_keys)?;
62    validate_key_types(&left_schema, &right_schema, &left_indices, &right_indices)?;
63
64    let mut left_reader = io_utils::open_csv_reader_from_path(&args.left, left_delimiter, true)?;
65    let mut right_reader = io_utils::open_csv_reader_from_path(&args.right, right_delimiter, true)?;
66
67    let left_headers = io_utils::reader_headers(&mut left_reader, left_encoding)?;
68    let right_headers = io_utils::reader_headers(&mut right_reader, right_encoding)?;
69    left_schema
70        .validate_headers(&left_headers)
71        .with_context(|| format!("Validating left headers for {:?}", args.left))?;
72    right_schema
73        .validate_headers(&right_headers)
74        .with_context(|| format!("Validating right headers for {:?}", args.right))?;
75
76    let mut right_lookup = build_right_lookup(
77        &mut right_reader,
78        &right_schema,
79        &right_indices,
80        right_encoding,
81    )?;
82
83    let (output_headers, right_columns) =
84        build_output_headers(&left_headers, &right_headers, &right_indices);
85
86    let mut writer =
87        io_utils::open_csv_writer(args.output.as_deref(), output_delimiter, output_encoding)?;
88    writer
89        .write_record(&output_headers)
90        .context("Writing joined headers")?;
91
92    let mut output_rows = 0usize;
93    let mut matched_rows = 0usize;
94    let include_unmatched_left = matches!(args.kind, JoinKind::Left | JoinKind::Full);
95    let include_unmatched_right = matches!(args.kind, JoinKind::Right | JoinKind::Full);
96    let key_pairs: Vec<(usize, usize)> = left_indices
97        .iter()
98        .cloned()
99        .zip(right_indices.iter().cloned())
100        .collect();
101
102    for (row_idx, record) in left_reader.byte_records().enumerate() {
103        let record = record.with_context(|| format!("Reading left row {}", row_idx + 2))?;
104        let mut decoded = io_utils::decode_record(&record, left_encoding)?;
105        left_schema.apply_replacements_to_row(&mut decoded);
106        let key = build_key(&decoded, &left_schema, &left_indices)?;
107        let mut matched_any = false;
108        if let Some(bucket) = right_lookup.get_mut(&key) {
109            for entry in bucket.iter_mut() {
110                matched_any = true;
111                entry.matched = true;
112                matched_rows += 1;
113                let mut combined = decoded.clone();
114                combined.extend(
115                    right_columns
116                        .iter()
117                        .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
118                );
119                writer
120                    .write_record(&combined)
121                    .context("Writing joined row")?;
122                output_rows += 1;
123            }
124        }
125
126        if !matched_any && include_unmatched_left {
127            let mut combined = decoded.clone();
128            combined.extend(right_columns.iter().map(|_| String::new()));
129            writer
130                .write_record(&combined)
131                .context("Writing left outer row")?;
132            output_rows += 1;
133        }
134    }
135
136    if include_unmatched_right {
137        for bucket in right_lookup.values() {
138            for entry in bucket.iter() {
139                if entry.matched {
140                    continue;
141                }
142                let mut left_part = vec![String::new(); left_headers.len()];
143                for (left_idx, right_idx) in &key_pairs {
144                    let value = entry.record.get(*right_idx).cloned().unwrap_or_default();
145                    left_part[*left_idx] = value;
146                }
147                let mut combined = left_part;
148                combined.extend(
149                    right_columns
150                        .iter()
151                        .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
152                );
153                writer
154                    .write_record(&combined)
155                    .context("Writing right outer row")?;
156                output_rows += 1;
157            }
158        }
159    }
160
161    writer.flush().context("Flushing join output")?;
162    info!("Join complete: {output_rows} output row(s), {matched_rows} matched row(s)");
163    Ok(())
164}
165
166fn parse_key_list(value: &str) -> Result<Vec<String>> {
167    let parts = value
168        .split(',')
169        .map(|s| s.trim())
170        .filter(|s| !s.is_empty())
171        .map(|s| s.to_string())
172        .collect::<Vec<_>>();
173    if parts.is_empty() {
174        Err(anyhow!("Join key list cannot be empty"))
175    } else {
176        Ok(parts)
177    }
178}
179
180fn load_schema(
181    path: &PathBuf,
182    schema_path: Option<&PathBuf>,
183    delimiter: u8,
184    encoding: &'static Encoding,
185) -> Result<Schema> {
186    if let Some(schema_path) = schema_path {
187        Schema::load(schema_path).with_context(|| format!("Loading schema from {schema_path:?}"))
188    } else {
189        schema::infer_schema(path, 0, delimiter, encoding)
190            .with_context(|| format!("Inferring schema from {path:?}"))
191    }
192}
193
194fn column_indices(schema: &Schema, columns: &[String]) -> Result<Vec<usize>> {
195    columns
196        .iter()
197        .map(|name| {
198            schema
199                .column_index(name)
200                .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))
201        })
202        .collect()
203}
204
205fn validate_key_types(
206    left_schema: &Schema,
207    right_schema: &Schema,
208    left_indices: &[usize],
209    right_indices: &[usize],
210) -> Result<()> {
211    for (l_idx, r_idx) in left_indices.iter().zip(right_indices.iter()) {
212        let left_type = &left_schema.columns[*l_idx].datatype;
213        let right_type = &right_schema.columns[*r_idx].datatype;
214        if !same_type(left_type, right_type) {
215            return Err(anyhow!(
216                "Type mismatch for join keys: left {left_type:?} vs right {right_type:?}"
217            ));
218        }
219    }
220    Ok(())
221}
222
223fn same_type(left: &ColumnType, right: &ColumnType) -> bool {
224    match (left, right) {
225        (ColumnType::Integer, ColumnType::Float) | (ColumnType::Float, ColumnType::Integer) => true,
226        _ => left == right,
227    }
228}
229
230struct RightRow {
231    record: Vec<String>,
232    matched: bool,
233}
234
235fn build_right_lookup(
236    reader: &mut csv::Reader<Box<dyn std::io::Read>>,
237    schema: &Schema,
238    key_indices: &[usize],
239    encoding: &'static Encoding,
240) -> Result<HashMap<String, Vec<RightRow>>> {
241    let mut map: HashMap<String, Vec<RightRow>> = HashMap::new();
242    for (row_idx, record) in reader.byte_records().enumerate() {
243        let record = record.with_context(|| format!("Reading right row {}", row_idx + 2))?;
244        let mut decoded = io_utils::decode_record(&record, encoding)?;
245        schema.apply_replacements_to_row(&mut decoded);
246        let key = build_key(&decoded, schema, key_indices)?;
247        map.entry(key).or_default().push(RightRow {
248            record: decoded,
249            matched: false,
250        });
251    }
252    Ok(map)
253}
254
255fn build_key(record: &[String], schema: &Schema, key_indices: &[usize]) -> Result<String> {
256    let mut parts = Vec::with_capacity(key_indices.len());
257    for idx in key_indices {
258        let column = &schema.columns[*idx];
259        let raw = record.get(*idx).map(|s| s.as_str()).unwrap_or("");
260        let normalized = column.normalize_value(raw);
261        let parsed = parse_typed_value(normalized.as_ref(), &column.datatype)
262            .with_context(|| format!("Parsing join key for column '{}'", column.name))?;
263        if let Some(value) = parsed {
264            parts.push(value.as_display());
265        } else {
266            parts.push(String::new());
267        }
268    }
269    Ok(parts.join(KEY_SEPARATOR))
270}
271
272fn build_output_headers(
273    left_headers: &[String],
274    right_headers: &[String],
275    right_key_indices: &[usize],
276) -> (Vec<String>, Vec<usize>) {
277    use std::collections::HashSet;
278
279    let mut headers = left_headers.to_vec();
280    let mut seen: HashSet<String> = headers.iter().cloned().collect();
281    let mut right_columns = Vec::new();
282
283    for (idx, name) in right_headers.iter().enumerate() {
284        if right_key_indices.contains(&idx) {
285            continue;
286        }
287        let mut candidate = name.clone();
288        if seen.contains(&candidate) {
289            let mut counter = 1usize;
290            let base = candidate.clone();
291            while seen.contains(&candidate) {
292                candidate = format!("right_{base}_{counter}");
293                counter += 1;
294            }
295        }
296        seen.insert(candidate.clone());
297        headers.push(candidate);
298        right_columns.push(idx);
299    }
300
301    (headers, right_columns)
302}