Skip to main content

krishiv_sql/lakehouse/
merge.rs

1//! MERGE INTO dispatch (R18 S5, ADR-18.2).
2
3use std::sync::Arc;
4
5use arrow::array::Int64Array;
6use arrow::datatypes::{DataType, Field, Schema};
7use arrow::record_batch::RecordBatch;
8use regex::Regex;
9use std::sync::LazyLock;
10
11use datafusion::prelude::SessionContext;
12
13use crate::SqlError;
14use crate::SqlResult;
15
16/// Match `alias.col = alias.col` in the ON clause, capturing alias and col for both sides.
17static KEY_COL_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
18    Regex::new(
19        r"(?i)((?:\w+|`[^`]+`))\.((?:\w+|`[^`]+`))\s*=\s*((?:\w+|`[^`]+`))\.((?:\w+|`[^`]+`))",
20    )
21    .ok()
22});
23
24static MERGE_RE: LazyLock<Option<Regex>> = LazyLock::new(|| {
25    Regex::new(
26        r"(?is)^\s*MERGE\s+INTO\s+([`\w.:/-]+)\s+USING\s+([`\w.]+)\s+ON\s+(.+?)(?:\s+WHEN\s+MATCHED\s+THEN\s+UPDATE\s+SET\s+.+?)?(?:\s+WHEN\s+NOT\s+MATCHED\s+THEN\s+INSERT\s*(?:\([^)]*\))?\s*(?:VALUES\s*\([^)]*\)|\*)?)?\s*$",
27    )
28    .ok()
29});
30
31/// MERGE metrics returned as a single-row batch.
32#[derive(Debug, Clone, Default, PartialEq, Eq)]
33pub struct MergeResult {
34    pub rows_inserted: u64,
35    pub rows_updated: u64,
36    pub rows_deleted: u64,
37}
38
39/// Target table format is not Delta or Iceberg.
40#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
41#[error("MERGE INTO is only supported for delta: and iceberg: targets (got {target})")]
42pub struct MergeTargetUnsupportedError {
43    pub target: String,
44}
45
46/// Parse and execute a MERGE INTO statement when matched.
47pub async fn execute_merge_sql(ctx: &SessionContext, sql: &str) -> SqlResult<Vec<RecordBatch>> {
48    let caps = MERGE_RE
49        .as_ref()
50        .ok_or_else(|| SqlError::DataFusion {
51            message: "MERGE regex failed to compile".into(),
52        })?
53        .captures(sql)
54        .ok_or_else(|| SqlError::Unsupported {
55            feature: "MERGE INTO syntax".into(),
56        })?;
57    let target = caps[1].trim_matches('`').to_string();
58    let source_table = caps[2].trim_matches('`').to_string();
59    let on_clause = caps[3].trim();
60    let has_matched = caps
61        .get(4)
62        .and_then(|m| {
63            let s = m.as_str().trim();
64            if s.is_empty() { None } else { Some(s) }
65        })
66        .is_some();
67    let has_not_matched = caps
68        .get(5)
69        .and_then(|m| {
70            let s = m.as_str().trim();
71            if s.is_empty() { None } else { Some(s) }
72        })
73        .is_some();
74    if !has_matched && !has_not_matched {
75        return Err(SqlError::Unsupported {
76            feature: "MERGE INTO requires at least one WHEN MATCHED or WHEN NOT MATCHED clause"
77                .into(),
78        });
79    }
80
81    let merge_key: String = KEY_COL_RE
82        .as_ref()
83        .ok_or_else(|| SqlError::DataFusion { message: "KEY_COL regex failed to compile".into() })?
84        .captures(on_clause)
85        .ok_or_else(|| SqlError::Unsupported {
86            feature:
87                "MERGE ON clause must contain a qualified column equality (e.g. target.col = source.col)"
88                    .into(),
89        })
90        .map(|caps| {
91            let _left_alias = caps[1].trim_matches('`');
92            let left_col = caps[2].trim_matches('`');
93            let right_alias = caps[3].trim_matches('`');
94            let right_col = caps[4].trim_matches('`');
95            // Pick the side whose alias does NOT match the source table to get the target col.
96            let source_lower = source_table.to_lowercase();
97            if right_alias.to_lowercase() == source_lower {
98                left_col.to_string()
99            } else {
100                right_col.to_string()
101            }
102        })?;
103    let merge_key = merge_key.as_str();
104
105    let source_df = ctx
106        .table(&source_table)
107        .await
108        .map_err(|e| SqlError::DataFusion {
109            message: e.to_string(),
110        })?;
111    let source_batches = source_df
112        .collect()
113        .await
114        .map_err(|e| SqlError::DataFusion {
115            message: e.to_string(),
116        })?;
117
118    let metrics = if let Some(path) = target
119        .strip_prefix("delta:`")
120        .and_then(|p| p.strip_suffix('`'))
121    {
122        krishiv_connectors::lakehouse::merge_delta(path, source_batches, merge_key, true, true)
123            .await
124            .map_err(|e| SqlError::DataFusion {
125                message: e.to_string(),
126            })?
127    } else if let Some(path) = target.strip_prefix("delta.") {
128        krishiv_connectors::lakehouse::merge_delta(path, source_batches, merge_key, true, true)
129            .await
130            .map_err(|e| SqlError::DataFusion {
131                message: e.to_string(),
132            })?
133    } else if target.starts_with("iceberg:") {
134        let r = dry_run_merge(ctx, &target, source_batches, merge_key).await?;
135        krishiv_connectors::lakehouse::MergeDeltaResult {
136            rows_inserted: r.rows_inserted,
137            rows_updated: r.rows_updated,
138            rows_deleted: r.rows_deleted,
139        }
140    } else {
141        return Err(SqlError::DataFusion {
142            message: MergeTargetUnsupportedError { target }.to_string(),
143        });
144    };
145
146    Ok(vec![merge_result_batch(metrics)?])
147}
148
149/// **Dry-run** merge for Iceberg in-memory tables.
150///
151/// This function simulates a MERGE INTO by computing how many rows would be
152/// inserted vs updated, but does **not** write the merged result back to the
153/// target table. It returns [`MergeResult`] metrics only.
154///
155/// A real MERGE INTO would need to join source + target on the key column,
156/// apply WHEN MATCHED / WHEN NOT MATCHED logic, and write the merged output
157/// back to the table. This is deferred to R2 when Iceberg write support lands.
158async fn dry_run_merge(
159    ctx: &SessionContext,
160    target: &str,
161    source_batches: Vec<RecordBatch>,
162    merge_key: &str,
163) -> SqlResult<MergeResult> {
164    use arrow::compute::concat_batches;
165    use arrow::util::display::{ArrayFormatter, FormatOptions};
166    use std::collections::HashSet;
167
168    if source_batches.is_empty() {
169        return Ok(MergeResult::default());
170    }
171
172    let source_schema = source_batches
173        .first()
174        .ok_or_else(|| SqlError::DataFusion {
175            message: "empty source batches".into(),
176        })?
177        .schema();
178    let source_batch =
179        concat_batches(&source_schema, &source_batches).map_err(|e| SqlError::DataFusion {
180            message: e.to_string(),
181        })?;
182
183    let inserted: u64 = source_batches.iter().map(|b| b.num_rows() as u64).sum();
184    let fmt_opts = FormatOptions::default();
185
186    // Extract source key values into a hash set.
187    let key_idx = source_schema
188        .index_of(merge_key)
189        .map_err(|_| SqlError::Unsupported {
190            feature: format!("merge key column '{merge_key}' not found in source schema"),
191        })?;
192    let source_keys: HashSet<String> = {
193        let f = ArrayFormatter::try_new(source_batch.column(key_idx), &fmt_opts).map_err(|e| {
194            SqlError::DataFusion {
195                message: e.to_string(),
196            }
197        })?;
198        (0..source_batch.num_rows())
199            .map(|i| f.value(i).to_string())
200            .collect()
201    };
202
203    // Only load the target table when we have source keys to match against.
204    let updated = if source_keys.is_empty() {
205        0
206    } else {
207        let table = target.trim_start_matches("iceberg:");
208        let existing = ctx
209            .table(table)
210            .await
211            .map_err(|e| SqlError::DataFusion {
212                message: e.to_string(),
213            })?
214            .collect()
215            .await
216            .map_err(|e| SqlError::DataFusion {
217                message: e.to_string(),
218            })?;
219
220        if existing.is_empty() {
221            0
222        } else {
223            let existing_schema = existing
224                .first()
225                .ok_or_else(|| SqlError::DataFusion {
226                    message: "empty existing batches".into(),
227                })?
228                .schema();
229            let tb =
230                concat_batches(&existing_schema, &existing).map_err(|e| SqlError::DataFusion {
231                    message: e.to_string(),
232                })?;
233            let target_key_idx =
234                tb.schema()
235                    .index_of(merge_key)
236                    .map_err(|_| SqlError::Unsupported {
237                        feature: format!(
238                            "merge key column '{merge_key}' not found in target schema"
239                        ),
240                    })?;
241            let target_keys: Vec<String> = {
242                let f =
243                    ArrayFormatter::try_new(tb.column(target_key_idx), &fmt_opts).map_err(|e| {
244                        SqlError::DataFusion {
245                            message: e.to_string(),
246                        }
247                    })?;
248                (0..tb.num_rows()).map(|i| f.value(i).to_string()).collect()
249            };
250            target_keys
251                .iter()
252                .filter(|k| source_keys.contains(*k))
253                .count() as u64
254        }
255    };
256    // ---- end !Send scope ----
257
258    Ok(MergeResult {
259        rows_inserted: inserted.saturating_sub(updated),
260        rows_updated: updated,
261        rows_deleted: 0,
262    })
263}
264
265fn merge_result_batch(
266    result: krishiv_connectors::lakehouse::MergeDeltaResult,
267) -> SqlResult<RecordBatch> {
268    merge_metrics_batch(
269        result.rows_inserted,
270        result.rows_updated,
271        result.rows_deleted,
272    )
273}
274
275fn merge_metrics_batch(inserted: u64, updated: u64, deleted: u64) -> SqlResult<RecordBatch> {
276    let schema = Arc::new(Schema::new(vec![
277        Field::new("rows_inserted", DataType::Int64, false),
278        Field::new("rows_updated", DataType::Int64, false),
279        Field::new("rows_deleted", DataType::Int64, false),
280    ]));
281    RecordBatch::try_new(
282        schema,
283        vec![
284            Arc::new(Int64Array::from(vec![inserted as i64])),
285            Arc::new(Int64Array::from(vec![updated as i64])),
286            Arc::new(Int64Array::from(vec![deleted as i64])),
287        ],
288    )
289    .map_err(|e| SqlError::DataFusion {
290        message: format!("merge metrics batch: {e}"),
291    })
292}
293
294#[cfg(test)]
295#[allow(clippy::unwrap_used)]
296mod tests {
297    use super::*;
298    use arrow::array::{Int64Array, StringArray};
299    use arrow::datatypes::{DataType, Field, Schema};
300    use datafusion::prelude::SessionContext;
301    use std::sync::Arc;
302
303    #[test]
304    fn merge_regex_matches_basic_statement() {
305        let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
306                   WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *";
307        assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
308    }
309
310    #[test]
311    fn merge_regex_matches_matched_only() {
312        let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
313                   WHEN MATCHED THEN UPDATE SET *";
314        assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
315    }
316
317    #[test]
318    fn merge_regex_matches_not_matched_only() {
319        let sql = "MERGE INTO delta.`/tmp/t` USING staging ON target.id = source.id \
320                   WHEN NOT MATCHED THEN INSERT *";
321        assert!(MERGE_RE.as_ref().unwrap().is_match(sql));
322    }
323
324    #[test]
325    fn merge_key_column_extraction() {
326        let on = "target.id = source.id";
327        let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
328        // caps: (left_alias, left_col, right_alias, right_col)
329        assert_eq!(caps.get(1).map(|m| m.as_str()), Some("target"));
330        assert_eq!(caps.get(2).map(|m| m.as_str()), Some("id"));
331    }
332
333    #[test]
334    fn merge_key_column_extraction_reversed() {
335        // ON clause written source.col = target.col — must still extract target col.
336        let on = "source.id = target.id";
337        let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
338        assert_eq!(caps.get(1).map(|m| m.as_str()), Some("source"));
339        assert_eq!(caps.get(3).map(|m| m.as_str()), Some("target"));
340    }
341
342    #[test]
343    fn merge_key_extracts_first_column_from_compound() {
344        let on = "target.id = source.id AND target.date = source.date";
345        let caps = KEY_COL_RE.as_ref().unwrap().captures(on).unwrap();
346        assert_eq!(caps.get(2).map(|m| m.as_str()), Some("id"));
347    }
348
349    /// C9 regression: iceberg in-memory merge must return correct metrics
350    /// (updated for matching keys, inserted for new keys) and must NOT
351    /// report all rows as inserted (the full-table-replace bug).
352    #[tokio::test]
353    async fn iceberg_merge_returns_correct_row_counts() {
354        let ctx = SessionContext::new();
355
356        let schema = Arc::new(Schema::new(vec![
357            Field::new("id", DataType::Int64, false),
358            Field::new("name", DataType::Utf8, false),
359        ]));
360
361        // Target: (1, "alice"), (2, "bob")
362        ctx.register_batch(
363            "target_t",
364            RecordBatch::try_new(
365                schema.clone(),
366                vec![
367                    Arc::new(Int64Array::from(vec![1, 2])),
368                    Arc::new(StringArray::from(vec!["alice", "bob"])),
369                ],
370            )
371            .unwrap(),
372        )
373        .unwrap();
374
375        // Source: (1, "alice-updated"), (3, "charlie") — id=1 matches, id=3 is new
376        let source = RecordBatch::try_new(
377            schema.clone(),
378            vec![
379                Arc::new(Int64Array::from(vec![1, 3])),
380                Arc::new(StringArray::from(vec!["alice-updated", "charlie"])),
381            ],
382        )
383        .unwrap();
384
385        let result = dry_run_merge(&ctx, "iceberg:target_t", vec![source], "id")
386            .await
387            .unwrap();
388
389        assert_eq!(result.rows_updated, 1, "id=1 matches target → updated");
390        assert_eq!(result.rows_inserted, 1, "id=3 is new → inserted");
391        assert_eq!(result.rows_deleted, 0);
392    }
393}