csv_managed/
join.rs

1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use encoding_rs::Encoding;
5use log::info;
6
7use crate::{
8    data::parse_typed_value,
9    io_utils,
10    schema::{self, ColumnType, Schema},
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum JoinKind {
15    #[default]
16    Inner,
17    Left,
18    Right,
19    Full,
20}
21
22#[derive(Debug, Clone)]
23pub struct JoinArgs {
24    pub left: PathBuf,
25    pub right: PathBuf,
26    pub output: Option<PathBuf>,
27    pub left_key: String,
28    pub right_key: String,
29    pub kind: JoinKind,
30    pub left_schema: Option<PathBuf>,
31    pub right_schema: Option<PathBuf>,
32    pub delimiter: Option<u8>,
33    pub left_encoding: Option<String>,
34    pub right_encoding: Option<String>,
35    pub output_encoding: Option<String>,
36}
37
38const KEY_SEPARATOR: &str = "\u{1f}";
39
40pub fn execute(args: &JoinArgs) -> Result<()> {
41    if args.left_key.is_empty() || args.right_key.is_empty() {
42        return Err(anyhow!("Join requires --left-key and --right-key"));
43    }
44    if io_utils::is_dash(&args.right) {
45        return Err(anyhow!(
46            "Right input cannot be stdin for join operations; provide a file path"
47        ));
48    }
49    if io_utils::is_dash(&args.left) && args.left_schema.is_none() {
50        return Err(anyhow!(
51            "Joining from stdin requires --left-schema (or --left-meta) to describe the schema"
52        ));
53    }
54
55    let left_keys = parse_key_list(&args.left_key)?;
56    let right_keys = parse_key_list(&args.right_key)?;
57    if left_keys.len() != right_keys.len() {
58        return Err(anyhow!(
59            "Left and right join keys must contain the same number of columns"
60        ));
61    }
62
63    let left_delimiter = io_utils::resolve_input_delimiter(&args.left, args.delimiter);
64    let right_delimiter = io_utils::resolve_input_delimiter(&args.right, args.delimiter);
65    let output_delimiter =
66        io_utils::resolve_output_delimiter(args.output.as_deref(), None, left_delimiter);
67    let left_encoding = io_utils::resolve_encoding(args.left_encoding.as_deref())?;
68    let right_encoding = io_utils::resolve_encoding(args.right_encoding.as_deref())?;
69    let output_encoding = io_utils::resolve_encoding(args.output_encoding.as_deref())?;
70
71    let left_schema = load_schema(
72        &args.left,
73        args.left_schema.as_ref(),
74        left_delimiter,
75        left_encoding,
76    )?;
77    let right_schema = load_schema(
78        &args.right,
79        args.right_schema.as_ref(),
80        right_delimiter,
81        right_encoding,
82    )?;
83
84    let left_indices = column_indices(&left_schema, &left_keys)?;
85    let right_indices = column_indices(&right_schema, &right_keys)?;
86    validate_key_types(&left_schema, &right_schema, &left_indices, &right_indices)?;
87
88    let left_expects_headers = left_schema.expects_headers();
89    let right_expects_headers = right_schema.expects_headers();
90
91    let mut left_reader =
92        io_utils::open_csv_reader_from_path(&args.left, left_delimiter, left_expects_headers)?;
93    let mut right_reader =
94        io_utils::open_csv_reader_from_path(&args.right, right_delimiter, right_expects_headers)?;
95
96    let left_headers = if left_expects_headers {
97        let headers = io_utils::reader_headers(&mut left_reader, left_encoding)?;
98        left_schema
99            .validate_headers(&headers)
100            .with_context(|| format!("Validating left headers for {:?}", args.left))?;
101        headers
102    } else {
103        left_schema.headers()
104    };
105
106    let right_headers = if right_expects_headers {
107        let headers = io_utils::reader_headers(&mut right_reader, right_encoding)?;
108        right_schema
109            .validate_headers(&headers)
110            .with_context(|| format!("Validating right headers for {:?}", args.right))?;
111        headers
112    } else {
113        right_schema.headers()
114    };
115
116    let mut right_lookup = build_right_lookup(
117        &mut right_reader,
118        &right_schema,
119        &right_indices,
120        right_encoding,
121    )?;
122
123    let (output_headers, right_columns) =
124        build_output_headers(&left_headers, &right_headers, &right_indices);
125
126    let mut writer =
127        io_utils::open_csv_writer(args.output.as_deref(), output_delimiter, output_encoding)?;
128    writer
129        .write_record(&output_headers)
130        .context("Writing joined headers")?;
131
132    let mut output_rows = 0usize;
133    let mut matched_rows = 0usize;
134    let include_unmatched_left = matches!(args.kind, JoinKind::Left | JoinKind::Full);
135    let include_unmatched_right = matches!(args.kind, JoinKind::Right | JoinKind::Full);
136    let key_pairs: Vec<(usize, usize)> = left_indices
137        .iter()
138        .cloned()
139        .zip(right_indices.iter().cloned())
140        .collect();
141
142    for (row_idx, record) in left_reader.byte_records().enumerate() {
143        let record = record.with_context(|| format!("Reading left row {}", row_idx + 2))?;
144        let mut decoded = io_utils::decode_record(&record, left_encoding)?;
145        if left_schema.has_transformations() {
146            left_schema
147                .apply_transformations_to_row(&mut decoded)
148                .with_context(|| {
149                    format!(
150                        "Applying datatype mappings to left row {} in {:?}",
151                        row_idx + 2,
152                        args.left
153                    )
154                })?;
155        }
156        left_schema.apply_replacements_to_row(&mut decoded);
157        let key = build_key(&decoded, &left_schema, &left_indices)?;
158        let mut matched_any = false;
159        if let Some(bucket) = right_lookup.get_mut(&key) {
160            for entry in bucket.iter_mut() {
161                matched_any = true;
162                entry.matched = true;
163                matched_rows += 1;
164                let mut combined = decoded.clone();
165                combined.extend(
166                    right_columns
167                        .iter()
168                        .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
169                );
170                writer
171                    .write_record(&combined)
172                    .context("Writing joined row")?;
173                output_rows += 1;
174            }
175        }
176
177        if !matched_any && include_unmatched_left {
178            let mut combined = decoded.clone();
179            combined.extend(right_columns.iter().map(|_| String::new()));
180            writer
181                .write_record(&combined)
182                .context("Writing left outer row")?;
183            output_rows += 1;
184        }
185    }
186
187    if include_unmatched_right {
188        for bucket in right_lookup.values() {
189            for entry in bucket.iter() {
190                if entry.matched {
191                    continue;
192                }
193                let mut left_part = vec![String::new(); left_headers.len()];
194                for (left_idx, right_idx) in &key_pairs {
195                    let value = entry.record.get(*right_idx).cloned().unwrap_or_default();
196                    left_part[*left_idx] = value;
197                }
198                let mut combined = left_part;
199                combined.extend(
200                    right_columns
201                        .iter()
202                        .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
203                );
204                writer
205                    .write_record(&combined)
206                    .context("Writing right outer row")?;
207                output_rows += 1;
208            }
209        }
210    }
211
212    writer.flush().context("Flushing join output")?;
213    info!("Join complete: {output_rows} output row(s), {matched_rows} matched row(s)");
214    Ok(())
215}
216
217fn parse_key_list(value: &str) -> Result<Vec<String>> {
218    let parts = value
219        .split(',')
220        .map(|s| s.trim())
221        .filter(|s| !s.is_empty())
222        .map(|s| s.to_string())
223        .collect::<Vec<_>>();
224    if parts.is_empty() {
225        Err(anyhow!("Join key list cannot be empty"))
226    } else {
227        Ok(parts)
228    }
229}
230
231fn load_schema(
232    path: &PathBuf,
233    schema_path: Option<&PathBuf>,
234    delimiter: u8,
235    encoding: &'static Encoding,
236) -> Result<Schema> {
237    if let Some(schema_path) = schema_path {
238        Schema::load(schema_path).with_context(|| format!("Loading schema from {schema_path:?}"))
239    } else {
240        schema::infer_schema(path, 0, delimiter, encoding, None)
241            .with_context(|| format!("Inferring schema from {path:?}"))
242    }
243}
244
245fn column_indices(schema: &Schema, columns: &[String]) -> Result<Vec<usize>> {
246    columns
247        .iter()
248        .map(|name| {
249            schema
250                .column_index(name)
251                .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))
252        })
253        .collect()
254}
255
256fn validate_key_types(
257    left_schema: &Schema,
258    right_schema: &Schema,
259    left_indices: &[usize],
260    right_indices: &[usize],
261) -> Result<()> {
262    for (l_idx, r_idx) in left_indices.iter().zip(right_indices.iter()) {
263        let left_type = &left_schema.columns[*l_idx].datatype;
264        let right_type = &right_schema.columns[*r_idx].datatype;
265        if !same_type(left_type, right_type) {
266            return Err(anyhow!(
267                "Type mismatch for join keys: left {left_type:?} vs right {right_type:?}"
268            ));
269        }
270    }
271    Ok(())
272}
273
274fn same_type(left: &ColumnType, right: &ColumnType) -> bool {
275    match (left, right) {
276        (ColumnType::Integer, ColumnType::Float) | (ColumnType::Float, ColumnType::Integer) => true,
277        _ => left == right,
278    }
279}
280
281struct RightRow {
282    record: Vec<String>,
283    matched: bool,
284}
285
286fn build_right_lookup(
287    reader: &mut csv::Reader<Box<dyn std::io::Read>>,
288    schema: &Schema,
289    key_indices: &[usize],
290    encoding: &'static Encoding,
291) -> Result<HashMap<String, Vec<RightRow>>> {
292    let mut map: HashMap<String, Vec<RightRow>> = HashMap::new();
293    for (row_idx, record) in reader.byte_records().enumerate() {
294        let record = record.with_context(|| format!("Reading right row {}", row_idx + 2))?;
295        let mut decoded = io_utils::decode_record(&record, encoding)?;
296        if schema.has_transformations() {
297            schema
298                .apply_transformations_to_row(&mut decoded)
299                .with_context(|| {
300                    format!("Applying datatype mappings to right row {}", row_idx + 2)
301                })?;
302        }
303        schema.apply_replacements_to_row(&mut decoded);
304        let key = build_key(&decoded, schema, key_indices)?;
305        map.entry(key).or_default().push(RightRow {
306            record: decoded,
307            matched: false,
308        });
309    }
310    Ok(map)
311}
312
313fn build_key(record: &[String], schema: &Schema, key_indices: &[usize]) -> Result<String> {
314    let mut parts = Vec::with_capacity(key_indices.len());
315    for idx in key_indices {
316        let column = &schema.columns[*idx];
317        let raw = record.get(*idx).map(|s| s.as_str()).unwrap_or("");
318        let normalized = column.normalize_value(raw);
319        let parsed = parse_typed_value(normalized.as_ref(), &column.datatype)
320            .with_context(|| format!("Parsing join key for column '{}'", column.name))?;
321        if let Some(value) = parsed {
322            parts.push(value.as_display());
323        } else {
324            parts.push(String::new());
325        }
326    }
327    Ok(parts.join(KEY_SEPARATOR))
328}
329
330fn build_output_headers(
331    left_headers: &[String],
332    right_headers: &[String],
333    right_key_indices: &[usize],
334) -> (Vec<String>, Vec<usize>) {
335    use std::collections::HashSet;
336
337    let mut headers = left_headers.to_vec();
338    let mut seen: HashSet<String> = headers.iter().cloned().collect();
339    let mut right_columns = Vec::new();
340
341    for (idx, name) in right_headers.iter().enumerate() {
342        if right_key_indices.contains(&idx) {
343            continue;
344        }
345        let mut candidate = name.clone();
346        if seen.contains(&candidate) {
347            let mut counter = 1usize;
348            let base = candidate.clone();
349            while seen.contains(&candidate) {
350                candidate = format!("right_{base}_{counter}");
351                counter += 1;
352            }
353        }
354        seen.insert(candidate.clone());
355        headers.push(candidate);
356        right_columns.push(idx);
357    }
358
359    (headers, right_columns)
360}