1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::{Context, Result, anyhow};
4use encoding_rs::Encoding;
5use log::info;
6
7use crate::{
8 cli::{JoinArgs, JoinKind},
9 data::parse_typed_value,
10 io_utils,
11 schema::{self, ColumnType, Schema},
12};
13
14const KEY_SEPARATOR: &str = "\u{1f}";
15
16pub fn execute(args: &JoinArgs) -> Result<()> {
17 if args.left_key.is_empty() || args.right_key.is_empty() {
18 return Err(anyhow!("Join requires --left-key and --right-key"));
19 }
20 if io_utils::is_dash(&args.right) {
21 return Err(anyhow!(
22 "Right input cannot be stdin for join operations; provide a file path"
23 ));
24 }
25 if io_utils::is_dash(&args.left) && args.left_schema.is_none() {
26 return Err(anyhow!(
27 "Joining from stdin requires --left-schema (or --left-meta) to describe the schema"
28 ));
29 }
30
31 let left_keys = parse_key_list(&args.left_key)?;
32 let right_keys = parse_key_list(&args.right_key)?;
33 if left_keys.len() != right_keys.len() {
34 return Err(anyhow!(
35 "Left and right join keys must contain the same number of columns"
36 ));
37 }
38
39 let left_delimiter = io_utils::resolve_input_delimiter(&args.left, args.delimiter);
40 let right_delimiter = io_utils::resolve_input_delimiter(&args.right, args.delimiter);
41 let output_delimiter =
42 io_utils::resolve_output_delimiter(args.output.as_deref(), None, left_delimiter);
43 let left_encoding = io_utils::resolve_encoding(args.left_encoding.as_deref())?;
44 let right_encoding = io_utils::resolve_encoding(args.right_encoding.as_deref())?;
45 let output_encoding = io_utils::resolve_encoding(args.output_encoding.as_deref())?;
46
47 let left_schema = load_schema(
48 &args.left,
49 args.left_schema.as_ref(),
50 left_delimiter,
51 left_encoding,
52 )?;
53 let right_schema = load_schema(
54 &args.right,
55 args.right_schema.as_ref(),
56 right_delimiter,
57 right_encoding,
58 )?;
59
60 let left_indices = column_indices(&left_schema, &left_keys)?;
61 let right_indices = column_indices(&right_schema, &right_keys)?;
62 validate_key_types(&left_schema, &right_schema, &left_indices, &right_indices)?;
63
64 let mut left_reader = io_utils::open_csv_reader_from_path(&args.left, left_delimiter, true)?;
65 let mut right_reader = io_utils::open_csv_reader_from_path(&args.right, right_delimiter, true)?;
66
67 let left_headers = io_utils::reader_headers(&mut left_reader, left_encoding)?;
68 let right_headers = io_utils::reader_headers(&mut right_reader, right_encoding)?;
69 left_schema
70 .validate_headers(&left_headers)
71 .with_context(|| format!("Validating left headers for {:?}", args.left))?;
72 right_schema
73 .validate_headers(&right_headers)
74 .with_context(|| format!("Validating right headers for {:?}", args.right))?;
75
76 let mut right_lookup = build_right_lookup(
77 &mut right_reader,
78 &right_schema,
79 &right_indices,
80 right_encoding,
81 )?;
82
83 let (output_headers, right_columns) =
84 build_output_headers(&left_headers, &right_headers, &right_indices);
85
86 let mut writer =
87 io_utils::open_csv_writer(args.output.as_deref(), output_delimiter, output_encoding)?;
88 writer
89 .write_record(&output_headers)
90 .context("Writing joined headers")?;
91
92 let mut output_rows = 0usize;
93 let mut matched_rows = 0usize;
94 let include_unmatched_left = matches!(args.kind, JoinKind::Left | JoinKind::Full);
95 let include_unmatched_right = matches!(args.kind, JoinKind::Right | JoinKind::Full);
96 let key_pairs: Vec<(usize, usize)> = left_indices
97 .iter()
98 .cloned()
99 .zip(right_indices.iter().cloned())
100 .collect();
101
102 for (row_idx, record) in left_reader.byte_records().enumerate() {
103 let record = record.with_context(|| format!("Reading left row {}", row_idx + 2))?;
104 let mut decoded = io_utils::decode_record(&record, left_encoding)?;
105 left_schema.apply_replacements_to_row(&mut decoded);
106 let key = build_key(&decoded, &left_schema, &left_indices)?;
107 let mut matched_any = false;
108 if let Some(bucket) = right_lookup.get_mut(&key) {
109 for entry in bucket.iter_mut() {
110 matched_any = true;
111 entry.matched = true;
112 matched_rows += 1;
113 let mut combined = decoded.clone();
114 combined.extend(
115 right_columns
116 .iter()
117 .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
118 );
119 writer
120 .write_record(&combined)
121 .context("Writing joined row")?;
122 output_rows += 1;
123 }
124 }
125
126 if !matched_any && include_unmatched_left {
127 let mut combined = decoded.clone();
128 combined.extend(right_columns.iter().map(|_| String::new()));
129 writer
130 .write_record(&combined)
131 .context("Writing left outer row")?;
132 output_rows += 1;
133 }
134 }
135
136 if include_unmatched_right {
137 for bucket in right_lookup.values() {
138 for entry in bucket.iter() {
139 if entry.matched {
140 continue;
141 }
142 let mut left_part = vec![String::new(); left_headers.len()];
143 for (left_idx, right_idx) in &key_pairs {
144 let value = entry.record.get(*right_idx).cloned().unwrap_or_default();
145 left_part[*left_idx] = value;
146 }
147 let mut combined = left_part;
148 combined.extend(
149 right_columns
150 .iter()
151 .map(|idx| entry.record.get(*idx).cloned().unwrap_or_default()),
152 );
153 writer
154 .write_record(&combined)
155 .context("Writing right outer row")?;
156 output_rows += 1;
157 }
158 }
159 }
160
161 writer.flush().context("Flushing join output")?;
162 info!("Join complete: {output_rows} output row(s), {matched_rows} matched row(s)");
163 Ok(())
164}
165
166fn parse_key_list(value: &str) -> Result<Vec<String>> {
167 let parts = value
168 .split(',')
169 .map(|s| s.trim())
170 .filter(|s| !s.is_empty())
171 .map(|s| s.to_string())
172 .collect::<Vec<_>>();
173 if parts.is_empty() {
174 Err(anyhow!("Join key list cannot be empty"))
175 } else {
176 Ok(parts)
177 }
178}
179
180fn load_schema(
181 path: &PathBuf,
182 schema_path: Option<&PathBuf>,
183 delimiter: u8,
184 encoding: &'static Encoding,
185) -> Result<Schema> {
186 if let Some(schema_path) = schema_path {
187 Schema::load(schema_path).with_context(|| format!("Loading schema from {schema_path:?}"))
188 } else {
189 schema::infer_schema(path, 0, delimiter, encoding)
190 .with_context(|| format!("Inferring schema from {path:?}"))
191 }
192}
193
194fn column_indices(schema: &Schema, columns: &[String]) -> Result<Vec<usize>> {
195 columns
196 .iter()
197 .map(|name| {
198 schema
199 .column_index(name)
200 .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))
201 })
202 .collect()
203}
204
205fn validate_key_types(
206 left_schema: &Schema,
207 right_schema: &Schema,
208 left_indices: &[usize],
209 right_indices: &[usize],
210) -> Result<()> {
211 for (l_idx, r_idx) in left_indices.iter().zip(right_indices.iter()) {
212 let left_type = &left_schema.columns[*l_idx].datatype;
213 let right_type = &right_schema.columns[*r_idx].datatype;
214 if !same_type(left_type, right_type) {
215 return Err(anyhow!(
216 "Type mismatch for join keys: left {left_type:?} vs right {right_type:?}"
217 ));
218 }
219 }
220 Ok(())
221}
222
223fn same_type(left: &ColumnType, right: &ColumnType) -> bool {
224 match (left, right) {
225 (ColumnType::Integer, ColumnType::Float) | (ColumnType::Float, ColumnType::Integer) => true,
226 _ => left == right,
227 }
228}
229
230struct RightRow {
231 record: Vec<String>,
232 matched: bool,
233}
234
235fn build_right_lookup(
236 reader: &mut csv::Reader<Box<dyn std::io::Read>>,
237 schema: &Schema,
238 key_indices: &[usize],
239 encoding: &'static Encoding,
240) -> Result<HashMap<String, Vec<RightRow>>> {
241 let mut map: HashMap<String, Vec<RightRow>> = HashMap::new();
242 for (row_idx, record) in reader.byte_records().enumerate() {
243 let record = record.with_context(|| format!("Reading right row {}", row_idx + 2))?;
244 let mut decoded = io_utils::decode_record(&record, encoding)?;
245 schema.apply_replacements_to_row(&mut decoded);
246 let key = build_key(&decoded, schema, key_indices)?;
247 map.entry(key).or_default().push(RightRow {
248 record: decoded,
249 matched: false,
250 });
251 }
252 Ok(map)
253}
254
255fn build_key(record: &[String], schema: &Schema, key_indices: &[usize]) -> Result<String> {
256 let mut parts = Vec::with_capacity(key_indices.len());
257 for idx in key_indices {
258 let column = &schema.columns[*idx];
259 let raw = record.get(*idx).map(|s| s.as_str()).unwrap_or("");
260 let normalized = column.normalize_value(raw);
261 let parsed = parse_typed_value(normalized.as_ref(), &column.datatype)
262 .with_context(|| format!("Parsing join key for column '{}'", column.name))?;
263 if let Some(value) = parsed {
264 parts.push(value.as_display());
265 } else {
266 parts.push(String::new());
267 }
268 }
269 Ok(parts.join(KEY_SEPARATOR))
270}
271
272fn build_output_headers(
273 left_headers: &[String],
274 right_headers: &[String],
275 right_key_indices: &[usize],
276) -> (Vec<String>, Vec<usize>) {
277 use std::collections::HashSet;
278
279 let mut headers = left_headers.to_vec();
280 let mut seen: HashSet<String> = headers.iter().cloned().collect();
281 let mut right_columns = Vec::new();
282
283 for (idx, name) in right_headers.iter().enumerate() {
284 if right_key_indices.contains(&idx) {
285 continue;
286 }
287 let mut candidate = name.clone();
288 if seen.contains(&candidate) {
289 let mut counter = 1usize;
290 let base = candidate.clone();
291 while seen.contains(&candidate) {
292 candidate = format!("right_{base}_{counter}");
293 counter += 1;
294 }
295 }
296 seen.insert(candidate.clone());
297 headers.push(candidate);
298 right_columns.push(idx);
299 }
300
301 (headers, right_columns)
302}