Skip to main content

omnigraph/exec/
mutation.rs

1use super::*;
2
3use super::query::literal_to_sql;
4
5// ─── Mutation helpers ────────────────────────────────────────────────────────
6
7/// Resolve an IRExpr to a concrete Literal value at runtime.
8fn resolve_expr_value(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
9    match expr {
10        IRExpr::Literal(lit) => Ok(lit.clone()),
11        IRExpr::Param(name) => params
12            .get(name)
13            .cloned()
14            .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name))),
15        other => Err(OmniError::manifest(format!(
16            "unsupported expression in mutation: {:?}",
17            other
18        ))),
19    }
20}
21
22/// Create a single-element or N-element array from a Literal, matching the target DataType.
23fn literal_to_typed_array(
24    lit: &Literal,
25    data_type: &DataType,
26    num_rows: usize,
27) -> Result<ArrayRef> {
28    Ok(match (lit, data_type) {
29        (Literal::Null, _) => arrow_array::new_null_array(data_type, num_rows),
30        (Literal::String(s), DataType::Utf8) => {
31            Arc::new(StringArray::from(vec![s.as_str(); num_rows])) as ArrayRef
32        }
33        (Literal::Integer(n), DataType::Int32) => {
34            Arc::new(Int32Array::from(vec![*n as i32; num_rows]))
35        }
36        (Literal::Integer(n), DataType::Int64) => Arc::new(Int64Array::from(vec![*n; num_rows])),
37        (Literal::Integer(n), DataType::UInt32) => {
38            Arc::new(UInt32Array::from(vec![*n as u32; num_rows]))
39        }
40        (Literal::Integer(n), DataType::UInt64) => {
41            Arc::new(UInt64Array::from(vec![*n as u64; num_rows]))
42        }
43        (Literal::Float(f), DataType::Float32) => {
44            Arc::new(Float32Array::from(vec![*f as f32; num_rows]))
45        }
46        (Literal::Float(f), DataType::Float64) => Arc::new(Float64Array::from(vec![*f; num_rows])),
47        (Literal::Bool(b), DataType::Boolean) => Arc::new(BooleanArray::from(vec![*b; num_rows])),
48        (Literal::Date(s), DataType::Date32) => {
49            let days = crate::loader::parse_date32_literal(s)?;
50            Arc::new(Date32Array::from(vec![days; num_rows]))
51        }
52        (Literal::DateTime(s), DataType::Date64) => Arc::new(Date64Array::from(vec![
53            crate::loader::parse_date64_literal(s)?;
54            num_rows
55        ])),
56        (Literal::List(items), DataType::List(field)) => {
57            typed_list_literal_to_array(items, field.data_type(), num_rows)?
58        }
59        (Literal::List(items), DataType::FixedSizeList(field, dim))
60            if field.data_type() == &DataType::Float32 =>
61        {
62            if items.len() != *dim as usize {
63                return Err(OmniError::manifest(format!(
64                    "vector property expects {} dimensions, got {}",
65                    dim,
66                    items.len()
67                )));
68            }
69            let mut builder = FixedSizeListBuilder::with_capacity(
70                Float32Builder::with_capacity(num_rows * (*dim as usize)),
71                *dim,
72                num_rows,
73            )
74            .with_field(field.clone());
75            for _ in 0..num_rows {
76                for item in items {
77                    match item {
78                        Literal::Integer(value) => builder.values().append_value(*value as f32),
79                        Literal::Float(value) => builder.values().append_value(*value as f32),
80                        _ => {
81                            return Err(OmniError::manifest(
82                                "vector elements must be numeric".to_string(),
83                            ));
84                        }
85                    }
86                }
87                builder.append(true);
88            }
89            Arc::new(builder.finish())
90        }
91        _ => {
92            return Err(OmniError::manifest(format!(
93                "cannot convert {:?} to {:?}",
94                lit, data_type
95            )));
96        }
97    })
98}
99
100fn typed_list_literal_to_array(
101    items: &[Literal],
102    item_type: &DataType,
103    num_rows: usize,
104) -> Result<ArrayRef> {
105    match item_type {
106        DataType::Utf8 => {
107            let mut builder = ListBuilder::new(StringBuilder::new());
108            for _ in 0..num_rows {
109                for item in items {
110                    match item {
111                        Literal::String(value) => builder.values().append_value(value),
112                        _ => builder.values().append_null(),
113                    }
114                }
115                builder.append(true);
116            }
117            Ok(Arc::new(builder.finish()))
118        }
119        DataType::Boolean => {
120            let mut builder = ListBuilder::new(BooleanBuilder::new());
121            for _ in 0..num_rows {
122                for item in items {
123                    match item {
124                        Literal::Bool(value) => builder.values().append_value(*value),
125                        _ => builder.values().append_null(),
126                    }
127                }
128                builder.append(true);
129            }
130            Ok(Arc::new(builder.finish()))
131        }
132        DataType::Int32 => {
133            let mut builder = ListBuilder::new(Int32Builder::new());
134            for _ in 0..num_rows {
135                for item in items {
136                    match item {
137                        Literal::Integer(value) => {
138                            let value = i32::try_from(*value).map_err(|_| {
139                                OmniError::manifest(format!(
140                                    "list value {} exceeds Int32 range",
141                                    value
142                                ))
143                            })?;
144                            builder.values().append_value(value);
145                        }
146                        _ => builder.values().append_null(),
147                    }
148                }
149                builder.append(true);
150            }
151            Ok(Arc::new(builder.finish()))
152        }
153        DataType::Int64 => {
154            let mut builder = ListBuilder::new(Int64Builder::new());
155            for _ in 0..num_rows {
156                for item in items {
157                    match item {
158                        Literal::Integer(value) => builder.values().append_value(*value),
159                        _ => builder.values().append_null(),
160                    }
161                }
162                builder.append(true);
163            }
164            Ok(Arc::new(builder.finish()))
165        }
166        DataType::UInt32 => {
167            let mut builder = ListBuilder::new(UInt32Builder::new());
168            for _ in 0..num_rows {
169                for item in items {
170                    match item {
171                        Literal::Integer(value) => {
172                            let value = u32::try_from(*value).map_err(|_| {
173                                OmniError::manifest(format!(
174                                    "list value {} exceeds UInt32 range",
175                                    value
176                                ))
177                            })?;
178                            builder.values().append_value(value);
179                        }
180                        _ => builder.values().append_null(),
181                    }
182                }
183                builder.append(true);
184            }
185            Ok(Arc::new(builder.finish()))
186        }
187        DataType::UInt64 => {
188            let mut builder = ListBuilder::new(UInt64Builder::new());
189            for _ in 0..num_rows {
190                for item in items {
191                    match item {
192                        Literal::Integer(value) => {
193                            let value = u64::try_from(*value).map_err(|_| {
194                                OmniError::manifest(format!(
195                                    "list value {} exceeds UInt64 range",
196                                    value
197                                ))
198                            })?;
199                            builder.values().append_value(value);
200                        }
201                        _ => builder.values().append_null(),
202                    }
203                }
204                builder.append(true);
205            }
206            Ok(Arc::new(builder.finish()))
207        }
208        DataType::Float32 => {
209            let mut builder = ListBuilder::new(Float32Builder::new());
210            for _ in 0..num_rows {
211                for item in items {
212                    match item {
213                        Literal::Integer(value) => builder.values().append_value(*value as f32),
214                        Literal::Float(value) => builder.values().append_value(*value as f32),
215                        _ => builder.values().append_null(),
216                    }
217                }
218                builder.append(true);
219            }
220            Ok(Arc::new(builder.finish()))
221        }
222        DataType::Float64 => {
223            let mut builder = ListBuilder::new(Float64Builder::new());
224            for _ in 0..num_rows {
225                for item in items {
226                    match item {
227                        Literal::Integer(value) => builder.values().append_value(*value as f64),
228                        Literal::Float(value) => builder.values().append_value(*value),
229                        _ => builder.values().append_null(),
230                    }
231                }
232                builder.append(true);
233            }
234            Ok(Arc::new(builder.finish()))
235        }
236        DataType::Date32 => {
237            let mut builder = ListBuilder::new(Date32Builder::new());
238            for _ in 0..num_rows {
239                for item in items {
240                    match item {
241                        Literal::Date(value) => builder
242                            .values()
243                            .append_value(crate::loader::parse_date32_literal(value)?),
244                        _ => builder.values().append_null(),
245                    }
246                }
247                builder.append(true);
248            }
249            Ok(Arc::new(builder.finish()))
250        }
251        DataType::Date64 => {
252            let mut builder = ListBuilder::new(Date64Builder::new());
253            for _ in 0..num_rows {
254                for item in items {
255                    match item {
256                        Literal::DateTime(value) => builder
257                            .values()
258                            .append_value(crate::loader::parse_date64_literal(value)?),
259                        _ => builder.values().append_null(),
260                    }
261                }
262                builder.append(true);
263            }
264            Ok(Arc::new(builder.finish()))
265        }
266        other => Err(OmniError::manifest(format!(
267            "cannot convert list literal to {:?}",
268            other
269        ))),
270    }
271}
272
273/// Build a single-element blob array from a URI or base64 value string.
274fn build_blob_array_from_value(value: &str) -> Result<ArrayRef> {
275    let mut builder = BlobArrayBuilder::new(1);
276    crate::loader::append_blob_value(&mut builder, value)?;
277    builder
278        .finish()
279        .map_err(|e| OmniError::Lance(e.to_string()))
280}
281
282/// Build a null blob array with one element.
283fn build_null_blob_array() -> Result<ArrayRef> {
284    let mut builder = BlobArrayBuilder::new(1);
285    builder
286        .push_null()
287        .map_err(|e| OmniError::Lance(e.to_string()))?;
288    builder
289        .finish()
290        .map_err(|e| OmniError::Lance(e.to_string()))
291}
292
293/// Build a single-row RecordBatch from resolved assignments.
294fn build_insert_batch(
295    schema: &SchemaRef,
296    id: &str,
297    assignments: &HashMap<String, Literal>,
298    blob_properties: &HashSet<String>,
299) -> Result<RecordBatch> {
300    let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
301
302    for field in schema.fields() {
303        if field.name() == "id" {
304            columns.push(Arc::new(StringArray::from(vec![id])));
305        } else if blob_properties.contains(field.name()) {
306            if let Some(Literal::String(uri)) = assignments.get(field.name()) {
307                columns.push(build_blob_array_from_value(uri)?);
308            } else if field.is_nullable() {
309                columns.push(build_null_blob_array()?);
310            } else {
311                return Err(OmniError::manifest(format!(
312                    "missing required blob property '{}'",
313                    field.name()
314                )));
315            }
316        } else if field.name() == "src" {
317            let lit = assignments.get("from").ok_or_else(|| {
318                OmniError::manifest("missing required edge endpoint 'from'".to_string())
319            })?;
320            columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
321        } else if field.name() == "dst" {
322            let lit = assignments.get("to").ok_or_else(|| {
323                OmniError::manifest("missing required edge endpoint 'to'".to_string())
324            })?;
325            columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
326        } else if let Some(lit) = assignments.get(field.name()) {
327            columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
328        } else if field.is_nullable() {
329            columns.push(arrow_array::new_null_array(field.data_type(), 1));
330        } else {
331            return Err(OmniError::manifest(format!(
332                "missing required property '{}'",
333                field.name()
334            )));
335        }
336    }
337
338    RecordBatch::try_new(schema.clone(), columns).map_err(|e| OmniError::Lance(e.to_string()))
339}
340
341async fn validate_edge_insert_endpoints(
342    db: &Omnigraph,
343    edge_name: &str,
344    assignments: &HashMap<String, Literal>,
345) -> Result<()> {
346    let edge_type = db
347        .catalog()
348        .edge_types
349        .get(edge_name)
350        .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_name)))?;
351    let from = match assignments.get("from") {
352        Some(Literal::String(value)) => value.as_str(),
353        Some(other) => {
354            return Err(OmniError::manifest(format!(
355                "edge {} from endpoint must be a string id, got {}",
356                edge_name,
357                literal_to_sql(other)
358            )));
359        }
360        None => {
361            return Err(OmniError::manifest(format!(
362                "edge {} missing 'from' endpoint",
363                edge_name
364            )));
365        }
366    };
367    let to = match assignments.get("to") {
368        Some(Literal::String(value)) => value.as_str(),
369        Some(other) => {
370            return Err(OmniError::manifest(format!(
371                "edge {} to endpoint must be a string id, got {}",
372                edge_name,
373                literal_to_sql(other)
374            )));
375        }
376        None => {
377            return Err(OmniError::manifest(format!(
378                "edge {} missing 'to' endpoint",
379                edge_name
380            )));
381        }
382    };
383
384    ensure_node_id_exists(db, &edge_type.from_type, from, "src").await?;
385    ensure_node_id_exists(db, &edge_type.to_type, to, "dst").await?;
386    Ok(())
387}
388
389async fn ensure_node_id_exists(
390    db: &Omnigraph,
391    node_type: &str,
392    id: &str,
393    label: &str,
394) -> Result<()> {
395    let snapshot = db.snapshot();
396    let table_key = format!("node:{}", node_type);
397    let ds = snapshot.open(&table_key).await?;
398    let filter = format!("id = '{}'", id.replace('\'', "''"));
399    let exists = ds
400        .count_rows(Some(filter))
401        .await
402        .map_err(|e| OmniError::Lance(e.to_string()))?
403        > 0;
404    if exists {
405        Ok(())
406    } else {
407        Err(OmniError::manifest(format!(
408            "{} '{}' not found in {}",
409            label, id, node_type
410        )))
411    }
412}
413
414/// Convert an IRMutationPredicate to a Lance SQL filter string.
415fn predicate_to_sql(
416    predicate: &IRMutationPredicate,
417    params: &ParamMap,
418    is_edge: bool,
419) -> Result<String> {
420    let column = if is_edge {
421        match predicate.property.as_str() {
422            "from" => "src".to_string(),
423            "to" => "dst".to_string(),
424            other => other.to_string(),
425        }
426    } else {
427        predicate.property.clone()
428    };
429
430    let value = resolve_expr_value(&predicate.value, params)?;
431    let value_sql = literal_to_sql(&value);
432
433    let op = match predicate.op {
434        CompOp::Eq => "=",
435        CompOp::Ne => "!=",
436        CompOp::Gt => ">",
437        CompOp::Lt => "<",
438        CompOp::Ge => ">=",
439        CompOp::Le => "<=",
440        CompOp::Contains => {
441            return Err(OmniError::manifest(
442                "contains predicate not supported in mutations".to_string(),
443            ));
444        }
445    };
446
447    Ok(format!("{} {} {}", column, op, value_sql))
448}
449
450/// Replace specific columns in a RecordBatch with new literal values.
451/// Blob columns are excluded from the scan result, so assigned blob values are
452/// synthesized from the full table schema and included inline in the update
453/// batch. Unassigned blob columns are omitted so merge_insert leaves them
454/// untouched.
455fn apply_assignments(
456    full_schema: &SchemaRef,
457    batch: &RecordBatch,
458    assignments: &HashMap<String, Literal>,
459    blob_properties: &HashSet<String>,
460) -> Result<RecordBatch> {
461    let mut columns: Vec<ArrayRef> = Vec::with_capacity(full_schema.fields().len());
462    let mut out_fields: Vec<Field> = Vec::with_capacity(full_schema.fields().len());
463
464    for field in full_schema.fields().iter() {
465        if blob_properties.contains(field.name()) {
466            // Blob columns aren't in the scan result. If this blob has an
467            // assignment, build the blob array inline so the single
468            // merge_insert covers both scalar and blob updates. Unassigned
469            // blob columns are omitted — merge_insert only touches columns
470            // present in the batch.
471            if let Some(Literal::String(uri)) = assignments.get(field.name()) {
472                let mut builder = BlobArrayBuilder::new(batch.num_rows());
473                for _ in 0..batch.num_rows() {
474                    crate::loader::append_blob_value(&mut builder, uri)?;
475                }
476                let blob_field = lance::blob::blob_field(field.name(), true);
477                out_fields.push(blob_field);
478                columns.push(
479                    builder
480                        .finish()
481                        .map_err(|e| OmniError::Lance(e.to_string()))?,
482                );
483            }
484            // else: no assignment for this blob column — skip it
485        } else if let Some(lit) = assignments.get(field.name()) {
486            out_fields.push(field.as_ref().clone());
487            columns.push(literal_to_typed_array(
488                lit,
489                field.data_type(),
490                batch.num_rows(),
491            )?);
492        } else {
493            let col = batch.column_by_name(field.name()).ok_or_else(|| {
494                OmniError::Lance(format!(
495                    "column '{}' not found in scan result",
496                    field.name()
497                ))
498            })?;
499            out_fields.push(field.as_ref().clone());
500            columns.push(col.clone());
501        }
502    }
503
504    RecordBatch::try_new(Arc::new(Schema::new(out_fields)), columns)
505        .map_err(|e| OmniError::Lance(e.to_string()))
506}
507
508// ─── Mutation execution ──────────────────────────────────────────────────────
509
510impl Omnigraph {
511    pub async fn mutate(
512        &mut self,
513        branch: &str,
514        query_source: &str,
515        query_name: &str,
516        params: &ParamMap,
517    ) -> Result<MutationResult> {
518        self.mutate_as(branch, query_source, query_name, params, None)
519            .await
520    }
521
522    pub async fn mutate_as(
523        &mut self,
524        branch: &str,
525        query_source: &str,
526        query_name: &str,
527        params: &ParamMap,
528        actor_id: Option<&str>,
529    ) -> Result<MutationResult> {
530        let previous_actor = self.audit_actor_id.clone();
531        self.audit_actor_id = actor_id.map(str::to_string);
532        let result = self
533            .mutate_with_current_actor(branch, query_source, query_name, params)
534            .await;
535        self.audit_actor_id = previous_actor;
536        result
537    }
538
539    async fn mutate_with_current_actor(
540        &mut self,
541        branch: &str,
542        query_source: &str,
543        query_name: &str,
544        params: &ParamMap,
545    ) -> Result<MutationResult> {
546        self.ensure_schema_state_valid().await?;
547        let requested = Self::normalize_branch_name(branch)?;
548        let resolved_params = enrich_mutation_params(params)?;
549        let operation = format!(
550            "mutation:{}:branch={}",
551            query_name,
552            requested.as_deref().unwrap_or("main")
553        );
554
555        if requested.as_deref().is_some_and(is_internal_run_branch) {
556            return self
557                .execute_named_mutation_on_branch(
558                    requested.as_deref(),
559                    query_source,
560                    query_name,
561                    &resolved_params,
562                )
563                .await;
564        }
565
566        let target_branch = requested.clone().unwrap_or_else(|| "main".to_string());
567        let target_head_before = self.latest_branch_snapshot_id(&target_branch).await?;
568        let run = self
569            .begin_run(&target_branch, Some(operation.as_str()))
570            .await?;
571
572        let staged_result = match self
573            .execute_named_mutation_on_branch(
574                Some(run.run_branch.as_str()),
575                query_source,
576                query_name,
577                &resolved_params,
578            )
579            .await
580        {
581            Ok(result) => result,
582            Err(err) => {
583                let _ = self.fail_run(&run.run_id).await;
584                return Err(err);
585            }
586        };
587
588        let target_head_now = self.latest_branch_snapshot_id(&target_branch).await?;
589        if target_head_now.as_str() != target_head_before.as_str() {
590            let _ = self.fail_run(&run.run_id).await;
591            return Err(OmniError::manifest_conflict(format!(
592                "target branch '{}' advanced during transactional mutation; retry",
593                target_branch
594            )));
595        }
596
597        if let Err(err) = self.publish_run(&run.run_id).await {
598            let _ = self.fail_run(&run.run_id).await;
599            return Err(err);
600        }
601
602        Ok(staged_result)
603    }
604
605    async fn execute_named_mutation_on_branch(
606        &mut self,
607        branch: Option<&str>,
608        query_source: &str,
609        query_name: &str,
610        params: &ParamMap,
611    ) -> Result<MutationResult> {
612        let requested = match branch {
613            Some(branch) => Self::normalize_branch_name(branch)?,
614            None => None,
615        };
616        let current = self.active_branch().map(str::to_string);
617        if requested == current {
618            return self
619                .execute_named_mutation(query_source, query_name, params)
620                .await;
621        }
622
623        let previous = self
624            .swap_coordinator_for_branch(requested.as_deref())
625            .await?;
626        let result = self
627            .execute_named_mutation(query_source, query_name, params)
628            .await;
629        self.restore_coordinator(previous);
630        result
631    }
632
633    async fn execute_named_mutation(
634        &mut self,
635        query_source: &str,
636        query_name: &str,
637        params: &ParamMap,
638    ) -> Result<MutationResult> {
639        let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
640            .map_err(|e| OmniError::manifest(e.to_string()))?;
641
642        let checked = typecheck_query_decl(self.catalog(), &query_decl)?;
643        match checked {
644            CheckedQuery::Mutation(_) => {}
645            CheckedQuery::Read(_) => {
646                return Err(OmniError::manifest(
647                    "mutation execution called on a read query; use query instead".to_string(),
648                ));
649            }
650        }
651
652        let ir = lower_mutation_query(&query_decl)?;
653
654        let mut total = MutationResult::default();
655        for op in &ir.ops {
656            let result = match op {
657                MutationOpIR::Insert {
658                    type_name,
659                    assignments,
660                } => self.execute_insert(type_name, assignments, params).await?,
661                MutationOpIR::Update {
662                    type_name,
663                    assignments,
664                    predicate,
665                } => {
666                    self.execute_update(type_name, assignments, predicate, params)
667                        .await?
668                }
669                MutationOpIR::Delete {
670                    type_name,
671                    predicate,
672                } => self.execute_delete(type_name, predicate, params).await?,
673            };
674            total.affected_nodes += result.affected_nodes;
675            total.affected_edges += result.affected_edges;
676        }
677        Ok(total)
678    }
679
680    async fn execute_insert(
681        &mut self,
682        type_name: &str,
683        assignments: &[IRAssignment],
684        params: &ParamMap,
685    ) -> Result<MutationResult> {
686        let mut resolved: HashMap<String, Literal> = HashMap::new();
687        for a in assignments {
688            resolved.insert(a.property.clone(), resolve_expr_value(&a.value, params)?);
689        }
690
691        let is_node = self.catalog().node_types.contains_key(type_name);
692        let is_edge = self.catalog().edge_types.contains_key(type_name);
693
694        if is_node {
695            let node_type = &self.catalog().node_types[type_name];
696            let schema = node_type.arrow_schema.clone();
697            let blob_props = node_type.blob_properties.clone();
698            let id = if let Some(key_prop) = node_type.key_property() {
699                match resolved.get(key_prop) {
700                    Some(Literal::String(s)) => s.clone(),
701                    Some(other) => literal_to_sql(other).trim_matches('\'').to_string(),
702                    None => {
703                        return Err(OmniError::manifest(format!(
704                            "insert missing @key property '{}'",
705                            key_prop
706                        )));
707                    }
708                }
709            } else {
710                ulid::Ulid::new().to_string()
711            };
712
713            let batch = build_insert_batch(&schema, &id, &resolved, &blob_props)?;
714            crate::loader::validate_value_constraints(&batch, node_type)?;
715            let has_key = node_type.key_property().is_some();
716            let (state, table_branch) = if has_key {
717                self.upsert_batch(type_name, true, schema, batch).await?
718            } else {
719                self.append_batch(type_name, true, schema, batch).await?
720            };
721
722            let table_key = format!("node:{}", type_name);
723            self.commit_updates(&[crate::db::SubTableUpdate {
724                table_key,
725                table_version: state.version,
726                table_branch,
727                row_count: state.row_count,
728                version_metadata: state.version_metadata,
729            }])
730            .await?;
731
732            Ok(MutationResult {
733                affected_nodes: 1,
734                affected_edges: 0,
735            })
736        } else if is_edge {
737            let edge_type = &self.catalog().edge_types[type_name];
738            let schema = edge_type.arrow_schema.clone();
739            let blob_props = edge_type.blob_properties.clone();
740            let id = ulid::Ulid::new().to_string();
741
742            let batch = build_insert_batch(&schema, &id, &resolved, &blob_props)?;
743            validate_edge_insert_endpoints(self, type_name, &resolved).await?;
744            let (state, table_branch) = self.append_batch(type_name, false, schema, batch).await?;
745
746            let table_key = format!("edge:{}", type_name);
747            self.commit_updates(&[crate::db::SubTableUpdate {
748                table_key,
749                table_version: state.version,
750                table_branch,
751                row_count: state.row_count,
752                version_metadata: state.version_metadata,
753            }])
754            .await?;
755
756            self.invalidate_graph_index().await;
757
758            Ok(MutationResult {
759                affected_nodes: 0,
760                affected_edges: 1,
761            })
762        } else {
763            Err(OmniError::manifest(format!("unknown type '{}'", type_name)))
764        }
765    }
766
767    /// Append a batch to a sub-table, returning (new_version, row_count).
768    async fn append_batch(
769        &self,
770        type_name: &str,
771        is_node: bool,
772        _schema: SchemaRef,
773        batch: RecordBatch,
774    ) -> Result<(crate::table_store::TableState, Option<String>)> {
775        let table_key = if is_node {
776            format!("node:{}", type_name)
777        } else {
778            format!("edge:{}", type_name)
779        };
780        let (mut ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
781        let state = self
782            .table_store()
783            .append_batch(&full_path, &mut ds, batch)
784            .await?;
785        Ok((state, table_branch))
786    }
787
788    /// Upsert a batch into a sub-table using merge_insert keyed by "id".
789    /// Used for @key node types to enforce uniqueness.
790    async fn upsert_batch(
791        &self,
792        type_name: &str,
793        is_node: bool,
794        _schema: SchemaRef,
795        batch: RecordBatch,
796    ) -> Result<(crate::table_store::TableState, Option<String>)> {
797        let table_key = if is_node {
798            format!("node:{}", type_name)
799        } else {
800            format!("edge:{}", type_name)
801        };
802        let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
803        let state = self
804            .table_store()
805            .merge_insert_batch(
806                &full_path,
807                ds,
808                batch,
809                vec!["id".to_string()],
810                lance::dataset::WhenMatched::UpdateAll,
811                lance::dataset::WhenNotMatched::InsertAll,
812            )
813            .await?;
814        Ok((state, table_branch))
815    }
816
817    async fn execute_update(
818        &mut self,
819        type_name: &str,
820        assignments: &[IRAssignment],
821        predicate: &IRMutationPredicate,
822        params: &ParamMap,
823    ) -> Result<MutationResult> {
824        // Defense in depth: ensure this is a node type
825        if !self.catalog().node_types.contains_key(type_name) {
826            return Err(OmniError::manifest(format!(
827                "update is only supported for node types, not '{}'",
828                type_name
829            )));
830        }
831
832        // Reject updates to @key properties — identity is immutable
833        if let Some(key_prop) = self.catalog().node_types[type_name].key_property() {
834            if assignments.iter().any(|a| a.property == key_prop) {
835                return Err(OmniError::manifest(format!(
836                    "cannot update @key property '{}' — delete and re-insert instead",
837                    key_prop
838                )));
839            }
840        }
841
842        let pred_sql = predicate_to_sql(predicate, params, false)?;
843        let schema = self.catalog().node_types[type_name].arrow_schema.clone();
844        let blob_props = self.catalog().node_types[type_name].blob_properties.clone();
845
846        let table_key = format!("node:{}", type_name);
847        let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
848        let initial_version = ds.version().version;
849
850        let non_blob_cols: Vec<&str> = schema
851            .fields()
852            .iter()
853            .filter(|f| !blob_props.contains(f.name()))
854            .map(|f| f.name().as_str())
855            .collect();
856        let batches = self
857            .table_store()
858            .scan(
859                &ds,
860                (!blob_props.is_empty()).then_some(non_blob_cols.as_slice()),
861                Some(&pred_sql),
862                None,
863            )
864            .await?;
865
866        if batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0) {
867            return Ok(MutationResult {
868                affected_nodes: 0,
869                affected_edges: 0,
870            });
871        }
872
873        let matched = if batches.len() == 1 {
874            batches.into_iter().next().unwrap()
875        } else {
876            let s = batches[0].schema();
877            arrow_select::concat::concat_batches(&s, &batches)
878                .map_err(|e| OmniError::Lance(e.to_string()))?
879        };
880
881        let affected_count = matched.num_rows();
882
883        let mut resolved: HashMap<String, Literal> = HashMap::new();
884        for a in assignments {
885            resolved.insert(a.property.clone(), resolve_expr_value(&a.value, params)?);
886        }
887        let updated = apply_assignments(&schema, &matched, &resolved, &blob_props)?;
888        crate::loader::validate_value_constraints(&updated, &self.catalog().node_types[type_name])?;
889
890        // Re-open for merge_insert (scan consumed the dataset;
891        // version guard was already applied by open_for_mutation above)
892        let ds = self
893            .reopen_for_mutation(
894                &table_key,
895                &full_path,
896                table_branch.as_deref(),
897                initial_version,
898            )
899            .await?;
900        let update_state = self
901            .table_store()
902            .merge_insert_batch(
903                &full_path,
904                ds,
905                updated,
906                vec!["id".to_string()],
907                lance::dataset::WhenMatched::UpdateAll,
908                lance::dataset::WhenNotMatched::DoNothing,
909            )
910            .await?;
911
912        self.commit_updates(&[crate::db::SubTableUpdate {
913            table_key,
914            table_version: update_state.version,
915            table_branch,
916            row_count: update_state.row_count,
917            version_metadata: update_state.version_metadata,
918        }])
919        .await?;
920
921        Ok(MutationResult {
922            affected_nodes: affected_count,
923            affected_edges: 0,
924        })
925    }
926
927    async fn execute_delete(
928        &mut self,
929        type_name: &str,
930        predicate: &IRMutationPredicate,
931        params: &ParamMap,
932    ) -> Result<MutationResult> {
933        let is_node = self.catalog().node_types.contains_key(type_name);
934        if is_node {
935            self.execute_delete_node(type_name, predicate, params).await
936        } else {
937            self.execute_delete_edge(type_name, predicate, params).await
938        }
939    }
940
941    async fn execute_delete_node(
942        &mut self,
943        type_name: &str,
944        predicate: &IRMutationPredicate,
945        params: &ParamMap,
946    ) -> Result<MutationResult> {
947        let pred_sql = predicate_to_sql(predicate, params, false)?;
948
949        let table_key = format!("node:{}", type_name);
950        let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
951        let initial_version = ds.version().version;
952
953        // Scan matching IDs for cascade
954        let batches = self
955            .table_store()
956            .scan(&ds, Some(&["id"]), Some(&pred_sql), None)
957            .await?;
958
959        let deleted_ids: Vec<String> = batches
960            .iter()
961            .flat_map(|batch| {
962                let ids = batch
963                    .column(0)
964                    .as_any()
965                    .downcast_ref::<StringArray>()
966                    .unwrap();
967                (0..ids.len())
968                    .map(|i| ids.value(i).to_string())
969                    .collect::<Vec<_>>()
970            })
971            .collect();
972
973        if deleted_ids.is_empty() {
974            return Ok(MutationResult {
975                affected_nodes: 0,
976                affected_edges: 0,
977            });
978        }
979
980        let affected_nodes = deleted_ids.len();
981
982        // Delete nodes (re-open needed because the scan consumed the dataset;
983        // version guard was already applied by open_for_mutation above)
984        let mut ds = self
985            .reopen_for_mutation(
986                &table_key,
987                &full_path,
988                table_branch.as_deref(),
989                initial_version,
990            )
991            .await?;
992        let delete_state = self
993            .table_store()
994            .delete_where(&full_path, &mut ds, &pred_sql)
995            .await?;
996
997        let mut updates = vec![crate::db::SubTableUpdate {
998            table_key,
999            table_version: delete_state.version,
1000            table_branch: table_branch.clone(),
1001            row_count: delete_state.row_count,
1002            version_metadata: delete_state.version_metadata,
1003        }];
1004
1005        let mut affected_edges = 0usize;
1006        let escaped: Vec<String> = deleted_ids
1007            .iter()
1008            .map(|id| format!("'{}'", id.replace('\'', "''")))
1009            .collect();
1010        let id_list = escaped.join(", ");
1011
1012        let edge_info: Vec<(String, String, String)> = self
1013            .catalog()
1014            .edge_types
1015            .iter()
1016            .map(|(name, et)| (name.clone(), et.from_type.clone(), et.to_type.clone()))
1017            .collect();
1018
1019        for (edge_name, from_type, to_type) in &edge_info {
1020            let mut cascade_filters = Vec::new();
1021            if from_type == type_name {
1022                cascade_filters.push(format!("src IN ({})", id_list));
1023            }
1024            if to_type == type_name {
1025                cascade_filters.push(format!("dst IN ({})", id_list));
1026            }
1027            if cascade_filters.is_empty() {
1028                continue;
1029            }
1030
1031            let edge_table_key = format!("edge:{}", edge_name);
1032            let cascade_filter = cascade_filters.join(" OR ");
1033            let (mut edge_ds, edge_full_path, edge_table_branch) =
1034                self.open_for_mutation(&edge_table_key).await?;
1035
1036            let edge_delete = self
1037                .table_store()
1038                .delete_where(&edge_full_path, &mut edge_ds, &cascade_filter)
1039                .await?;
1040
1041            affected_edges += edge_delete.deleted_rows;
1042
1043            if edge_delete.deleted_rows > 0 {
1044                updates.push(crate::db::SubTableUpdate {
1045                    table_key: edge_table_key,
1046                    table_version: edge_delete.version,
1047                    table_branch: edge_table_branch,
1048                    row_count: edge_delete.row_count,
1049                    version_metadata: edge_delete.version_metadata,
1050                });
1051            }
1052        }
1053
1054        self.commit_updates(&updates).await?;
1055
1056        if affected_edges > 0 {
1057            self.invalidate_graph_index().await;
1058        }
1059
1060        Ok(MutationResult {
1061            affected_nodes,
1062            affected_edges,
1063        })
1064    }
1065
1066    async fn execute_delete_edge(
1067        &mut self,
1068        type_name: &str,
1069        predicate: &IRMutationPredicate,
1070        params: &ParamMap,
1071    ) -> Result<MutationResult> {
1072        let pred_sql = predicate_to_sql(predicate, params, true)?;
1073
1074        let table_key = format!("edge:{}", type_name);
1075        let (mut ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
1076
1077        let delete_state = self
1078            .table_store()
1079            .delete_where(&full_path, &mut ds, &pred_sql)
1080            .await?;
1081        let affected = delete_state.deleted_rows;
1082
1083        if affected > 0 {
1084            self.commit_updates(&[crate::db::SubTableUpdate {
1085                table_key,
1086                table_version: delete_state.version,
1087                table_branch,
1088                row_count: delete_state.row_count,
1089                version_metadata: delete_state.version_metadata,
1090            }])
1091            .await?;
1092
1093            self.invalidate_graph_index().await;
1094        }
1095
1096        Ok(MutationResult {
1097            affected_nodes: 0,
1098            affected_edges: affected,
1099        })
1100    }
1101}
1102
1103fn enrich_mutation_params(params: &ParamMap) -> Result<ParamMap> {
1104    let mut resolved = params.clone();
1105    if !resolved.contains_key(NOW_PARAM_NAME) {
1106        let now = OffsetDateTime::now_utc()
1107            .format(&Rfc3339)
1108            .map_err(|e| OmniError::manifest(format!("failed to format now(): {}", e)))?;
1109        resolved.insert(NOW_PARAM_NAME.to_string(), Literal::DateTime(now));
1110    }
1111    Ok(resolved)
1112}