1use std::sync::Arc;
4
5use arrow::array::Int64Array;
6use arrow::datatypes::{DataType, Field, Schema};
7use arrow::record_batch::RecordBatch;
8use regex::Regex;
9use std::sync::LazyLock;
10
11use datafusion::prelude::SessionContext;
12
13use crate::SqlError;
14use crate::SqlResult;
15
16static KEY_COL_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
18 Regex::new(
19 r"(?i)((?:\w+|`[^`]+`))\.((?:\w+|`[^`]+`))\s*=\s*((?:\w+|`[^`]+`))\.((?:\w+|`[^`]+`))",
20 )
21 .ok()
22});
23
24static MERGE_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
25 Regex::new(
26 r"(?is)^\s*MERGE\s+INTO\s+([`\w.:/-]+)\s+USING\s+([`\w.]+)\s+ON\s+(.+?)(?:\s+WHEN\s+MATCHED\s+THEN\s+UPDATE\s+SET\s+.+?)?(?:\s+WHEN\s+NOT\s+MATCHED\s+THEN\s+INSERT\s*(?:\([^)]*\))?\s*(?:VALUES\s*\([^)]*\)|\*)?)?\s*$",
27 )
28 .ok()
29});
30
31#[derive(Debug, Clone, Default, PartialEq, Eq)]
33pub struct MergeResult {
34 pub rows_inserted: u64,
35 pub rows_updated: u64,
36 pub rows_deleted: u64,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
41#[error("MERGE INTO is only supported for delta: and iceberg: targets (got {target})")]
42pub struct MergeTargetUnsupportedError {
43 pub target: String,
44}
45
46pub async fn execute_merge_sql(ctx: &SessionContext, sql: &str) -> SqlResult<Vec<RecordBatch>> {
48 let caps = MERGE_RE
49 .as_ref()
50 .ok_or_else(|| SqlError::DataFusion {
51 message: "MERGE regex failed to compile".into(),
52 })?
53 .captures(sql)
54 .ok_or_else(|| SqlError::Unsupported {
55 feature: "MERGE INTO syntax".into(),
56 })?;
57 let target = caps[1].trim_matches('`').to_string();
58 let source_table = caps[2].trim_matches('`').to_string();
59 let on_clause = caps[3].trim();
60 let has_matched = caps
61 .get(4)
62 .and_then(|m| {
63 let s = m.as_str().trim();
64 if s.is_empty() { None } else { Some(s) }
65 })
66 .is_some();
67 let has_not_matched = caps
68 .get(5)
69 .and_then(|m| {
70 let s = m.as_str().trim();
71 if s.is_empty() { None } else { Some(s) }
72 })
73 .is_some();
74 if !has_matched && !has_not_matched {
75 return Err(SqlError::Unsupported {
76 feature: "MERGE INTO requires at least one WHEN MATCHED or WHEN NOT MATCHED clause"
77 .into(),
78 });
79 }
80
81 let merge_key: String = KEY_COL_RE
82 .as_ref()
83 .ok_or_else(|| SqlError::DataFusion { message: "KEY_COL regex failed to compile".into() })?
84 .captures(on_clause)
85 .ok_or_else(|| SqlError::Unsupported {
86 feature:
87 "MERGE ON clause must contain a qualified column equality (e.g. target.col = source.col)"
88 .into(),
89 })
90 .map(|caps| {
91 let _left_alias = caps[1].trim_matches('`');
92 let left_col = caps[2].trim_matches('`');
93 let right_alias = caps[3].trim_matches('`');
94 let right_col = caps[4].trim_matches('`');
95 let source_lower = source_table.to_lowercase();
97 if right_alias.to_lowercase() == source_lower {
98 left_col.to_string()
99 } else {
100 right_col.to_string()
101 }
102 })?;
103 let merge_key = merge_key.as_str();
104
105 let source_df = ctx
106 .table(&source_table)
107 .await
108 .map_err(|e| SqlError::DataFusion {
109 message: e.to_string(),
110 })?;
111 let source_batches = source_df
112 .collect()
113 .await
114 .map_err(|e| SqlError::DataFusion {
115 message: e.to_string(),
116 })?;
117
118 let metrics = if let Some(path) = target
119 .strip_prefix("delta:`")
120 .and_then(|p| p.strip_suffix('`'))
121 {
122 krishiv_connectors::lakehouse::merge_delta(path, source_batches, merge_key, true, true)
123 .await
124 .map_err(|e| SqlError::DataFusion {
125 message: e.to_string(),
126 })?
127 } else if let Some(path) = target.strip_prefix("delta.") {
128 krishiv_connectors::lakehouse::merge_delta(path, source_batches, merge_key, true, true)
129 .await
130 .map_err(|e| SqlError::DataFusion {
131 message: e.to_string(),
132 })?
133 } else if target.starts_with("iceberg:") {
134 let r = dry_run_merge(ctx, &target, source_batches, merge_key).await?;
135 krishiv_connectors::lakehouse::MergeDeltaResult {
136 rows_inserted: r.rows_inserted,
137 rows_updated: r.rows_updated,
138 rows_deleted: r.rows_deleted,
139 }
140 } else {
141 return Err(SqlError::DataFusion {
142 message: MergeTargetUnsupportedError { target }.to_string(),
143 });
144 };
145
146 Ok(vec![merge_result_batch(metrics)?])
147}
148
149async fn dry_run_merge(
159 ctx: &SessionContext,
160 target: &str,
161 source_batches: Vec<RecordBatch>,
162 merge_key: &str,
163) -> SqlResult<MergeResult> {
164 use arrow::compute::concat_batches;
165 use arrow::util::display::{ArrayFormatter, FormatOptions};
166 use std::collections::HashSet;
167
168 if source_batches.is_empty() {
169 return Ok(MergeResult::default());
170 }
171
172 let source_schema = source_batches
173 .first()
174 .ok_or_else(|| SqlError::DataFusion {
175 message: "empty source batches".into(),
176 })?
177 .schema();
178 let source_batch =
179 concat_batches(&source_schema, &source_batches).map_err(|e| SqlError::DataFusion {
180 message: e.to_string(),
181 })?;
182
183 let inserted: u64 = source_batches.iter().map(|b| b.num_rows() as u64).sum();
184 let fmt_opts = FormatOptions::default();
185
186 let key_idx = source_schema
188 .index_of(merge_key)
189 .map_err(|_| SqlError::Unsupported {
190 feature: format!("merge key column '{merge_key}' not found in source schema"),
191 })?;
192 let source_keys: HashSet<String> = {
193 let f = ArrayFormatter::try_new(source_batch.column(key_idx), &fmt_opts).map_err(|e| {
194 SqlError::DataFusion {
195 message: e.to_string(),
196 }
197 })?;
198 (0..source_batch.num_rows())
199 .map(|i| f.value(i).to_string())
200 .collect()
201 };
202
203 let updated = if source_keys.is_empty() {
205 0
206 } else {
207 let table = target.trim_start_matches("iceberg:");
208 let existing = ctx
209 .table(table)
210 .await
211 .map_err(|e| SqlError::DataFusion {
212 message: e.to_string(),
213 })?
214 .collect()
215 .await
216 .map_err(|e| SqlError::DataFusion {
217 message: e.to_string(),
218 })?;
219
220 if existing.is_empty() {
221 0
222 } else {
223 let existing_schema = existing
224 .first()
225 .ok_or_else(|| SqlError::DataFusion {
226 message: "empty existing batches".into(),
227 })?
228 .schema();
229 let tb =
230 concat_batches(&existing_schema, &existing).map_err(|e| SqlError::DataFusion {
231 message: e.to_string(),
232 })?;
233 let target_key_idx =
234 tb.schema()
235 .index_of(merge_key)
236 .map_err(|_| SqlError::Unsupported {
237 feature: format!(
238 "merge key column '{merge_key}' not found in target schema"
239 ),
240 })?;
241 let target_keys: Vec<String> = {
242 let f =
243 ArrayFormatter::try_new(tb.column(target_key_idx), &fmt_opts).map_err(|e| {
244 SqlError::DataFusion {
245 message: e.to_string(),
246 }
247 })?;
248 (0..tb.num_rows()).map(|i| f.value(i).to_string()).collect()
249 };
250 target_keys
251 .iter()
252 .filter(|k| source_keys.contains(*k))
253 .count() as u64
254 }
255 };
256 Ok(MergeResult {
259 rows_inserted: inserted.saturating_sub(updated),
260 rows_updated: updated,
261 rows_deleted: 0,
262 })
263}
264
265fn merge_result_batch(
266 result: krishiv_connectors::lakehouse::MergeDeltaResult,
267) -> SqlResult<RecordBatch> {
268 merge_metrics_batch(
269 result.rows_inserted,
270 result.rows_updated,
271 result.rows_deleted,
272 )
273}
274
275fn merge_metrics_batch(inserted: u64, updated: u64, deleted: u64) -> SqlResult<RecordBatch> {
276 let schema = Arc::new(Schema::new(vec![
277 Field::new("rows_inserted", DataType::Int64, false),
278 Field::new("rows_updated", DataType::Int64, false),
279 Field::new("rows_deleted", DataType::Int64, false),
280 ]));
281 RecordBatch::try_new(
282 schema,
283 vec![
284 Arc::new(Int64Array::from(vec![inserted as i64])),
285 Arc::new(Int64Array::from(vec![updated as i64])),
286 Arc::new(Int64Array::from(vec![deleted as i64])),
287 ],
288 )
289 .map_err(|e| SqlError::DataFusion {
290 message: format!("merge metrics batch: {e}"),
291 })
292}
293
294#[cfg(test)]
295#[allow(clippy::unwrap_used)]
296mod tests {
297 use super::*;
298 use arrow::array::{Int64Array, StringArray};
299 use arrow::datatypes::{DataType, Field, Schema};
300 use datafusion::prelude::SessionContext;
301 use std::sync::Arc;
302
303 #[test]
304 fn merge_regex_matches_basic_statement() {
305 let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
306 WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *";
307 assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
308 }
309
310 #[test]
311 fn merge_regex_matches_matched_only() {
312 let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
313 WHEN MATCHED THEN UPDATE SET *";
314 assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
315 }
316
317 #[test]
318 fn merge_regex_matches_not_matched_only() {
319 let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
320 WHEN NOT MATCHED THEN INSERT *";
321 assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
322 }
323
324 #[test]
325 fn merge_key_column_extraction() {
326 let on = "target.id = source.id";
327 let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
328 assert_eq!(caps.get(1).map(|m| m.as_str()), Some("target"));
330 assert_eq!(caps.get(2).map(|m| m.as_str()), Some("id"));
331 }
332
333 #[test]
334 fn merge_key_column_extraction_reversed() {
335 let on = "source.id = target.id";
337 let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
338 assert_eq!(caps.get(1).map(|m| m.as_str()), Some("source"));
339 assert_eq!(caps.get(3).map(|m| m.as_str()), Some("target"));
340 }
341
342 #[test]
343 fn merge_key_extracts_first_column_from_compound() {
344 let on = "target.id = source.id AND target.date = source.date";
345 let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
346 assert_eq!(caps.get(2).map(|m| m.as_str()), Some("id"));
347 }
348
349 #[tokio::test]
353 async fn iceberg_merge_returns_correct_row_counts() {
354 let ctx = SessionContext::new();
355
356 let schema = Arc::new(Schema::new(vec![
357 Field::new("id", DataType::Int64, false),
358 Field::new("name", DataType::Utf8, false),
359 ]));
360
361 ctx.register_batch(
363 "target_t",
364 RecordBatch::try_new(
365 schema.clone(),
366 vec![
367 Arc::new(Int64Array::from(vec![1, 2])),
368 Arc::new(StringArray::from(vec!["alice", "bob"])),
369 ],
370 )
371 .unwrap(),
372 )
373 .unwrap();
374
375 let source = RecordBatch::try_new(
377 schema.clone(),
378 vec![
379 Arc::new(Int64Array::from(vec![1, 3])),
380 Arc::new(StringArray::from(vec!["alice-updated", "charlie"])),
381 ],
382 )
383 .unwrap();
384
385 let result = dry_run_merge(&ctx, "iceberg:target_t", vec![source], "id")
386 .await
387 .unwrap();
388
389 assert_eq!(result.rows_updated, 1, "id=1 matches target → updated");
390 assert_eq!(result.rows_inserted, 1, "id=3 is new → inserted");
391 assert_eq!(result.rows_deleted, 0);
392 }
393}