1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use encoding_rs::Encoding;
5use log::info;
6
7use crate::{
8 data::parse_typed_value,
9 io_utils,
10 schema::{self, ColumnType, Schema},
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14pub enum JoinKind {
15 #[default]
16 Inner,
17 Left,
18 Right,
19 Full,
20}
21
22#[derive(Debug, Clone)]
23pub struct JoinArgs {
24 pub left: PathBuf,
25 pub right: PathBuf,
26 pub output: Option<PathBuf>,
27 pub left_key: String,
28 pub right_key: String,
29 pub kind: JoinKind,
30 pub left_schema: Option<PathBuf>,
31 pub right_schema: Option<PathBuf>,
32 pub delimiter: Option<u8>,
33 pub left_encoding: Option<String>,
34 pub right_encoding: Option<String>,
35 pub output_encoding: Option<String>,
36}
37
38const KEY_SEPARATOR: &str = "\u{1f}";
39
40pub fn execute(args: &JoinArgs) -> Result<()> {
41 if args.left_key.is_empty() || args.right_key.is_empty() {
42 return Err(anyhow!("Join requires --left-key and --right-key"));
43 }
44 if io_utils::is_dash(&args.right) {
45 return Err(anyhow!(
46 "Right input cannot be stdin for join operations; provide a file path"
47 ));
48 }
49 if io_utils::is_dash(&args.left) && args.left_schema.is_none() {
50 return Err(anyhow!(
51 "Joining from stdin requires --left-schema (or --left-meta) to describe the schema"
52 ));
53 }
54
55 let left_keys = parse_key_list(&args.left_key)?;
56 let right_keys = parse_key_list(&args.right_key)?;
57 if left_keys.len() != right_keys.len() {
58 return Err(anyhow!(
59 "Left and right join keys must contain the same number of columns"
60 ));
61 }
62
63 let left_delimiter = io_utils::resolve_input_delimiter(&args.left, args.delimiter);
64 let right_delimiter = io_utils::resolve_input_delimiter(&args.right, args.delimiter);
65 let output_delimiter =
66 io_utils::resolve_output_delimiter(args.output.as_deref(), None, left_delimiter);
67 let left_encoding = io_utils::resolve_encoding(args.left_encoding.as_deref())?;
68 let right_encoding = io_utils::resolve_encoding(args.right_encoding.as_deref())?;
69 let output_encoding = io_utils::resolve_encoding(args.output_encoding.as_deref())?;
70
71 let left_schema = load_schema(
72 &args.left,
73 args.left_schema.as_ref(),
74 left_delimiter,
75 left_encoding,
76 )?;
77 let right_schema = load_schema(
78 &args.right,
79 args.right_schema.as_ref(),
80 right_delimiter,
81 right_encoding,
82 )?;
83
84 let left_indices = column_indices(&left_schema, &left_keys)?;
85 let right_indices = column_indices(&right_schema, &right_keys)?;
86 validate_key_types(&left_schema, &right_schema, &left_indices, &right_indices)?;
87
88 let left_expects_headers = left_schema.expects_headers();
89 let right_expects_headers = right_schema.expects_headers();
90
91 let mut left_reader =
92 io_utils::open_csv_reader_from_path(&args.left, left_delimiter, left_expects_headers)?;
93 let mut right_reader =
94 io_utils::open_csv_reader_from_path(&args.right, right_delimiter, right_expects_headers)?;
95
96 let left_headers = if left_expects_headers {
97 let headers = io_utils::reader_headers(&mut left_reader, left_encoding)?;
98 left_schema
99 .validate_headers(&headers)
100 .with_context(|| format!("Validating left headers for {:?}", args.left))?;
101 headers
102 } else {
103 left_schema.headers()
104 };
105
106 let right_headers = if right_expects_headers {
107 let headers = io_utils::reader_headers(&mut right_reader, right_encoding)?;
108 right_schema
109 .validate_headers(&headers)
110 .with_context(|| format!("Validating right headers for {:?}", args.right))?;
111 headers
112 } else {
113 right_schema.headers()
114 };
115
116 let mut right_lookup = build_right_lookup(
117 &mut right_reader,
118 &right_schema,
119 &right_indices,
120 right_encoding,
121 )?;
122
123 let (output_headers, right_columns) =
124 build_output_headers(&left_headers, &right_headers, &right_indices);
125
126 let mut writer =
127 io_utils::open_csv_writer(args.output.as_deref(), output_delimiter, output_encoding)?;
128 writer
129 .write_record(&output_headers)
130 .context("Writing joined headers")?;
131
132 let mut output_rows = 0usize;
133 let mut matched_rows = 0usize;
134 let include_unmatched_left = matches!(args.kind, JoinKind::Left | JoinKind::Full);
135 let include_unmatched_right = matches!(args.kind, JoinKind::Right | JoinKind::Full);
136 let key_pairs: Vec<(usize, usize)> = left_indices
137 .iter()
138 .cloned()
139 .zip(right_indices.iter().cloned())
140 .collect();
141
142 for (row_idx, record) in left_reader.byte_records().enumerate() {
143 let record = record.with_context(|| format!("Reading left row {}", row_idx + 2))?;
144 let mut decoded = io_utils::decode_record(&record, left_encoding)?;
145 if left_schema.has_transformations() {
146 left_schema
147 .apply_transformations_to_row(&mut decoded)
148 .with_context(|| {
149 format!(
150 "Applying datatype mappings to left row {} in {:?}",
151 row_idx + 2,
152 args.left
153 )
154 })?;
155 }
156 left_schema.apply_replacements_to_row(&mut decoded);
157 let key = build_key(&decoded, &left_schema, &left_indices)?;
158 let mut matched_any = false;
159 if let Some(bucket) = right_lookup.get_mut(&key) {
160 for entry in bucket.iter_mut() {
161 matched_any = true;
162 entry.matched = true;
163 matched_rows += 1;
164 let mut combined = decoded.clone();
165 combined.extend(
166 right_columns
167 .iter()
168 .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
169 );
170 writer
171 .write_record(&combined)
172 .context("Writing joined row")?;
173 output_rows += 1;
174 }
175 }
176
177 if !matched_any && include_unmatched_left {
178 let mut combined = decoded.clone();
179 combined.extend(right_columns.iter().map(|_| String::new()));
180 writer
181 .write_record(&combined)
182 .context("Writing left outer row")?;
183 output_rows += 1;
184 }
185 }
186
187 if include_unmatched_right {
188 for bucket in right_lookup.values() {
189 for entry in bucket.iter() {
190 if entry.matched {
191 continue;
192 }
193 let mut left_part = vec![String::new(); left_headers.len()];
194 for (left_idx, right_idx) in &key_pairs {
195 let value = entry.record.get(*right_idx).cloned().unwrap_or_default();
196 left_part[*left_idx] = value;
197 }
198 let mut combined = left_part;
199 combined.extend(
200 right_columns
201 .iter()
202 .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
203 );
204 writer
205 .write_record(&combined)
206 .context("Writing right outer row")?;
207 output_rows += 1;
208 }
209 }
210 }
211
212 writer.flush().context("Flushing join output")?;
213 info!("Join complete: {output_rows} output row(s), {matched_rows} matched row(s)");
214 Ok(())
215}
216
217fn parse_key_list(value: &str) -> Result<Vec<String>> {
218 let parts = value
219 .split(',')
220 .map(|s| s.trim())
221 .filter(|s| !s.is_empty())
222 .map(|s| s.to_string())
223 .collect::<Vec<_>>();
224 if parts.is_empty() {
225 Err(anyhow!("Join key list cannot be empty"))
226 } else {
227 Ok(parts)
228 }
229}
230
231fn load_schema(
232 path: &PathBuf,
233 schema_path: Option<&PathBuf>,
234 delimiter: u8,
235 encoding: &'static Encoding,
236) -> Result<Schema> {
237 if let Some(schema_path) = schema_path {
238 Schema::load(schema_path).with_context(|| format!("Loading schema from {schema_path:?}"))
239 } else {
240 schema::infer_schema(path, 0, delimiter, encoding, None)
241 .with_context(|| format!("Inferring schema from {path:?}"))
242 }
243}
244
245fn column_indices(schema: &Schema, columns: &[String]) -> Result<Vec<usize>> {
246 columns
247 .iter()
248 .map(|name| {
249 schema
250 .column_index(name)
251 .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))
252 })
253 .collect()
254}
255
256fn validate_key_types(
257 left_schema: &Schema,
258 right_schema: &Schema,
259 left_indices: &[usize],
260 right_indices: &[usize],
261) -> Result<()> {
262 for (l_idx, r_idx) in left_indices.iter().zip(right_indices.iter()) {
263 let left_type = &left_schema.columns[*l_idx].datatype;
264 let right_type = &right_schema.columns[*r_idx].datatype;
265 if !same_type(left_type, right_type) {
266 return Err(anyhow!(
267 "Type mismatch for join keys: left {left_type:?} vs right {right_type:?}"
268 ));
269 }
270 }
271 Ok(())
272}
273
274fn same_type(left: &ColumnType, right: &ColumnType) -> bool {
275 match (left, right) {
276 (ColumnType::Integer, ColumnType::Float) | (ColumnType::Float, ColumnType::Integer) => true,
277 _ => left == right,
278 }
279}
280
281struct RightRow {
282 record: Vec<String>,
283 matched: bool,
284}
285
286fn build_right_lookup(
287 reader: &mut csv::Reader<Box<dyn std::io::Read>>,
288 schema: &Schema,
289 key_indices: &[usize],
290 encoding: &'static Encoding,
291) -> Result<HashMap<String, Vec<RightRow>>> {
292 let mut map: HashMap<String, Vec<RightRow>> = HashMap::new();
293 for (row_idx, record) in reader.byte_records().enumerate() {
294 let record = record.with_context(|| format!("Reading right row {}", row_idx + 2))?;
295 let mut decoded = io_utils::decode_record(&record, encoding)?;
296 if schema.has_transformations() {
297 schema
298 .apply_transformations_to_row(&mut decoded)
299 .with_context(|| {
300 format!("Applying datatype mappings to right row {}", row_idx + 2)
301 })?;
302 }
303 schema.apply_replacements_to_row(&mut decoded);
304 let key = build_key(&decoded, schema, key_indices)?;
305 map.entry(key).or_default().push(RightRow {
306 record: decoded,
307 matched: false,
308 });
309 }
310 Ok(map)
311}
312
313fn build_key(record: &[String], schema: &Schema, key_indices: &[usize]) -> Result<String> {
314 let mut parts = Vec::with_capacity(key_indices.len());
315 for idx in key_indices {
316 let column = &schema.columns[*idx];
317 let raw = record.get(*idx).map(|s| s.as_str()).unwrap_or("");
318 let normalized = column.normalize_value(raw);
319 let parsed = parse_typed_value(normalized.as_ref(), &column.datatype)
320 .with_context(|| format!("Parsing join key for column '{}'", column.name))?;
321 if let Some(value) = parsed {
322 parts.push(value.as_display());
323 } else {
324 parts.push(String::new());
325 }
326 }
327 Ok(parts.join(KEY_SEPARATOR))
328}
329
330fn build_output_headers(
331 left_headers: &[String],
332 right_headers: &[String],
333 right_key_indices: &[usize],
334) -> (Vec<String>, Vec<usize>) {
335 use std::collections::HashSet;
336
337 let mut headers = left_headers.to_vec();
338 let mut seen: HashSet<String> = headers.iter().cloned().collect();
339 let mut right_columns = Vec::new();
340
341 for (idx, name) in right_headers.iter().enumerate() {
342 if right_key_indices.contains(&idx) {
343 continue;
344 }
345 let mut candidate = name.clone();
346 if seen.contains(&candidate) {
347 let mut counter = 1usize;
348 let base = candidate.clone();
349 while seen.contains(&candidate) {
350 candidate = format!("right_{base}_{counter}");
351 counter += 1;
352 }
353 }
354 seen.insert(candidate.clone());
355 headers.push(candidate);
356 right_columns.push(idx);
357 }
358
359 (headers, right_columns)
360}