1use super::*;
2
3use super::query::literal_to_sql;
4
5fn resolve_expr_value(expr: &IRExpr, params: &ParamMap) -> Result<Literal> {
9 match expr {
10 IRExpr::Literal(lit) => Ok(lit.clone()),
11 IRExpr::Param(name) => params
12 .get(name)
13 .cloned()
14 .ok_or_else(|| OmniError::manifest(format!("parameter '{}' not provided", name))),
15 other => Err(OmniError::manifest(format!(
16 "unsupported expression in mutation: {:?}",
17 other
18 ))),
19 }
20}
21
22fn literal_to_typed_array(
24 lit: &Literal,
25 data_type: &DataType,
26 num_rows: usize,
27) -> Result<ArrayRef> {
28 Ok(match (lit, data_type) {
29 (Literal::Null, _) => arrow_array::new_null_array(data_type, num_rows),
30 (Literal::String(s), DataType::Utf8) => {
31 Arc::new(StringArray::from(vec![s.as_str(); num_rows])) as ArrayRef
32 }
33 (Literal::Integer(n), DataType::Int32) => {
34 Arc::new(Int32Array::from(vec![*n as i32; num_rows]))
35 }
36 (Literal::Integer(n), DataType::Int64) => Arc::new(Int64Array::from(vec![*n; num_rows])),
37 (Literal::Integer(n), DataType::UInt32) => {
38 Arc::new(UInt32Array::from(vec![*n as u32; num_rows]))
39 }
40 (Literal::Integer(n), DataType::UInt64) => {
41 Arc::new(UInt64Array::from(vec![*n as u64; num_rows]))
42 }
43 (Literal::Float(f), DataType::Float32) => {
44 Arc::new(Float32Array::from(vec![*f as f32; num_rows]))
45 }
46 (Literal::Float(f), DataType::Float64) => Arc::new(Float64Array::from(vec![*f; num_rows])),
47 (Literal::Bool(b), DataType::Boolean) => Arc::new(BooleanArray::from(vec![*b; num_rows])),
48 (Literal::Date(s), DataType::Date32) => {
49 let days = crate::loader::parse_date32_literal(s)?;
50 Arc::new(Date32Array::from(vec![days; num_rows]))
51 }
52 (Literal::DateTime(s), DataType::Date64) => Arc::new(Date64Array::from(vec![
53 crate::loader::parse_date64_literal(s)?;
54 num_rows
55 ])),
56 (Literal::List(items), DataType::List(field)) => {
57 typed_list_literal_to_array(items, field.data_type(), num_rows)?
58 }
59 (Literal::List(items), DataType::FixedSizeList(field, dim))
60 if field.data_type() == &DataType::Float32 =>
61 {
62 if items.len() != *dim as usize {
63 return Err(OmniError::manifest(format!(
64 "vector property expects {} dimensions, got {}",
65 dim,
66 items.len()
67 )));
68 }
69 let mut builder = FixedSizeListBuilder::with_capacity(
70 Float32Builder::with_capacity(num_rows * (*dim as usize)),
71 *dim,
72 num_rows,
73 )
74 .with_field(field.clone());
75 for _ in 0..num_rows {
76 for item in items {
77 match item {
78 Literal::Integer(value) => builder.values().append_value(*value as f32),
79 Literal::Float(value) => builder.values().append_value(*value as f32),
80 _ => {
81 return Err(OmniError::manifest(
82 "vector elements must be numeric".to_string(),
83 ));
84 }
85 }
86 }
87 builder.append(true);
88 }
89 Arc::new(builder.finish())
90 }
91 _ => {
92 return Err(OmniError::manifest(format!(
93 "cannot convert {:?} to {:?}",
94 lit, data_type
95 )));
96 }
97 })
98}
99
100fn typed_list_literal_to_array(
101 items: &[Literal],
102 item_type: &DataType,
103 num_rows: usize,
104) -> Result<ArrayRef> {
105 match item_type {
106 DataType::Utf8 => {
107 let mut builder = ListBuilder::new(StringBuilder::new());
108 for _ in 0..num_rows {
109 for item in items {
110 match item {
111 Literal::String(value) => builder.values().append_value(value),
112 _ => builder.values().append_null(),
113 }
114 }
115 builder.append(true);
116 }
117 Ok(Arc::new(builder.finish()))
118 }
119 DataType::Boolean => {
120 let mut builder = ListBuilder::new(BooleanBuilder::new());
121 for _ in 0..num_rows {
122 for item in items {
123 match item {
124 Literal::Bool(value) => builder.values().append_value(*value),
125 _ => builder.values().append_null(),
126 }
127 }
128 builder.append(true);
129 }
130 Ok(Arc::new(builder.finish()))
131 }
132 DataType::Int32 => {
133 let mut builder = ListBuilder::new(Int32Builder::new());
134 for _ in 0..num_rows {
135 for item in items {
136 match item {
137 Literal::Integer(value) => {
138 let value = i32::try_from(*value).map_err(|_| {
139 OmniError::manifest(format!(
140 "list value {} exceeds Int32 range",
141 value
142 ))
143 })?;
144 builder.values().append_value(value);
145 }
146 _ => builder.values().append_null(),
147 }
148 }
149 builder.append(true);
150 }
151 Ok(Arc::new(builder.finish()))
152 }
153 DataType::Int64 => {
154 let mut builder = ListBuilder::new(Int64Builder::new());
155 for _ in 0..num_rows {
156 for item in items {
157 match item {
158 Literal::Integer(value) => builder.values().append_value(*value),
159 _ => builder.values().append_null(),
160 }
161 }
162 builder.append(true);
163 }
164 Ok(Arc::new(builder.finish()))
165 }
166 DataType::UInt32 => {
167 let mut builder = ListBuilder::new(UInt32Builder::new());
168 for _ in 0..num_rows {
169 for item in items {
170 match item {
171 Literal::Integer(value) => {
172 let value = u32::try_from(*value).map_err(|_| {
173 OmniError::manifest(format!(
174 "list value {} exceeds UInt32 range",
175 value
176 ))
177 })?;
178 builder.values().append_value(value);
179 }
180 _ => builder.values().append_null(),
181 }
182 }
183 builder.append(true);
184 }
185 Ok(Arc::new(builder.finish()))
186 }
187 DataType::UInt64 => {
188 let mut builder = ListBuilder::new(UInt64Builder::new());
189 for _ in 0..num_rows {
190 for item in items {
191 match item {
192 Literal::Integer(value) => {
193 let value = u64::try_from(*value).map_err(|_| {
194 OmniError::manifest(format!(
195 "list value {} exceeds UInt64 range",
196 value
197 ))
198 })?;
199 builder.values().append_value(value);
200 }
201 _ => builder.values().append_null(),
202 }
203 }
204 builder.append(true);
205 }
206 Ok(Arc::new(builder.finish()))
207 }
208 DataType::Float32 => {
209 let mut builder = ListBuilder::new(Float32Builder::new());
210 for _ in 0..num_rows {
211 for item in items {
212 match item {
213 Literal::Integer(value) => builder.values().append_value(*value as f32),
214 Literal::Float(value) => builder.values().append_value(*value as f32),
215 _ => builder.values().append_null(),
216 }
217 }
218 builder.append(true);
219 }
220 Ok(Arc::new(builder.finish()))
221 }
222 DataType::Float64 => {
223 let mut builder = ListBuilder::new(Float64Builder::new());
224 for _ in 0..num_rows {
225 for item in items {
226 match item {
227 Literal::Integer(value) => builder.values().append_value(*value as f64),
228 Literal::Float(value) => builder.values().append_value(*value),
229 _ => builder.values().append_null(),
230 }
231 }
232 builder.append(true);
233 }
234 Ok(Arc::new(builder.finish()))
235 }
236 DataType::Date32 => {
237 let mut builder = ListBuilder::new(Date32Builder::new());
238 for _ in 0..num_rows {
239 for item in items {
240 match item {
241 Literal::Date(value) => builder
242 .values()
243 .append_value(crate::loader::parse_date32_literal(value)?),
244 _ => builder.values().append_null(),
245 }
246 }
247 builder.append(true);
248 }
249 Ok(Arc::new(builder.finish()))
250 }
251 DataType::Date64 => {
252 let mut builder = ListBuilder::new(Date64Builder::new());
253 for _ in 0..num_rows {
254 for item in items {
255 match item {
256 Literal::DateTime(value) => builder
257 .values()
258 .append_value(crate::loader::parse_date64_literal(value)?),
259 _ => builder.values().append_null(),
260 }
261 }
262 builder.append(true);
263 }
264 Ok(Arc::new(builder.finish()))
265 }
266 other => Err(OmniError::manifest(format!(
267 "cannot convert list literal to {:?}",
268 other
269 ))),
270 }
271}
272
273fn build_blob_array_from_value(value: &str) -> Result<ArrayRef> {
275 let mut builder = BlobArrayBuilder::new(1);
276 crate::loader::append_blob_value(&mut builder, value)?;
277 builder
278 .finish()
279 .map_err(|e| OmniError::Lance(e.to_string()))
280}
281
282fn build_null_blob_array() -> Result<ArrayRef> {
284 let mut builder = BlobArrayBuilder::new(1);
285 builder
286 .push_null()
287 .map_err(|e| OmniError::Lance(e.to_string()))?;
288 builder
289 .finish()
290 .map_err(|e| OmniError::Lance(e.to_string()))
291}
292
293fn build_insert_batch(
295 schema: &SchemaRef,
296 id: &str,
297 assignments: &HashMap<String, Literal>,
298 blob_properties: &HashSet<String>,
299) -> Result<RecordBatch> {
300 let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
301
302 for field in schema.fields() {
303 if field.name() == "id" {
304 columns.push(Arc::new(StringArray::from(vec![id])));
305 } else if blob_properties.contains(field.name()) {
306 if let Some(Literal::String(uri)) = assignments.get(field.name()) {
307 columns.push(build_blob_array_from_value(uri)?);
308 } else if field.is_nullable() {
309 columns.push(build_null_blob_array()?);
310 } else {
311 return Err(OmniError::manifest(format!(
312 "missing required blob property '{}'",
313 field.name()
314 )));
315 }
316 } else if field.name() == "src" {
317 let lit = assignments.get("from").ok_or_else(|| {
318 OmniError::manifest("missing required edge endpoint 'from'".to_string())
319 })?;
320 columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
321 } else if field.name() == "dst" {
322 let lit = assignments.get("to").ok_or_else(|| {
323 OmniError::manifest("missing required edge endpoint 'to'".to_string())
324 })?;
325 columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
326 } else if let Some(lit) = assignments.get(field.name()) {
327 columns.push(literal_to_typed_array(lit, field.data_type(), 1)?);
328 } else if field.is_nullable() {
329 columns.push(arrow_array::new_null_array(field.data_type(), 1));
330 } else {
331 return Err(OmniError::manifest(format!(
332 "missing required property '{}'",
333 field.name()
334 )));
335 }
336 }
337
338 RecordBatch::try_new(schema.clone(), columns).map_err(|e| OmniError::Lance(e.to_string()))
339}
340
341async fn validate_edge_insert_endpoints(
342 db: &Omnigraph,
343 edge_name: &str,
344 assignments: &HashMap<String, Literal>,
345) -> Result<()> {
346 let edge_type = db
347 .catalog()
348 .edge_types
349 .get(edge_name)
350 .ok_or_else(|| OmniError::manifest(format!("unknown edge type '{}'", edge_name)))?;
351 let from = match assignments.get("from") {
352 Some(Literal::String(value)) => value.as_str(),
353 Some(other) => {
354 return Err(OmniError::manifest(format!(
355 "edge {} from endpoint must be a string id, got {}",
356 edge_name,
357 literal_to_sql(other)
358 )));
359 }
360 None => {
361 return Err(OmniError::manifest(format!(
362 "edge {} missing 'from' endpoint",
363 edge_name
364 )));
365 }
366 };
367 let to = match assignments.get("to") {
368 Some(Literal::String(value)) => value.as_str(),
369 Some(other) => {
370 return Err(OmniError::manifest(format!(
371 "edge {} to endpoint must be a string id, got {}",
372 edge_name,
373 literal_to_sql(other)
374 )));
375 }
376 None => {
377 return Err(OmniError::manifest(format!(
378 "edge {} missing 'to' endpoint",
379 edge_name
380 )));
381 }
382 };
383
384 ensure_node_id_exists(db, &edge_type.from_type, from, "src").await?;
385 ensure_node_id_exists(db, &edge_type.to_type, to, "dst").await?;
386 Ok(())
387}
388
389async fn ensure_node_id_exists(
390 db: &Omnigraph,
391 node_type: &str,
392 id: &str,
393 label: &str,
394) -> Result<()> {
395 let snapshot = db.snapshot();
396 let table_key = format!("node:{}", node_type);
397 let ds = snapshot.open(&table_key).await?;
398 let filter = format!("id = '{}'", id.replace('\'', "''"));
399 let exists = ds
400 .count_rows(Some(filter))
401 .await
402 .map_err(|e| OmniError::Lance(e.to_string()))?
403 > 0;
404 if exists {
405 Ok(())
406 } else {
407 Err(OmniError::manifest(format!(
408 "{} '{}' not found in {}",
409 label, id, node_type
410 )))
411 }
412}
413
414fn predicate_to_sql(
416 predicate: &IRMutationPredicate,
417 params: &ParamMap,
418 is_edge: bool,
419) -> Result<String> {
420 let column = if is_edge {
421 match predicate.property.as_str() {
422 "from" => "src".to_string(),
423 "to" => "dst".to_string(),
424 other => other.to_string(),
425 }
426 } else {
427 predicate.property.clone()
428 };
429
430 let value = resolve_expr_value(&predicate.value, params)?;
431 let value_sql = literal_to_sql(&value);
432
433 let op = match predicate.op {
434 CompOp::Eq => "=",
435 CompOp::Ne => "!=",
436 CompOp::Gt => ">",
437 CompOp::Lt => "<",
438 CompOp::Ge => ">=",
439 CompOp::Le => "<=",
440 CompOp::Contains => {
441 return Err(OmniError::manifest(
442 "contains predicate not supported in mutations".to_string(),
443 ));
444 }
445 };
446
447 Ok(format!("{} {} {}", column, op, value_sql))
448}
449
450fn apply_assignments(
456 full_schema: &SchemaRef,
457 batch: &RecordBatch,
458 assignments: &HashMap<String, Literal>,
459 blob_properties: &HashSet<String>,
460) -> Result<RecordBatch> {
461 let mut columns: Vec<ArrayRef> = Vec::with_capacity(full_schema.fields().len());
462 let mut out_fields: Vec<Field> = Vec::with_capacity(full_schema.fields().len());
463
464 for field in full_schema.fields().iter() {
465 if blob_properties.contains(field.name()) {
466 if let Some(Literal::String(uri)) = assignments.get(field.name()) {
472 let mut builder = BlobArrayBuilder::new(batch.num_rows());
473 for _ in 0..batch.num_rows() {
474 crate::loader::append_blob_value(&mut builder, uri)?;
475 }
476 let blob_field = lance::blob::blob_field(field.name(), true);
477 out_fields.push(blob_field);
478 columns.push(
479 builder
480 .finish()
481 .map_err(|e| OmniError::Lance(e.to_string()))?,
482 );
483 }
484 } else if let Some(lit) = assignments.get(field.name()) {
486 out_fields.push(field.as_ref().clone());
487 columns.push(literal_to_typed_array(
488 lit,
489 field.data_type(),
490 batch.num_rows(),
491 )?);
492 } else {
493 let col = batch.column_by_name(field.name()).ok_or_else(|| {
494 OmniError::Lance(format!(
495 "column '{}' not found in scan result",
496 field.name()
497 ))
498 })?;
499 out_fields.push(field.as_ref().clone());
500 columns.push(col.clone());
501 }
502 }
503
504 RecordBatch::try_new(Arc::new(Schema::new(out_fields)), columns)
505 .map_err(|e| OmniError::Lance(e.to_string()))
506}
507
508impl Omnigraph {
511 pub async fn mutate(
512 &mut self,
513 branch: &str,
514 query_source: &str,
515 query_name: &str,
516 params: &ParamMap,
517 ) -> Result<MutationResult> {
518 self.mutate_as(branch, query_source, query_name, params, None)
519 .await
520 }
521
522 pub async fn mutate_as(
523 &mut self,
524 branch: &str,
525 query_source: &str,
526 query_name: &str,
527 params: &ParamMap,
528 actor_id: Option<&str>,
529 ) -> Result<MutationResult> {
530 let previous_actor = self.audit_actor_id.clone();
531 self.audit_actor_id = actor_id.map(str::to_string);
532 let result = self
533 .mutate_with_current_actor(branch, query_source, query_name, params)
534 .await;
535 self.audit_actor_id = previous_actor;
536 result
537 }
538
539 async fn mutate_with_current_actor(
540 &mut self,
541 branch: &str,
542 query_source: &str,
543 query_name: &str,
544 params: &ParamMap,
545 ) -> Result<MutationResult> {
546 self.ensure_schema_state_valid().await?;
547 let requested = Self::normalize_branch_name(branch)?;
548 let resolved_params = enrich_mutation_params(params)?;
549 let operation = format!(
550 "mutation:{}:branch={}",
551 query_name,
552 requested.as_deref().unwrap_or("main")
553 );
554
555 if requested.as_deref().is_some_and(is_internal_run_branch) {
556 return self
557 .execute_named_mutation_on_branch(
558 requested.as_deref(),
559 query_source,
560 query_name,
561 &resolved_params,
562 )
563 .await;
564 }
565
566 let target_branch = requested.clone().unwrap_or_else(|| "main".to_string());
567 let target_head_before = self.latest_branch_snapshot_id(&target_branch).await?;
568 let run = self
569 .begin_run(&target_branch, Some(operation.as_str()))
570 .await?;
571
572 let staged_result = match self
573 .execute_named_mutation_on_branch(
574 Some(run.run_branch.as_str()),
575 query_source,
576 query_name,
577 &resolved_params,
578 )
579 .await
580 {
581 Ok(result) => result,
582 Err(err) => {
583 let _ = self.fail_run(&run.run_id).await;
584 return Err(err);
585 }
586 };
587
588 let target_head_now = self.latest_branch_snapshot_id(&target_branch).await?;
589 if target_head_now.as_str() != target_head_before.as_str() {
590 let _ = self.fail_run(&run.run_id).await;
591 return Err(OmniError::manifest_conflict(format!(
592 "target branch '{}' advanced during transactional mutation; retry",
593 target_branch
594 )));
595 }
596
597 if let Err(err) = self.publish_run(&run.run_id).await {
598 let _ = self.fail_run(&run.run_id).await;
599 return Err(err);
600 }
601
602 Ok(staged_result)
603 }
604
605 async fn execute_named_mutation_on_branch(
606 &mut self,
607 branch: Option<&str>,
608 query_source: &str,
609 query_name: &str,
610 params: &ParamMap,
611 ) -> Result<MutationResult> {
612 let requested = match branch {
613 Some(branch) => Self::normalize_branch_name(branch)?,
614 None => None,
615 };
616 let current = self.active_branch().map(str::to_string);
617 if requested == current {
618 return self
619 .execute_named_mutation(query_source, query_name, params)
620 .await;
621 }
622
623 let previous = self
624 .swap_coordinator_for_branch(requested.as_deref())
625 .await?;
626 let result = self
627 .execute_named_mutation(query_source, query_name, params)
628 .await;
629 self.restore_coordinator(previous);
630 result
631 }
632
633 async fn execute_named_mutation(
634 &mut self,
635 query_source: &str,
636 query_name: &str,
637 params: &ParamMap,
638 ) -> Result<MutationResult> {
639 let query_decl = omnigraph_compiler::find_named_query(query_source, query_name)
640 .map_err(|e| OmniError::manifest(e.to_string()))?;
641
642 let checked = typecheck_query_decl(self.catalog(), &query_decl)?;
643 match checked {
644 CheckedQuery::Mutation(_) => {}
645 CheckedQuery::Read(_) => {
646 return Err(OmniError::manifest(
647 "mutation execution called on a read query; use query instead".to_string(),
648 ));
649 }
650 }
651
652 let ir = lower_mutation_query(&query_decl)?;
653
654 let mut total = MutationResult::default();
655 for op in &ir.ops {
656 let result = match op {
657 MutationOpIR::Insert {
658 type_name,
659 assignments,
660 } => self.execute_insert(type_name, assignments, params).await?,
661 MutationOpIR::Update {
662 type_name,
663 assignments,
664 predicate,
665 } => {
666 self.execute_update(type_name, assignments, predicate, params)
667 .await?
668 }
669 MutationOpIR::Delete {
670 type_name,
671 predicate,
672 } => self.execute_delete(type_name, predicate, params).await?,
673 };
674 total.affected_nodes += result.affected_nodes;
675 total.affected_edges += result.affected_edges;
676 }
677 Ok(total)
678 }
679
680 async fn execute_insert(
681 &mut self,
682 type_name: &str,
683 assignments: &[IRAssignment],
684 params: &ParamMap,
685 ) -> Result<MutationResult> {
686 let mut resolved: HashMap<String, Literal> = HashMap::new();
687 for a in assignments {
688 resolved.insert(a.property.clone(), resolve_expr_value(&a.value, params)?);
689 }
690
691 let is_node = self.catalog().node_types.contains_key(type_name);
692 let is_edge = self.catalog().edge_types.contains_key(type_name);
693
694 if is_node {
695 let node_type = &self.catalog().node_types[type_name];
696 let schema = node_type.arrow_schema.clone();
697 let blob_props = node_type.blob_properties.clone();
698 let id = if let Some(key_prop) = node_type.key_property() {
699 match resolved.get(key_prop) {
700 Some(Literal::String(s)) => s.clone(),
701 Some(other) => literal_to_sql(other).trim_matches('\'').to_string(),
702 None => {
703 return Err(OmniError::manifest(format!(
704 "insert missing @key property '{}'",
705 key_prop
706 )));
707 }
708 }
709 } else {
710 ulid::Ulid::new().to_string()
711 };
712
713 let batch = build_insert_batch(&schema, &id, &resolved, &blob_props)?;
714 crate::loader::validate_value_constraints(&batch, node_type)?;
715 let has_key = node_type.key_property().is_some();
716 let (state, table_branch) = if has_key {
717 self.upsert_batch(type_name, true, schema, batch).await?
718 } else {
719 self.append_batch(type_name, true, schema, batch).await?
720 };
721
722 let table_key = format!("node:{}", type_name);
723 self.commit_updates(&[crate::db::SubTableUpdate {
724 table_key,
725 table_version: state.version,
726 table_branch,
727 row_count: state.row_count,
728 version_metadata: state.version_metadata,
729 }])
730 .await?;
731
732 Ok(MutationResult {
733 affected_nodes: 1,
734 affected_edges: 0,
735 })
736 } else if is_edge {
737 let edge_type = &self.catalog().edge_types[type_name];
738 let schema = edge_type.arrow_schema.clone();
739 let blob_props = edge_type.blob_properties.clone();
740 let id = ulid::Ulid::new().to_string();
741
742 let batch = build_insert_batch(&schema, &id, &resolved, &blob_props)?;
743 validate_edge_insert_endpoints(self, type_name, &resolved).await?;
744 let (state, table_branch) = self.append_batch(type_name, false, schema, batch).await?;
745
746 let table_key = format!("edge:{}", type_name);
747 self.commit_updates(&[crate::db::SubTableUpdate {
748 table_key,
749 table_version: state.version,
750 table_branch,
751 row_count: state.row_count,
752 version_metadata: state.version_metadata,
753 }])
754 .await?;
755
756 self.invalidate_graph_index().await;
757
758 Ok(MutationResult {
759 affected_nodes: 0,
760 affected_edges: 1,
761 })
762 } else {
763 Err(OmniError::manifest(format!("unknown type '{}'", type_name)))
764 }
765 }
766
767 async fn append_batch(
769 &self,
770 type_name: &str,
771 is_node: bool,
772 _schema: SchemaRef,
773 batch: RecordBatch,
774 ) -> Result<(crate::table_store::TableState, Option<String>)> {
775 let table_key = if is_node {
776 format!("node:{}", type_name)
777 } else {
778 format!("edge:{}", type_name)
779 };
780 let (mut ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
781 let state = self
782 .table_store()
783 .append_batch(&full_path, &mut ds, batch)
784 .await?;
785 Ok((state, table_branch))
786 }
787
788 async fn upsert_batch(
791 &self,
792 type_name: &str,
793 is_node: bool,
794 _schema: SchemaRef,
795 batch: RecordBatch,
796 ) -> Result<(crate::table_store::TableState, Option<String>)> {
797 let table_key = if is_node {
798 format!("node:{}", type_name)
799 } else {
800 format!("edge:{}", type_name)
801 };
802 let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
803 let state = self
804 .table_store()
805 .merge_insert_batch(
806 &full_path,
807 ds,
808 batch,
809 vec!["id".to_string()],
810 lance::dataset::WhenMatched::UpdateAll,
811 lance::dataset::WhenNotMatched::InsertAll,
812 )
813 .await?;
814 Ok((state, table_branch))
815 }
816
817 async fn execute_update(
818 &mut self,
819 type_name: &str,
820 assignments: &[IRAssignment],
821 predicate: &IRMutationPredicate,
822 params: &ParamMap,
823 ) -> Result<MutationResult> {
824 if !self.catalog().node_types.contains_key(type_name) {
826 return Err(OmniError::manifest(format!(
827 "update is only supported for node types, not '{}'",
828 type_name
829 )));
830 }
831
832 if let Some(key_prop) = self.catalog().node_types[type_name].key_property() {
834 if assignments.iter().any(|a| a.property == key_prop) {
835 return Err(OmniError::manifest(format!(
836 "cannot update @key property '{}' — delete and re-insert instead",
837 key_prop
838 )));
839 }
840 }
841
842 let pred_sql = predicate_to_sql(predicate, params, false)?;
843 let schema = self.catalog().node_types[type_name].arrow_schema.clone();
844 let blob_props = self.catalog().node_types[type_name].blob_properties.clone();
845
846 let table_key = format!("node:{}", type_name);
847 let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
848 let initial_version = ds.version().version;
849
850 let non_blob_cols: Vec<&str> = schema
851 .fields()
852 .iter()
853 .filter(|f| !blob_props.contains(f.name()))
854 .map(|f| f.name().as_str())
855 .collect();
856 let batches = self
857 .table_store()
858 .scan(
859 &ds,
860 (!blob_props.is_empty()).then_some(non_blob_cols.as_slice()),
861 Some(&pred_sql),
862 None,
863 )
864 .await?;
865
866 if batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0) {
867 return Ok(MutationResult {
868 affected_nodes: 0,
869 affected_edges: 0,
870 });
871 }
872
873 let matched = if batches.len() == 1 {
874 batches.into_iter().next().unwrap()
875 } else {
876 let s = batches[0].schema();
877 arrow_select::concat::concat_batches(&s, &batches)
878 .map_err(|e| OmniError::Lance(e.to_string()))?
879 };
880
881 let affected_count = matched.num_rows();
882
883 let mut resolved: HashMap<String, Literal> = HashMap::new();
884 for a in assignments {
885 resolved.insert(a.property.clone(), resolve_expr_value(&a.value, params)?);
886 }
887 let updated = apply_assignments(&schema, &matched, &resolved, &blob_props)?;
888 crate::loader::validate_value_constraints(&updated, &self.catalog().node_types[type_name])?;
889
890 let ds = self
893 .reopen_for_mutation(
894 &table_key,
895 &full_path,
896 table_branch.as_deref(),
897 initial_version,
898 )
899 .await?;
900 let update_state = self
901 .table_store()
902 .merge_insert_batch(
903 &full_path,
904 ds,
905 updated,
906 vec!["id".to_string()],
907 lance::dataset::WhenMatched::UpdateAll,
908 lance::dataset::WhenNotMatched::DoNothing,
909 )
910 .await?;
911
912 self.commit_updates(&[crate::db::SubTableUpdate {
913 table_key,
914 table_version: update_state.version,
915 table_branch,
916 row_count: update_state.row_count,
917 version_metadata: update_state.version_metadata,
918 }])
919 .await?;
920
921 Ok(MutationResult {
922 affected_nodes: affected_count,
923 affected_edges: 0,
924 })
925 }
926
927 async fn execute_delete(
928 &mut self,
929 type_name: &str,
930 predicate: &IRMutationPredicate,
931 params: &ParamMap,
932 ) -> Result<MutationResult> {
933 let is_node = self.catalog().node_types.contains_key(type_name);
934 if is_node {
935 self.execute_delete_node(type_name, predicate, params).await
936 } else {
937 self.execute_delete_edge(type_name, predicate, params).await
938 }
939 }
940
941 async fn execute_delete_node(
942 &mut self,
943 type_name: &str,
944 predicate: &IRMutationPredicate,
945 params: &ParamMap,
946 ) -> Result<MutationResult> {
947 let pred_sql = predicate_to_sql(predicate, params, false)?;
948
949 let table_key = format!("node:{}", type_name);
950 let (ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
951 let initial_version = ds.version().version;
952
953 let batches = self
955 .table_store()
956 .scan(&ds, Some(&["id"]), Some(&pred_sql), None)
957 .await?;
958
959 let deleted_ids: Vec<String> = batches
960 .iter()
961 .flat_map(|batch| {
962 let ids = batch
963 .column(0)
964 .as_any()
965 .downcast_ref::<StringArray>()
966 .unwrap();
967 (0..ids.len())
968 .map(|i| ids.value(i).to_string())
969 .collect::<Vec<_>>()
970 })
971 .collect();
972
973 if deleted_ids.is_empty() {
974 return Ok(MutationResult {
975 affected_nodes: 0,
976 affected_edges: 0,
977 });
978 }
979
980 let affected_nodes = deleted_ids.len();
981
982 let mut ds = self
985 .reopen_for_mutation(
986 &table_key,
987 &full_path,
988 table_branch.as_deref(),
989 initial_version,
990 )
991 .await?;
992 let delete_state = self
993 .table_store()
994 .delete_where(&full_path, &mut ds, &pred_sql)
995 .await?;
996
997 let mut updates = vec![crate::db::SubTableUpdate {
998 table_key,
999 table_version: delete_state.version,
1000 table_branch: table_branch.clone(),
1001 row_count: delete_state.row_count,
1002 version_metadata: delete_state.version_metadata,
1003 }];
1004
1005 let mut affected_edges = 0usize;
1006 let escaped: Vec<String> = deleted_ids
1007 .iter()
1008 .map(|id| format!("'{}'", id.replace('\'', "''")))
1009 .collect();
1010 let id_list = escaped.join(", ");
1011
1012 let edge_info: Vec<(String, String, String)> = self
1013 .catalog()
1014 .edge_types
1015 .iter()
1016 .map(|(name, et)| (name.clone(), et.from_type.clone(), et.to_type.clone()))
1017 .collect();
1018
1019 for (edge_name, from_type, to_type) in &edge_info {
1020 let mut cascade_filters = Vec::new();
1021 if from_type == type_name {
1022 cascade_filters.push(format!("src IN ({})", id_list));
1023 }
1024 if to_type == type_name {
1025 cascade_filters.push(format!("dst IN ({})", id_list));
1026 }
1027 if cascade_filters.is_empty() {
1028 continue;
1029 }
1030
1031 let edge_table_key = format!("edge:{}", edge_name);
1032 let cascade_filter = cascade_filters.join(" OR ");
1033 let (mut edge_ds, edge_full_path, edge_table_branch) =
1034 self.open_for_mutation(&edge_table_key).await?;
1035
1036 let edge_delete = self
1037 .table_store()
1038 .delete_where(&edge_full_path, &mut edge_ds, &cascade_filter)
1039 .await?;
1040
1041 affected_edges += edge_delete.deleted_rows;
1042
1043 if edge_delete.deleted_rows > 0 {
1044 updates.push(crate::db::SubTableUpdate {
1045 table_key: edge_table_key,
1046 table_version: edge_delete.version,
1047 table_branch: edge_table_branch,
1048 row_count: edge_delete.row_count,
1049 version_metadata: edge_delete.version_metadata,
1050 });
1051 }
1052 }
1053
1054 self.commit_updates(&updates).await?;
1055
1056 if affected_edges > 0 {
1057 self.invalidate_graph_index().await;
1058 }
1059
1060 Ok(MutationResult {
1061 affected_nodes,
1062 affected_edges,
1063 })
1064 }
1065
1066 async fn execute_delete_edge(
1067 &mut self,
1068 type_name: &str,
1069 predicate: &IRMutationPredicate,
1070 params: &ParamMap,
1071 ) -> Result<MutationResult> {
1072 let pred_sql = predicate_to_sql(predicate, params, true)?;
1073
1074 let table_key = format!("edge:{}", type_name);
1075 let (mut ds, full_path, table_branch) = self.open_for_mutation(&table_key).await?;
1076
1077 let delete_state = self
1078 .table_store()
1079 .delete_where(&full_path, &mut ds, &pred_sql)
1080 .await?;
1081 let affected = delete_state.deleted_rows;
1082
1083 if affected > 0 {
1084 self.commit_updates(&[crate::db::SubTableUpdate {
1085 table_key,
1086 table_version: delete_state.version,
1087 table_branch,
1088 row_count: delete_state.row_count,
1089 version_metadata: delete_state.version_metadata,
1090 }])
1091 .await?;
1092
1093 self.invalidate_graph_index().await;
1094 }
1095
1096 Ok(MutationResult {
1097 affected_nodes: 0,
1098 affected_edges: affected,
1099 })
1100 }
1101}
1102
1103fn enrich_mutation_params(params: &ParamMap) -> Result<ParamMap> {
1104 let mut resolved = params.clone();
1105 if !resolved.contains_key(NOW_PARAM_NAME) {
1106 let now = OffsetDateTime::now_utc()
1107 .format(&Rfc3339)
1108 .map_err(|e| OmniError::manifest(format!("failed to format now(): {}", e)))?;
1109 resolved.insert(NOW_PARAM_NAME.to_string(), Literal::DateTime(now));
1110 }
1111 Ok(resolved)
1112}