1use std::sync::{Arc, RwLock};
4
5use std::collections::HashMap;
6
7use std::time::Instant;
8
9use kyu_binder::{
10 BindContext, BoundMatchClause, BoundNodePattern, BoundPatternElement, BoundQuery,
11 BoundReadingClause, BoundUpdatingClause, Binder, BoundStatement,
12};
13use kyu_catalog::{Catalog, NodeTableEntry, Property, RelTableEntry};
14use kyu_common::id::TableId;
15use kyu_common::{KyuError, KyuResult};
16use kyu_delta::{DeltaBatch, DeltaStats, GraphDelta};
17use kyu_executor::{ExecutionContext, QueryResult, Storage, execute};
18use kyu_expression::{FunctionRegistry, evaluate, evaluate_constant};
19use kyu_planner::{build_query_plan, optimize, resolve_properties};
20use kyu_transaction::{Checkpointer, TransactionManager, TransactionType, Wal};
21use kyu_types::{LogicalType, TypedValue};
22use smol_str::SmolStr;
23
24use crate::storage::NodeGroupStorage;
25
26pub struct Connection {
33 catalog: Arc<Catalog>,
34 storage: Arc<RwLock<NodeGroupStorage>>,
35 txn_mgr: Arc<TransactionManager>,
36 wal: Arc<Wal>,
37 checkpointer: Arc<Checkpointer>,
38 extensions: Arc<Vec<Box<dyn kyu_extension::Extension>>>,
39}
40
41impl Connection {
42 pub(crate) fn new(
43 catalog: Arc<Catalog>,
44 storage: Arc<RwLock<NodeGroupStorage>>,
45 txn_mgr: Arc<TransactionManager>,
46 wal: Arc<Wal>,
47 checkpointer: Arc<Checkpointer>,
48 extensions: Arc<Vec<Box<dyn kyu_extension::Extension>>>,
49 ) -> Self {
50 Self { catalog, storage, txn_mgr, wal, checkpointer, extensions }
51 }
52
53 pub fn query(&self, cypher: &str) -> KyuResult<QueryResult> {
59 self.query_internal(cypher, BindContext::empty())
60 }
61
62 pub fn query_with_params(
79 &self,
80 cypher: &str,
81 params: HashMap<String, TypedValue>,
82 ) -> KyuResult<QueryResult> {
83 let ctx = BindContext {
84 params: params
85 .into_iter()
86 .map(|(k, v)| (SmolStr::new(k), v))
87 .collect(),
88 env: HashMap::new(),
89 };
90 self.query_internal(cypher, ctx)
91 }
92
93 pub fn execute(
116 &self,
117 cypher: &str,
118 params: HashMap<String, TypedValue>,
119 env: HashMap<String, TypedValue>,
120 ) -> KyuResult<QueryResult> {
121 let ctx = BindContext {
122 params: params
123 .into_iter()
124 .map(|(k, v)| (SmolStr::new(k), v))
125 .collect(),
126 env: env
127 .into_iter()
128 .map(|(k, v)| (SmolStr::new(k), v))
129 .collect(),
130 };
131 self.query_internal(cypher, ctx)
132 }
133
134 fn query_internal(&self, cypher: &str, ctx: BindContext) -> KyuResult<QueryResult> {
135 if cypher.trim().eq_ignore_ascii_case("CHECKPOINT")
137 || cypher.trim().eq_ignore_ascii_case("CHECKPOINT;")
138 {
139 self.checkpointer.checkpoint().map_err(|e| {
140 KyuError::Transaction(format!("checkpoint failed: {e}"))
141 })?;
142 return Ok(QueryResult::new(vec![], vec![]));
143 }
144
145 if let Some(result) = self.try_call_extension(cypher)? {
147 return Ok(result);
148 }
149
150 let parse_result = kyu_parser::parse(cypher);
152 let stmt = parse_result
153 .ast
154 .ok_or_else(|| KyuError::Parser(format!("{:?}", parse_result.errors)))?;
155
156 let catalog_snapshot = self.catalog.read();
158 let mut binder = Binder::new(catalog_snapshot, FunctionRegistry::with_builtins())
159 .with_context(ctx);
160 let bound = binder.bind(&stmt)?;
161
162 let is_ddl = matches!(
164 &bound,
165 BoundStatement::CreateNodeTable(_)
166 | BoundStatement::CreateRelTable(_)
167 | BoundStatement::Drop(_)
168 );
169 let is_write = match &bound {
170 BoundStatement::Query(q) => self.is_standalone_dml(q) || self.has_match_mutations(q),
171 BoundStatement::CopyFrom(_) => true,
172 _ => is_ddl,
173 };
174
175 let txn_type = if is_write { TransactionType::Write } else { TransactionType::ReadOnly };
177 let mut txn = self.txn_mgr.begin(txn_type).map_err(|e| {
178 KyuError::Transaction(e.to_string())
179 })?;
180
181 let result = self.execute_bound(bound);
183
184 match &result {
186 Ok(_) => {
187 if is_ddl {
189 let snapshot = self.catalog.read().serialize_json();
190 txn.log_catalog_snapshot(snapshot.into_bytes());
191 }
192 self.txn_mgr.commit(&mut txn, &self.wal, |_, _| {}).map_err(|e| {
193 KyuError::Transaction(e.to_string())
194 })?;
195 if is_write {
197 let _ = self.checkpointer.try_checkpoint();
198 }
199 }
200 Err(_) => {
201 let _ = self.txn_mgr.rollback(&mut txn, |_| {});
202 }
203 }
204
205 result
206 }
207
208 fn execute_bound(&self, bound: BoundStatement) -> KyuResult<QueryResult> {
210 match bound {
211 BoundStatement::Query(query) => {
212 if self.is_standalone_dml(&query) {
213 return self.exec_dml(&query);
214 }
215 if self.has_match_mutations(&query) {
216 return self.exec_match_dml(&query);
217 }
218 let catalog_snapshot = self.catalog.read();
219 let plan = build_query_plan(&query, &catalog_snapshot)?;
220 let plan = optimize(plan, &catalog_snapshot);
221 let storage_guard = self.storage.read().unwrap();
222 let ctx = ExecutionContext::new(catalog_snapshot, &*storage_guard);
223 execute(&plan, &query.output_schema, &ctx)
224 }
225 BoundStatement::CreateNodeTable(create) => self.exec_create_node_table(&create),
226 BoundStatement::CreateRelTable(create) => self.exec_create_rel_table(&create),
227 BoundStatement::Drop(drop) => self.exec_drop(&drop),
228 BoundStatement::CopyFrom(copy) => self.exec_copy_from(©),
229 _ => Err(KyuError::NotImplemented(
230 "statement type not yet supported".into(),
231 )),
232 }
233 }
234
235 fn is_standalone_dml(&self, query: &BoundQuery) -> bool {
239 query.parts.iter().all(|part| {
240 part.reading_clauses.is_empty() && !part.updating_clauses.is_empty()
241 })
242 }
243
244 fn exec_dml(&self, query: &BoundQuery) -> KyuResult<QueryResult> {
246 let catalog_snapshot = self.catalog.read();
247
248 for part in &query.parts {
249 let mut created_nodes: Vec<(Option<u32>, TableId, Vec<TypedValue>)> = Vec::new();
251
252 for clause in &part.updating_clauses {
253 match clause {
254 BoundUpdatingClause::Create(patterns) => {
255 for pattern in patterns {
256 for element in &pattern.elements {
257 match element {
258 BoundPatternElement::Node(node) => {
259 let values =
260 self.exec_create_node(node, &catalog_snapshot)?;
261 created_nodes.push((
262 node.variable_index,
263 node.table_id,
264 values,
265 ));
266 }
267 BoundPatternElement::Relationship(_rel) => {
268 return Err(KyuError::NotImplemented(
269 "CREATE relationship not yet supported".into(),
270 ));
271 }
272 }
273 }
274 }
275 }
276 BoundUpdatingClause::Set(_) => {
277 return Err(KyuError::NotImplemented(
278 "standalone SET without MATCH".into(),
279 ));
280 }
281 BoundUpdatingClause::Delete(_) => {
282 return Err(KyuError::NotImplemented(
283 "standalone DELETE without MATCH".into(),
284 ));
285 }
286 }
287 }
288
289 if let Some(ref proj) = part.projection {
291 let mut prop_map: HashMap<(u32, SmolStr), u32> = HashMap::new();
292 let mut combined_values: Vec<TypedValue> = Vec::new();
293 let mut offset = 0u32;
294
295 for (var_idx, table_id, values) in &created_nodes {
296 if let Some(entry) = catalog_snapshot.find_by_id(*table_id) {
297 if let Some(vi) = var_idx {
298 for (i, prop) in entry.properties().iter().enumerate() {
299 prop_map.insert((*vi, prop.name.clone()), offset + i as u32);
300 }
301 }
302 offset += entry.properties().len() as u32;
303 }
304 combined_values.extend(values.iter().cloned());
305 }
306
307 let col_names: Vec<SmolStr> =
308 proj.items.iter().map(|item| item.alias.clone()).collect();
309 let col_types: Vec<LogicalType> = proj
310 .items
311 .iter()
312 .map(|item| item.expression.result_type().clone())
313 .collect();
314
315 let mut row: Vec<TypedValue> = Vec::with_capacity(proj.items.len());
316 for item in &proj.items {
317 let resolved = resolve_properties(&item.expression, &prop_map);
318 let value = evaluate(&resolved, combined_values.as_slice())?;
319 row.push(value);
320 }
321
322 let mut result = QueryResult::new(col_names, col_types);
323 result.push_row(row);
324 return Ok(result);
325 }
326 }
327
328 Ok(QueryResult::new(vec![], vec![]))
329 }
330
331 fn exec_create_node(
334 &self,
335 node: &BoundNodePattern,
336 catalog: &kyu_catalog::CatalogContent,
337 ) -> KyuResult<Vec<TypedValue>> {
338 let entry = catalog.find_by_id(node.table_id).ok_or_else(|| {
339 KyuError::Catalog(format!("table {:?} not found", node.table_id))
340 })?;
341 let properties = entry.properties();
342
343 let mut values = Vec::with_capacity(properties.len());
344 for prop in properties {
345 let value = if let Some((_pid, expr)) =
346 node.properties.iter().find(|(pid, _)| *pid == prop.id)
347 {
348 evaluate_constant(expr)?
349 } else {
350 TypedValue::Null
351 };
352 values.push(value);
353 }
354
355 self.storage
356 .write()
357 .unwrap()
358 .insert_row(node.table_id, &values)?;
359
360 Ok(values)
361 }
362
363 fn has_match_mutations(&self, query: &BoundQuery) -> bool {
365 query.parts.iter().any(|part| {
366 !part.reading_clauses.is_empty() && !part.updating_clauses.is_empty()
367 })
368 }
369
370 fn exec_match_dml(&self, query: &BoundQuery) -> KyuResult<QueryResult> {
372 let catalog_snapshot = self.catalog.read();
373
374 for part in &query.parts {
375 let match_clause = part
376 .reading_clauses
377 .iter()
378 .find_map(|c| match c {
379 BoundReadingClause::Match(m) => Some(m),
380 _ => None,
381 })
382 .ok_or_else(|| {
383 KyuError::NotImplemented("MATCH...SET/DELETE requires a MATCH clause".into())
384 })?;
385
386 let (table_id, var_idx) = self.extract_match_node(match_clause)?;
388
389 let entry = catalog_snapshot.find_by_id(table_id).ok_or_else(|| {
391 KyuError::Catalog(format!("table {:?} not found", table_id))
392 })?;
393 let properties = entry.properties();
394 let prop_map: HashMap<(u32, SmolStr), u32> = properties
395 .iter()
396 .enumerate()
397 .filter_map(|(i, p)| var_idx.map(|vi| ((vi, p.name.clone()), i as u32)))
398 .collect();
399
400 let resolved_where = match_clause
402 .where_clause
403 .as_ref()
404 .map(|w| resolve_properties(w, &prop_map));
405
406 let rows = self.storage.read().unwrap().scan_rows(table_id)?;
408
409 let mut set_mutations: Vec<(u64, usize, TypedValue)> = Vec::new();
410 let mut delete_rows: Vec<u64> = Vec::new();
411
412 for (row_idx, row_values) in &rows {
413 if let Some(ref pred) = resolved_where {
415 let result = evaluate(pred, row_values.as_slice())?;
416 if result != TypedValue::Bool(true) {
417 continue;
418 }
419 }
420
421 for clause in &part.updating_clauses {
423 match clause {
424 BoundUpdatingClause::Set(items) => {
425 for item in items {
426 let resolved_value =
427 resolve_properties(&item.value, &prop_map);
428 let new_value =
429 evaluate(&resolved_value, row_values.as_slice())?;
430 let col_idx = properties
431 .iter()
432 .position(|p| p.id == item.property_id)
433 .ok_or_else(|| {
434 KyuError::Storage(format!(
435 "property {:?} not found",
436 item.property_id
437 ))
438 })?;
439 set_mutations.push((*row_idx, col_idx, new_value));
440 }
441 }
442 BoundUpdatingClause::Delete(_) => {
443 delete_rows.push(*row_idx);
444 }
445 _ => {}
446 }
447 }
448 }
449
450 let mut storage = self.storage.write().unwrap();
452 for (row_idx, col_idx, value) in &set_mutations {
453 storage.update_cell(table_id, *row_idx, *col_idx, value)?;
454 }
455 for row_idx in &delete_rows {
456 storage.delete_row(table_id, *row_idx)?;
457 }
458 }
459
460 Ok(QueryResult::new(vec![], vec![]))
461 }
462
463 fn extract_match_node(
465 &self,
466 match_clause: &BoundMatchClause,
467 ) -> KyuResult<(TableId, Option<u32>)> {
468 for pattern in &match_clause.patterns {
469 for element in &pattern.elements {
470 if let BoundPatternElement::Node(node) = element {
471 return Ok((node.table_id, node.variable_index));
472 }
473 }
474 }
475 Err(KyuError::NotImplemented(
476 "MATCH clause must contain at least one node pattern".into(),
477 ))
478 }
479
480 pub fn apply_delta(&self, batch: DeltaBatch) -> KyuResult<DeltaStats> {
489 let start = Instant::now();
490 let mut stats = DeltaStats {
491 total_deltas: batch.len() as u64,
492 ..DeltaStats::default()
493 };
494
495 let mut txn = self.txn_mgr.begin(TransactionType::Write).map_err(|e| {
497 KyuError::Transaction(e.to_string())
498 })?;
499
500 let catalog = self.catalog.read();
501 let mut storage = self.storage.write().unwrap();
502
503 for delta in batch.iter() {
504 match delta {
505 GraphDelta::UpsertNode { key, labels: _, props } => {
506 let entry = catalog.find_by_name(key.label.as_str()).ok_or_else(|| {
507 KyuError::Catalog(format!("node table '{}' not found", key.label))
508 })?;
509 let node_entry = entry.as_node_table().ok_or_else(|| {
510 KyuError::Catalog(format!("'{}' is not a node table", key.label))
511 })?;
512 let table_id = node_entry.table_id;
513 let pk_col_idx = node_entry.primary_key_idx;
514 let pk_type = &node_entry.properties[pk_col_idx].data_type;
515 let pk_value = parse_primary_key(key.primary_key.as_str(), pk_type)?;
516
517 let existing = find_row_by_pk(&storage, table_id, pk_col_idx, &pk_value)?;
518
519 if let Some(row_idx) = existing {
520 for (prop_name, value) in props {
522 if let Some(col_idx) = find_property_index(node_entry, prop_name.as_str()) {
523 storage.update_cell(table_id, row_idx, col_idx, value)?;
524 }
525 }
526 stats.nodes_updated += 1;
527 } else {
528 let values = build_node_row(node_entry, &pk_value, props);
530 storage.insert_row(table_id, &values)?;
531 stats.nodes_created += 1;
532 }
533 }
534
535 GraphDelta::UpsertEdge { src, rel_type, dst, props } => {
536 let entry = catalog.find_by_name(rel_type.as_str()).ok_or_else(|| {
537 KyuError::Catalog(format!("rel table '{}' not found", rel_type))
538 })?;
539 let rel_entry = entry.as_rel_table().ok_or_else(|| {
540 KyuError::Catalog(format!("'{}' is not a rel table", rel_type))
541 })?;
542 let rel_table_id = rel_entry.table_id;
543
544 let src_node = catalog.find_by_name(src.label.as_str())
546 .and_then(|e| e.as_node_table())
547 .ok_or_else(|| KyuError::Catalog(format!("node table '{}' not found", src.label)))?;
548 let dst_node = catalog.find_by_name(dst.label.as_str())
549 .and_then(|e| e.as_node_table())
550 .ok_or_else(|| KyuError::Catalog(format!("node table '{}' not found", dst.label)))?;
551
552 let src_pk_type = &src_node.properties[src_node.primary_key_idx].data_type;
553 let dst_pk_type = &dst_node.properties[dst_node.primary_key_idx].data_type;
554 let src_pk = parse_primary_key(src.primary_key.as_str(), src_pk_type)?;
555 let dst_pk = parse_primary_key(dst.primary_key.as_str(), dst_pk_type)?;
556
557 let existing = find_edge_row(&storage, rel_table_id, &src_pk, &dst_pk)?;
559
560 if let Some(row_idx) = existing {
561 for (prop_name, value) in props {
563 if let Some(prop_idx) = find_rel_property_index(rel_entry, prop_name.as_str()) {
564 let col_idx = prop_idx + 2; storage.update_cell(rel_table_id, row_idx, col_idx, value)?;
566 }
567 }
568 stats.edges_updated += 1;
569 } else {
570 let values = build_edge_row(rel_entry, &src_pk, &dst_pk, props);
572 storage.insert_row(rel_table_id, &values)?;
573 stats.edges_created += 1;
574 }
575 }
576
577 GraphDelta::DeleteNode { key } => {
578 let entry = catalog.find_by_name(key.label.as_str())
579 .and_then(|e| e.as_node_table())
580 .ok_or_else(|| KyuError::Catalog(format!("node table '{}' not found", key.label)))?;
581 let table_id = entry.table_id;
582 let pk_col_idx = entry.primary_key_idx;
583 let pk_type = &entry.properties[pk_col_idx].data_type;
584 let pk_value = parse_primary_key(key.primary_key.as_str(), pk_type)?;
585
586 if let Some(row_idx) = find_row_by_pk(&storage, table_id, pk_col_idx, &pk_value)? {
587 storage.delete_row(table_id, row_idx)?;
588 stats.nodes_deleted += 1;
589 }
590 }
591
592 GraphDelta::DeleteEdge { src, rel_type, dst } => {
593 let rel_entry = catalog.find_by_name(rel_type.as_str())
594 .and_then(|e| e.as_rel_table())
595 .ok_or_else(|| KyuError::Catalog(format!("rel table '{}' not found", rel_type)))?;
596 let rel_table_id = rel_entry.table_id;
597
598 let src_node = catalog.find_by_name(src.label.as_str())
599 .and_then(|e| e.as_node_table())
600 .ok_or_else(|| KyuError::Catalog(format!("node table '{}' not found", src.label)))?;
601 let dst_node = catalog.find_by_name(dst.label.as_str())
602 .and_then(|e| e.as_node_table())
603 .ok_or_else(|| KyuError::Catalog(format!("node table '{}' not found", dst.label)))?;
604
605 let src_pk = parse_primary_key(src.primary_key.as_str(), &src_node.properties[src_node.primary_key_idx].data_type)?;
606 let dst_pk = parse_primary_key(dst.primary_key.as_str(), &dst_node.properties[dst_node.primary_key_idx].data_type)?;
607
608 if let Some(row_idx) = find_edge_row(&storage, rel_table_id, &src_pk, &dst_pk)? {
609 storage.delete_row(rel_table_id, row_idx)?;
610 stats.edges_deleted += 1;
611 }
612 }
613 }
614 }
615
616 drop(storage);
617 drop(catalog);
618
619 self.txn_mgr.commit(&mut txn, &self.wal, |_, _| {}).map_err(|e| {
621 KyuError::Transaction(e.to_string())
622 })?;
623 let _ = self.checkpointer.try_checkpoint();
624
625 stats.elapsed_micros = start.elapsed().as_micros() as u64;
626 Ok(stats)
627 }
628
629 fn try_call_extension(&self, cypher: &str) -> KyuResult<Option<QueryResult>> {
634 let trimmed = cypher.trim();
635 if !trimmed.to_uppercase().starts_with("CALL ") {
636 return Ok(None);
637 }
638
639 let rest = trimmed[5..].trim();
641 let dot_pos = rest.find('.').ok_or_else(|| {
642 KyuError::Binder("CALL requires <extension>.<procedure>(...) syntax".into())
643 })?;
644 let ext_name = &rest[..dot_pos];
645 let after_dot = &rest[dot_pos + 1..];
646
647 let paren_pos = after_dot.find('(').ok_or_else(|| {
648 KyuError::Binder("CALL requires <extension>.<procedure>(...) syntax".into())
649 })?;
650 let proc_name = &after_dot[..paren_pos];
651 let args_str = after_dot[paren_pos + 1..].trim_end_matches([')', ';']);
652
653 let args: Vec<String> = if args_str.trim().is_empty() {
654 Vec::new()
655 } else {
656 args_str.split(',').map(|s| s.trim().trim_matches('\'').to_string()).collect()
657 };
658
659 let ext = self.extensions.iter().find(|e| e.name() == ext_name).ok_or_else(|| {
661 KyuError::Binder(format!("unknown extension '{ext_name}'"))
662 })?;
663
664 let adjacency = if ext.needs_graph() {
666 self.build_graph_adjacency()
667 } else {
668 std::collections::HashMap::new()
669 };
670
671 let rows = ext.execute(proc_name, &args, &adjacency).map_err(|e| {
673 KyuError::Runtime(format!("extension error: {e}"))
674 })?;
675
676 let proc_sig = ext.procedures().into_iter().find(|p| p.name == proc_name).ok_or_else(|| {
678 KyuError::Binder(format!("unknown procedure '{proc_name}' in extension '{ext_name}'"))
679 })?;
680
681 let col_names: Vec<SmolStr> = proc_sig.columns.iter().map(|c| SmolStr::new(&c.name)).collect();
682 let col_types: Vec<LogicalType> = proc_sig.columns.iter().map(|c| c.data_type.clone()).collect();
683
684 let mut result = QueryResult::new(col_names, col_types);
685 for proc_row in rows {
686 result.push_row(proc_row);
687 }
688
689 Ok(Some(result))
690 }
691
692 fn build_graph_adjacency(&self) -> std::collections::HashMap<i64, Vec<(i64, f64)>> {
697 use kyu_executor::value_vector::ValueVector;
698
699 let mut adjacency: std::collections::HashMap<i64, Vec<(i64, f64)>> = std::collections::HashMap::new();
700 let catalog = self.catalog.read();
701 let storage = self.storage.read().unwrap();
702
703 for rel in catalog.rel_tables() {
704 let table_id = rel.table_id;
705 for chunk in storage.scan_table(table_id) {
706 let n = chunk.num_rows();
707 if n == 0 {
708 continue;
709 }
710
711 let src_col = chunk.column(0);
712 let dst_col = chunk.column(1);
713
714 if chunk.selection().is_identity()
716 && let (ValueVector::Flat(src_flat), ValueVector::Flat(dst_flat)) =
717 (src_col, dst_col)
718 {
719 let src_slice = src_flat.data_as_i64_slice();
720 let dst_slice = dst_flat.data_as_i64_slice();
721 let src_nm = src_flat.null_mask();
722 let dst_nm = dst_flat.null_mask();
723 for i in 0..n {
724 if !src_nm.is_null(i as u64) && !dst_nm.is_null(i as u64) {
725 adjacency
726 .entry(src_slice[i])
727 .or_default()
728 .push((dst_slice[i], 1.0));
729 }
730 }
731 continue;
732 }
733
734 for row_idx in 0..n {
736 let src = chunk.get_value(row_idx, 0);
737 let dst = chunk.get_value(row_idx, 1);
738 if let (TypedValue::Int64(s), TypedValue::Int64(d)) = (src, dst) {
739 adjacency.entry(s).or_default().push((d, 1.0));
740 }
741 }
742 }
743 }
744
745 adjacency
746 }
747
748 fn exec_create_node_table(
751 &self,
752 create: &kyu_binder::BoundCreateNodeTable,
753 ) -> KyuResult<QueryResult> {
754 let mut catalog = self.catalog.begin_write();
755
756 let table_id = catalog.alloc_table_id();
757 let properties: Vec<Property> = create
758 .columns
759 .iter()
760 .map(|col| {
761 let prop_id = catalog.alloc_property_id();
762 Property::new(
763 prop_id,
764 col.name.clone(),
765 col.data_type.clone(),
766 col.property_id.0 as usize == create.primary_key_idx,
767 )
768 })
769 .collect();
770
771 let schema: Vec<LogicalType> = create.columns.iter().map(|c| c.data_type.clone()).collect();
772
773 catalog.add_node_table(NodeTableEntry {
774 table_id,
775 name: create.name.clone(),
776 properties,
777 primary_key_idx: create.primary_key_idx,
778 num_rows: 0,
779 comment: None,
780 })?;
781
782 self.catalog.commit_write(catalog);
783
784 self.storage.write().unwrap().create_table(table_id, schema);
786
787 Ok(QueryResult::new(vec![], vec![]))
788 }
789
790 fn exec_create_rel_table(
791 &self,
792 create: &kyu_binder::BoundCreateRelTable,
793 ) -> KyuResult<QueryResult> {
794 let mut catalog = self.catalog.begin_write();
795
796 let table_id = catalog.alloc_table_id();
797 let properties: Vec<Property> = create
798 .columns
799 .iter()
800 .map(|col| {
801 let prop_id = catalog.alloc_property_id();
802 Property::new(prop_id, col.name.clone(), col.data_type.clone(), false)
803 })
804 .collect();
805
806 let from_key_type = catalog
808 .find_by_id(create.from_table_id)
809 .and_then(|e| e.as_node_table())
810 .map(|n| n.primary_key_property().data_type.clone())
811 .unwrap_or(LogicalType::Int64);
812 let to_key_type = catalog
813 .find_by_id(create.to_table_id)
814 .and_then(|e| e.as_node_table())
815 .map(|n| n.primary_key_property().data_type.clone())
816 .unwrap_or(LogicalType::Int64);
817 let mut schema = vec![from_key_type, to_key_type];
818 schema.extend(create.columns.iter().map(|c| c.data_type.clone()));
819
820 catalog.add_rel_table(RelTableEntry {
821 table_id,
822 name: create.name.clone(),
823 from_table_id: create.from_table_id,
824 to_table_id: create.to_table_id,
825 properties,
826 num_rows: 0,
827 comment: None,
828 })?;
829
830 self.catalog.commit_write(catalog);
831
832 self.storage.write().unwrap().create_table(table_id, schema);
833
834 Ok(QueryResult::new(vec![], vec![]))
835 }
836
837 fn exec_drop(&self, drop: &kyu_binder::BoundDrop) -> KyuResult<QueryResult> {
838 let mut catalog = self.catalog.begin_write();
839 catalog.remove_by_id(drop.table_id).ok_or_else(|| {
840 KyuError::Catalog(format!("table '{}' not found", drop.name))
841 })?;
842 self.catalog.commit_write(catalog);
843
844 self.storage.write().unwrap().drop_table(drop.table_id);
845
846 Ok(QueryResult::new(vec![], vec![]))
847 }
848
849 fn exec_copy_from(&self, copy: &kyu_binder::BoundCopyFrom) -> KyuResult<QueryResult> {
852 let path_val = evaluate_constant(©.source)?;
854 let path = match &path_val {
855 TypedValue::String(s) => s.as_str().to_string(),
856 _ => {
857 return Err(KyuError::Copy(
858 "COPY FROM source must be a string path".into(),
859 ))
860 }
861 };
862
863 let catalog_snapshot = self.catalog.read();
865 let entry = catalog_snapshot.find_by_id(copy.table_id).ok_or_else(|| {
866 KyuError::Catalog(format!("table {:?} not found", copy.table_id))
867 })?;
868 let properties = entry.properties();
869 let schema: Vec<LogicalType> = properties.iter().map(|p| p.data_type.clone()).collect();
870 drop(catalog_snapshot);
871
872 let reader = kyu_copy::open_reader(&path, &schema)?;
874
875 let mut storage = self.storage.write().unwrap();
876 for row_result in reader {
877 let values = row_result?;
878 storage.insert_row(copy.table_id, &values)?;
879 }
880
881 Ok(QueryResult::new(vec![], vec![]))
882 }
883}
884
885fn parse_primary_key(value: &str, ty: &LogicalType) -> KyuResult<TypedValue> {
889 match ty {
890 LogicalType::Int8 => value.parse::<i8>().map(TypedValue::Int8)
891 .map_err(|e| KyuError::Delta(format!("cannot parse PK '{value}' as INT8: {e}"))),
892 LogicalType::Int16 => value.parse::<i16>().map(TypedValue::Int16)
893 .map_err(|e| KyuError::Delta(format!("cannot parse PK '{value}' as INT16: {e}"))),
894 LogicalType::Int32 => value.parse::<i32>().map(TypedValue::Int32)
895 .map_err(|e| KyuError::Delta(format!("cannot parse PK '{value}' as INT32: {e}"))),
896 LogicalType::Int64 | LogicalType::Serial => value.parse::<i64>().map(TypedValue::Int64)
897 .map_err(|e| KyuError::Delta(format!("cannot parse PK '{value}' as INT64: {e}"))),
898 LogicalType::String => Ok(TypedValue::String(SmolStr::new(value))),
899 _ => Err(KyuError::Delta(format!(
900 "unsupported primary key type '{}' for delta upsert",
901 ty.type_name()
902 ))),
903 }
904}
905
906fn find_row_by_pk(
908 storage: &crate::storage::NodeGroupStorage,
909 table_id: TableId,
910 pk_col_idx: usize,
911 pk_value: &TypedValue,
912) -> KyuResult<Option<u64>> {
913 let rows = storage.scan_rows(table_id)?;
914 for (row_idx, row_values) in &rows {
915 if row_values.get(pk_col_idx) == Some(pk_value) {
916 return Ok(Some(*row_idx));
917 }
918 }
919 Ok(None)
920}
921
922fn find_edge_row(
925 storage: &crate::storage::NodeGroupStorage,
926 rel_table_id: TableId,
927 src_pk: &TypedValue,
928 dst_pk: &TypedValue,
929) -> KyuResult<Option<u64>> {
930 let rows = storage.scan_rows(rel_table_id)?;
931 for (row_idx, row_values) in &rows {
932 if row_values.first() == Some(src_pk) && row_values.get(1) == Some(dst_pk) {
933 return Ok(Some(*row_idx));
934 }
935 }
936 Ok(None)
937}
938
939fn find_property_index(entry: &NodeTableEntry, name: &str) -> Option<usize> {
941 let lower = name.to_lowercase();
942 entry.properties.iter().position(|p| p.name.to_lowercase() == lower)
943}
944
945fn find_rel_property_index(entry: &RelTableEntry, name: &str) -> Option<usize> {
947 let lower = name.to_lowercase();
948 entry.properties.iter().position(|p| p.name.to_lowercase() == lower)
949}
950
951fn build_node_row(
953 entry: &NodeTableEntry,
954 pk_value: &TypedValue,
955 props: &hashbrown::HashMap<SmolStr, TypedValue>,
956) -> Vec<TypedValue> {
957 entry.properties.iter().enumerate().map(|(i, prop)| {
958 if i == entry.primary_key_idx {
959 pk_value.clone()
960 } else if let Some(val) = props.get(&prop.name) {
961 val.clone()
962 } else {
963 TypedValue::Null
964 }
965 }).collect()
966}
967
968fn build_edge_row(
970 entry: &RelTableEntry,
971 src_pk: &TypedValue,
972 dst_pk: &TypedValue,
973 props: &hashbrown::HashMap<SmolStr, TypedValue>,
974) -> Vec<TypedValue> {
975 let mut row = vec![src_pk.clone(), dst_pk.clone()];
976 for prop in &entry.properties {
977 if let Some(val) = props.get(&prop.name) {
978 row.push(val.clone());
979 } else {
980 row.push(TypedValue::Null);
981 }
982 }
983 row
984}
985
986#[cfg(test)]
987mod tests {
988 use crate::database::Database;
989 use kyu_types::TypedValue;
990 use smol_str::SmolStr;
991
992 #[test]
993 fn create_database_and_connect() {
994 let db = Database::in_memory();
995 let _conn = db.connect();
996 assert_eq!(db.catalog().num_tables(), 0);
997 }
998
999 #[test]
1000 fn return_literal() {
1001 let db = Database::in_memory();
1002 let conn = db.connect();
1003 let result = conn.query("RETURN 1 AS x").unwrap();
1004 assert_eq!(result.num_rows(), 1);
1005 assert_eq!(result.row(0), vec![TypedValue::Int64(1)]);
1006 }
1007
1008 #[test]
1009 fn return_arithmetic() {
1010 let db = Database::in_memory();
1011 let conn = db.connect();
1012 let result = conn.query("RETURN 2 + 3 AS sum").unwrap();
1013 assert_eq!(result.row(0), vec![TypedValue::Int64(5)]);
1014 }
1015
1016 #[test]
1017 fn create_node_table() {
1018 let db = Database::in_memory();
1019 let conn = db.connect();
1020 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1021 .unwrap();
1022
1023 assert_eq!(db.catalog().num_tables(), 1);
1024 let snapshot = db.catalog().read();
1025 let entry = snapshot.find_by_name("Person").unwrap();
1026 assert!(entry.is_node_table());
1027 assert_eq!(entry.properties().len(), 2);
1028
1029 assert!(db.storage().read().unwrap().has_table(entry.table_id()));
1030 }
1031
1032 #[test]
1033 fn create_and_query_empty_table() {
1034 let db = Database::in_memory();
1035 let conn = db.connect();
1036 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1037 .unwrap();
1038 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1039 assert_eq!(result.num_rows(), 0);
1040 }
1041
1042 #[test]
1043 fn create_rel_table() {
1044 let db = Database::in_memory();
1045 let conn = db.connect();
1046 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1047 .unwrap();
1048 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person, since INT64)")
1049 .unwrap();
1050
1051 assert_eq!(db.catalog().num_tables(), 2);
1052 let snapshot = db.catalog().read();
1053 let entry = snapshot.find_by_name("KNOWS").unwrap();
1054 assert!(entry.is_rel_table());
1055 }
1056
1057 #[test]
1058 fn drop_table() {
1059 let db = Database::in_memory();
1060 let conn = db.connect();
1061 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1062 .unwrap();
1063 assert_eq!(db.catalog().num_tables(), 1);
1064
1065 conn.query("DROP TABLE Person").unwrap();
1066 assert_eq!(db.catalog().num_tables(), 0);
1067 }
1068
1069 #[test]
1070 fn create_duplicate_error() {
1071 let db = Database::in_memory();
1072 let conn = db.connect();
1073 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1074 .unwrap();
1075 let result = conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))");
1076 assert!(result.is_err());
1077 }
1078
1079 #[test]
1080 fn parse_error_propagated() {
1081 let db = Database::in_memory();
1082 let conn = db.connect();
1083 let result = conn.query("THIS IS NOT VALID CYPHER !!!");
1084 assert!(result.is_err());
1085 }
1086
1087 #[test]
1088 fn multiple_connections_share_state() {
1089 let db = Database::in_memory();
1090 let conn1 = db.connect();
1091 let conn2 = db.connect();
1092
1093 conn1
1094 .query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1095 .unwrap();
1096
1097 assert_eq!(db.catalog().num_tables(), 1);
1099 let result = conn2.query("MATCH (p:Person) RETURN p.id").unwrap();
1100 assert_eq!(result.num_rows(), 0);
1101 }
1102
1103 #[test]
1104 fn create_node_via_cypher() {
1105 let db = Database::in_memory();
1106 let conn = db.connect();
1107 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1108 .unwrap();
1109
1110 conn.query("CREATE (n:Person {id: 1, name: 'Alice'})")
1111 .unwrap();
1112
1113 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1114 assert_eq!(result.num_rows(), 1);
1115 assert_eq!(
1116 result.row(0)[0],
1117 TypedValue::String(SmolStr::new("Alice"))
1118 );
1119 }
1120
1121 #[test]
1122 fn create_multiple_nodes() {
1123 let db = Database::in_memory();
1124 let conn = db.connect();
1125 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1126 .unwrap();
1127
1128 conn.query("CREATE (a:Person {id: 1, name: 'Alice'}), (b:Person {id: 2, name: 'Bob'})")
1129 .unwrap();
1130
1131 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1132 assert_eq!(result.num_rows(), 2);
1133 }
1134
1135 #[test]
1136 fn create_node_partial_properties() {
1137 let db = Database::in_memory();
1138 let conn = db.connect();
1139 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1140 .unwrap();
1141
1142 conn.query("CREATE (n:Person {id: 1})").unwrap();
1144
1145 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
1146 assert_eq!(result.num_rows(), 1);
1147 assert_eq!(result.row(0)[0], TypedValue::Int64(1));
1148 assert_eq!(result.row(0)[1], TypedValue::Null);
1149 }
1150
1151 #[test]
1152 fn create_and_return() {
1153 let db = Database::in_memory();
1154 let conn = db.connect();
1155 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1156 .unwrap();
1157
1158 let result = conn
1159 .query("CREATE (n:Person {id: 1, name: 'Alice'}) RETURN n.name, n.id")
1160 .unwrap();
1161 assert_eq!(result.num_rows(), 1);
1162 assert_eq!(
1163 result.row(0)[0],
1164 TypedValue::String(SmolStr::new("Alice"))
1165 );
1166 assert_eq!(result.row(0)[1], TypedValue::Int64(1));
1167 }
1168
1169 #[test]
1170 fn match_set_property() {
1171 let db = Database::in_memory();
1172 let conn = db.connect();
1173 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, age INT64, PRIMARY KEY (id))")
1174 .unwrap();
1175 conn.query("CREATE (n:Person {id: 1, name: 'Alice', age: 25})")
1176 .unwrap();
1177
1178 conn.query("MATCH (p:Person) WHERE p.name = 'Alice' SET p.age = 31")
1179 .unwrap();
1180
1181 let result = conn.query("MATCH (p:Person) RETURN p.age").unwrap();
1182 assert_eq!(result.num_rows(), 1);
1183 assert_eq!(result.row(0)[0], TypedValue::Int64(31));
1184 }
1185
1186 #[test]
1187 fn match_set_with_where() {
1188 let db = Database::in_memory();
1189 let conn = db.connect();
1190 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, age INT64, PRIMARY KEY (id))")
1191 .unwrap();
1192 conn.query("CREATE (a:Person {id: 1, name: 'Alice', age: 25})")
1193 .unwrap();
1194 conn.query("CREATE (b:Person {id: 2, name: 'Bob', age: 30})")
1195 .unwrap();
1196
1197 conn.query("MATCH (p:Person) WHERE p.id = 1 SET p.age = 26")
1199 .unwrap();
1200
1201 let result = conn
1202 .query("MATCH (p:Person) RETURN p.name, p.age")
1203 .unwrap();
1204 assert_eq!(result.num_rows(), 2);
1205 let alice_row = result
1207 .iter_rows()
1208 .find(|r| r[0] == TypedValue::String(SmolStr::new("Alice")))
1209 .unwrap();
1210 let bob_row = result
1211 .iter_rows()
1212 .find(|r| r[0] == TypedValue::String(SmolStr::new("Bob")))
1213 .unwrap();
1214 assert_eq!(alice_row[1], TypedValue::Int64(26)); assert_eq!(bob_row[1], TypedValue::Int64(30)); }
1217
1218 #[test]
1219 fn match_set_all_rows() {
1220 let db = Database::in_memory();
1221 let conn = db.connect();
1222 conn.query("CREATE NODE TABLE Person (id INT64, active INT64, PRIMARY KEY (id))")
1223 .unwrap();
1224 conn.query("CREATE (a:Person {id: 1, active: 0})")
1225 .unwrap();
1226 conn.query("CREATE (b:Person {id: 2, active: 0})")
1227 .unwrap();
1228
1229 conn.query("MATCH (p:Person) SET p.active = 1").unwrap();
1231
1232 let result = conn.query("MATCH (p:Person) RETURN p.active").unwrap();
1233 assert_eq!(result.num_rows(), 2);
1234 assert_eq!(result.row(0)[0], TypedValue::Int64(1));
1235 assert_eq!(result.row(1)[0], TypedValue::Int64(1));
1236 }
1237
1238 #[test]
1239 fn match_delete() {
1240 let db = Database::in_memory();
1241 let conn = db.connect();
1242 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1243 .unwrap();
1244 conn.query("CREATE (a:Person {id: 1, name: 'Alice'})").unwrap();
1245 conn.query("CREATE (b:Person {id: 2, name: 'Bob'})").unwrap();
1246
1247 conn.query("MATCH (p:Person) WHERE p.name = 'Alice' DELETE p")
1248 .unwrap();
1249
1250 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1251 assert_eq!(result.num_rows(), 1);
1252 assert_eq!(
1253 result.row(0)[0],
1254 TypedValue::String(SmolStr::new("Bob"))
1255 );
1256 }
1257
1258 #[test]
1259 fn match_delete_all() {
1260 let db = Database::in_memory();
1261 let conn = db.connect();
1262 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1263 .unwrap();
1264 conn.query("CREATE (a:Person {id: 1})").unwrap();
1265 conn.query("CREATE (b:Person {id: 2})").unwrap();
1266
1267 conn.query("MATCH (p:Person) DELETE p").unwrap();
1268
1269 let result = conn.query("MATCH (p:Person) RETURN p.id").unwrap();
1270 assert_eq!(result.num_rows(), 0);
1271 }
1272
1273 #[test]
1274 fn storage_roundtrip_insert_scan() {
1275 let db = Database::in_memory();
1276 let conn = db.connect();
1277 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1278 .unwrap();
1279
1280 let snapshot = db.catalog().read();
1282 let table_id = snapshot.find_by_name("Person").unwrap().table_id();
1283 drop(snapshot);
1284
1285 db.storage()
1287 .write()
1288 .unwrap()
1289 .insert_row(
1290 table_id,
1291 &[
1292 TypedValue::Int64(1),
1293 TypedValue::String(SmolStr::new("Alice")),
1294 ],
1295 )
1296 .unwrap();
1297
1298 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1300 assert_eq!(result.num_rows(), 1);
1301 assert_eq!(
1302 result.row(0)[0],
1303 TypedValue::String(SmolStr::new("Alice"))
1304 );
1305 }
1306
1307 #[test]
1308 fn storage_roundtrip_multiple_rows() {
1309 let db = Database::in_memory();
1310 let conn = db.connect();
1311 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, age INT64, PRIMARY KEY (id))")
1312 .unwrap();
1313
1314 let snapshot = db.catalog().read();
1315 let table_id = snapshot.find_by_name("Person").unwrap().table_id();
1316 drop(snapshot);
1317
1318 let mut storage = db.storage().write().unwrap();
1319 storage
1320 .insert_row(
1321 table_id,
1322 &[
1323 TypedValue::Int64(1),
1324 TypedValue::String(SmolStr::new("Alice")),
1325 TypedValue::Int64(25),
1326 ],
1327 )
1328 .unwrap();
1329 storage
1330 .insert_row(
1331 table_id,
1332 &[
1333 TypedValue::Int64(2),
1334 TypedValue::String(SmolStr::new("Bob")),
1335 TypedValue::Int64(30),
1336 ],
1337 )
1338 .unwrap();
1339 drop(storage);
1340
1341 let result = conn.query("MATCH (p:Person) RETURN p.name, p.age").unwrap();
1342 assert_eq!(result.num_rows(), 2);
1343 }
1344
1345 #[test]
1346 fn copy_from_csv() {
1347 use std::io::Write;
1348
1349 let dir = std::env::temp_dir().join("kyu_test_csv");
1350 let _ = std::fs::create_dir_all(&dir);
1351 let csv_path = dir.join("persons.csv");
1352 {
1353 let mut f = std::fs::File::create(&csv_path).unwrap();
1354 writeln!(f, "id,name").unwrap();
1355 writeln!(f, "1,Alice").unwrap();
1356 writeln!(f, "2,Bob").unwrap();
1357 writeln!(f, "3,Charlie").unwrap();
1358 }
1359
1360 let db = Database::in_memory();
1361 let conn = db.connect();
1362 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1363 .unwrap();
1364 conn.query(&format!(
1365 "COPY Person FROM '{}'",
1366 csv_path.display()
1367 ))
1368 .unwrap();
1369
1370 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
1371 assert_eq!(result.num_rows(), 3);
1372
1373 let _ = std::fs::remove_file(&csv_path);
1375 }
1376
1377 #[test]
1378 fn copy_from_csv_multiple_types() {
1379 use std::io::Write;
1380
1381 let dir = std::env::temp_dir().join("kyu_test_csv");
1382 let _ = std::fs::create_dir_all(&dir);
1383 let csv_path = dir.join("typed.csv");
1384 {
1385 let mut f = std::fs::File::create(&csv_path).unwrap();
1386 writeln!(f, "id,name,score,active").unwrap();
1387 writeln!(f, "1,Alice,95.5,true").unwrap();
1388 writeln!(f, "2,Bob,87.3,false").unwrap();
1389 }
1390
1391 let db = Database::in_memory();
1392 let conn = db.connect();
1393 conn.query(
1394 "CREATE NODE TABLE Student (id INT64, name STRING, score DOUBLE, active BOOL, PRIMARY KEY (id))",
1395 )
1396 .unwrap();
1397 conn.query(&format!(
1398 "COPY Student FROM '{}'",
1399 csv_path.display()
1400 ))
1401 .unwrap();
1402
1403 let result = conn
1404 .query("MATCH (s:Student) RETURN s.name, s.score, s.active")
1405 .unwrap();
1406 assert_eq!(result.num_rows(), 2);
1407 assert_eq!(
1408 result.row(0)[0],
1409 TypedValue::String(SmolStr::new("Alice"))
1410 );
1411 assert_eq!(result.row(0)[1], TypedValue::Double(95.5));
1412 assert_eq!(result.row(0)[2], TypedValue::Bool(true));
1413
1414 let _ = std::fs::remove_file(&csv_path);
1415 }
1416
1417 #[test]
1418 fn copy_from_parquet() {
1419 use arrow::array::{Int64Array, StringArray};
1420 use arrow::datatypes::{DataType, Field, Schema};
1421 use arrow::record_batch::RecordBatch;
1422 use parquet::arrow::ArrowWriter;
1423 use std::sync::Arc;
1424
1425 let dir = std::env::temp_dir().join("kyu_test_parquet_copy");
1426 let _ = std::fs::create_dir_all(&dir);
1427 let parquet_path = dir.join("persons.parquet");
1428 {
1429 let schema = Arc::new(Schema::new(vec![
1430 Field::new("id", DataType::Int64, false),
1431 Field::new("name", DataType::Utf8, false),
1432 ]));
1433 let ids = Int64Array::from(vec![1, 2, 3]);
1434 let names = StringArray::from(vec!["Alice", "Bob", "Charlie"]);
1435 let batch = RecordBatch::try_new(
1436 Arc::clone(&schema),
1437 vec![Arc::new(ids), Arc::new(names)],
1438 )
1439 .unwrap();
1440 let file = std::fs::File::create(&parquet_path).unwrap();
1441 let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), None).unwrap();
1442 writer.write(&batch).unwrap();
1443 writer.close().unwrap();
1444 }
1445
1446 let db = Database::in_memory();
1447 let conn = db.connect();
1448 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1449 .unwrap();
1450 conn.query(&format!(
1451 "COPY Person FROM '{}'",
1452 parquet_path.display()
1453 ))
1454 .unwrap();
1455
1456 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
1457 assert_eq!(result.num_rows(), 3);
1458 assert_eq!(result.row(0)[0], TypedValue::Int64(1));
1459 assert_eq!(
1460 result.row(0)[1],
1461 TypedValue::String(SmolStr::new("Alice"))
1462 );
1463
1464 let _ = std::fs::remove_dir_all(&dir);
1465 }
1466
1467 #[test]
1468 fn call_extension_pagerank() {
1469 let mut db = Database::in_memory();
1470 db.register_extension(Box::new(ext_algo::AlgoExtension));
1471 let conn = db.connect();
1472
1473 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))").unwrap();
1475 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person)").unwrap();
1476 conn.query("CREATE (n:Person {id: 1})").unwrap();
1477 conn.query("CREATE (n:Person {id: 2})").unwrap();
1478 conn.query("CREATE (n:Person {id: 3})").unwrap();
1479
1480 let snapshot = db.catalog().read();
1482 let rel_table_id = snapshot.find_by_name("KNOWS").unwrap().table_id();
1483 drop(snapshot);
1484 {
1485 let mut storage = db.storage().write().unwrap();
1486 storage.insert_row(rel_table_id, &[TypedValue::Int64(1), TypedValue::Int64(2)]).unwrap();
1487 storage.insert_row(rel_table_id, &[TypedValue::Int64(2), TypedValue::Int64(3)]).unwrap();
1488 storage.insert_row(rel_table_id, &[TypedValue::Int64(3), TypedValue::Int64(1)]).unwrap();
1489 }
1490
1491 let result = conn.query("CALL algo.pageRank(0.85, 20, 0.000001)").unwrap();
1492 assert_eq!(result.num_rows(), 3);
1493 assert_eq!(result.column_names.len(), 2);
1494 for row in result.iter_rows() {
1496 if let TypedValue::Double(rank) = &row[1] {
1497 assert!(*rank > 0.0);
1498 }
1499 }
1500 }
1501
1502 #[test]
1503 fn call_extension_wcc() {
1504 let mut db = Database::in_memory();
1505 db.register_extension(Box::new(ext_algo::AlgoExtension));
1506 let conn = db.connect();
1507
1508 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))").unwrap();
1509 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person)").unwrap();
1510 conn.query("CREATE (n:Person {id: 1})").unwrap();
1511 conn.query("CREATE (n:Person {id: 2})").unwrap();
1512 conn.query("CREATE (n:Person {id: 10})").unwrap();
1513 conn.query("CREATE (n:Person {id: 11})").unwrap();
1514
1515 let snapshot = db.catalog().read();
1516 let rel_table_id = snapshot.find_by_name("KNOWS").unwrap().table_id();
1517 drop(snapshot);
1518 {
1519 let mut storage = db.storage().write().unwrap();
1520 storage.insert_row(rel_table_id, &[TypedValue::Int64(1), TypedValue::Int64(2)]).unwrap();
1521 storage.insert_row(rel_table_id, &[TypedValue::Int64(10), TypedValue::Int64(11)]).unwrap();
1522 }
1523
1524 let result = conn.query("CALL algo.wcc()").unwrap();
1525 assert_eq!(result.num_rows(), 4);
1526 }
1527
1528 #[test]
1529 fn call_extension_betweenness() {
1530 let mut db = Database::in_memory();
1531 db.register_extension(Box::new(ext_algo::AlgoExtension));
1532 let conn = db.connect();
1533
1534 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))").unwrap();
1535 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person)").unwrap();
1536 conn.query("CREATE (n:Person {id: 1})").unwrap();
1537 conn.query("CREATE (n:Person {id: 2})").unwrap();
1538 conn.query("CREATE (n:Person {id: 3})").unwrap();
1539
1540 let snapshot = db.catalog().read();
1541 let rel_table_id = snapshot.find_by_name("KNOWS").unwrap().table_id();
1542 drop(snapshot);
1543 {
1544 let mut storage = db.storage().write().unwrap();
1545 storage.insert_row(rel_table_id, &[TypedValue::Int64(1), TypedValue::Int64(2)]).unwrap();
1546 storage.insert_row(rel_table_id, &[TypedValue::Int64(2), TypedValue::Int64(3)]).unwrap();
1547 }
1548
1549 let result = conn.query("CALL algo.betweenness()").unwrap();
1550 assert_eq!(result.num_rows(), 3);
1551 }
1552
1553 #[test]
1554 fn call_unknown_extension() {
1555 let db = Database::in_memory();
1556 let conn = db.connect();
1557 let result = conn.query("CALL nonexistent.proc()");
1558 assert!(result.is_err());
1559 }
1560
1561 #[test]
1562 fn persistence_survives_restart() {
1563 let dir = std::env::temp_dir().join("kyu_test_persist_e2e");
1564 let _ = std::fs::remove_dir_all(&dir);
1565
1566 {
1568 let db = Database::open(&dir).unwrap();
1569 let conn = db.connect();
1570 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1571 .unwrap();
1572 conn.query("CREATE (n:Person {id: 1, name: 'Alice'})").unwrap();
1573 conn.query("CREATE (n:Person {id: 2, name: 'Bob'})").unwrap();
1574 }
1576
1577 {
1579 let db = Database::open(&dir).unwrap();
1580 let conn = db.connect();
1581
1582 assert_eq!(db.catalog().num_tables(), 1);
1584 let snapshot = db.catalog().read();
1585 assert!(snapshot.find_by_name("Person").is_some());
1586 drop(snapshot);
1587
1588 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
1590 assert_eq!(result.num_rows(), 2);
1591 }
1592
1593 let _ = std::fs::remove_dir_all(&dir);
1594 }
1595
1596 #[test]
1597 fn persistence_ddl_recovery_via_wal() {
1598 let dir = std::env::temp_dir().join("kyu_test_persist_ddl");
1599 let _ = std::fs::remove_dir_all(&dir);
1600
1601 {
1603 let db = Database::open(&dir).unwrap();
1604 let conn = db.connect();
1605 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))")
1606 .unwrap();
1607 conn.query("CREATE NODE TABLE Organization (id INT64, name STRING, PRIMARY KEY (id))")
1608 .unwrap();
1609 }
1610
1611 {
1613 let db = Database::open(&dir).unwrap();
1614 assert_eq!(db.catalog().num_tables(), 2);
1615 let snapshot = db.catalog().read();
1616 assert!(snapshot.find_by_name("Person").is_some());
1617 assert!(snapshot.find_by_name("Organization").is_some());
1618 }
1619
1620 let _ = std::fs::remove_dir_all(&dir);
1621 }
1622
1623 #[test]
1624 fn persistence_empty_database() {
1625 let dir = std::env::temp_dir().join("kyu_test_persist_empty_db");
1626 let _ = std::fs::remove_dir_all(&dir);
1627
1628 { let _db = Database::open(&dir).unwrap(); }
1630
1631 {
1633 let db = Database::open(&dir).unwrap();
1634 assert_eq!(db.catalog().num_tables(), 0);
1635 }
1636
1637 let _ = std::fs::remove_dir_all(&dir);
1638 }
1639
1640 #[test]
1643 fn return_param() {
1644 let db = Database::in_memory();
1645 let conn = db.connect();
1646 let mut params = std::collections::HashMap::new();
1647 params.insert("x".to_string(), TypedValue::Int64(42));
1648 let result = conn
1649 .query_with_params("RETURN $x AS val", params)
1650 .unwrap();
1651 assert_eq!(result.num_rows(), 1);
1652 assert_eq!(result.row(0), vec![TypedValue::Int64(42)]);
1653 }
1654
1655 #[test]
1656 fn parameterized_where() {
1657 let db = Database::in_memory();
1658 let conn = db.connect();
1659 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, age INT64, PRIMARY KEY (id))")
1660 .unwrap();
1661 conn.query("CREATE (n:Person {id: 1, name: 'Alice', age: 30})")
1662 .unwrap();
1663 conn.query("CREATE (n:Person {id: 2, name: 'Bob', age: 20})")
1664 .unwrap();
1665
1666 let mut params = std::collections::HashMap::new();
1667 params.insert("min_age".to_string(), TypedValue::Int64(25));
1668 let result = conn
1669 .query_with_params(
1670 "MATCH (p:Person) WHERE p.age > $min_age RETURN p.name",
1671 params,
1672 )
1673 .unwrap();
1674 assert_eq!(result.num_rows(), 1);
1675 assert_eq!(
1676 result.row(0)[0],
1677 TypedValue::String(SmolStr::new("Alice"))
1678 );
1679 }
1680
1681 #[test]
1682 fn parameterized_create() {
1683 let db = Database::in_memory();
1684 let conn = db.connect();
1685 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1686 .unwrap();
1687
1688 let mut params = std::collections::HashMap::new();
1689 params.insert("id".to_string(), TypedValue::Int64(1));
1690 params.insert(
1691 "name".to_string(),
1692 TypedValue::String(SmolStr::new("Alice")),
1693 );
1694 conn.query_with_params(
1695 "CREATE (n:Person {id: $id, name: $name})",
1696 params,
1697 )
1698 .unwrap();
1699
1700 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1701 assert_eq!(result.num_rows(), 1);
1702 assert_eq!(
1703 result.row(0)[0],
1704 TypedValue::String(SmolStr::new("Alice"))
1705 );
1706 }
1707
1708 #[test]
1709 fn parameterized_set() {
1710 let db = Database::in_memory();
1711 let conn = db.connect();
1712 conn.query("CREATE NODE TABLE Person (id INT64, age INT64, PRIMARY KEY (id))")
1713 .unwrap();
1714 conn.query("CREATE (n:Person {id: 1, age: 25})").unwrap();
1715
1716 let mut params = std::collections::HashMap::new();
1717 params.insert("new_age".to_string(), TypedValue::Int64(31));
1718 conn.query_with_params(
1719 "MATCH (p:Person) WHERE p.id = 1 SET p.age = $new_age",
1720 params,
1721 )
1722 .unwrap();
1723
1724 let result = conn.query("MATCH (p:Person) RETURN p.age").unwrap();
1725 assert_eq!(result.row(0)[0], TypedValue::Int64(31));
1726 }
1727
1728 #[test]
1729 fn unresolved_param_error() {
1730 let db = Database::in_memory();
1731 let conn = db.connect();
1732 let result = conn.query("RETURN $missing AS val");
1733 assert!(result.is_err());
1734 assert!(result.unwrap_err().to_string().contains("unresolved parameter"));
1735 }
1736
1737 #[test]
1738 fn env_resolved() {
1739 let db = Database::in_memory();
1740 let conn = db.connect();
1741 let mut env = std::collections::HashMap::new();
1742 env.insert(
1743 "GREETING".to_string(),
1744 TypedValue::String(SmolStr::new("hello")),
1745 );
1746 let result = conn
1747 .execute("RETURN env('GREETING') AS val", std::collections::HashMap::new(), env)
1748 .unwrap();
1749 assert_eq!(result.num_rows(), 1);
1750 assert_eq!(
1751 result.row(0)[0],
1752 TypedValue::String(SmolStr::new("hello"))
1753 );
1754 }
1755
1756 #[test]
1757 fn env_missing_returns_null() {
1758 let db = Database::in_memory();
1759 let conn = db.connect();
1760 let result = conn
1761 .execute(
1762 "RETURN env('MISSING') AS val",
1763 std::collections::HashMap::new(),
1764 std::collections::HashMap::new(),
1765 )
1766 .unwrap();
1767 assert_eq!(result.num_rows(), 1);
1768 assert_eq!(result.row(0)[0], TypedValue::Null);
1769 }
1770
1771 #[test]
1774 fn delta_upsert_new_nodes() {
1775 use kyu_delta::DeltaBatchBuilder;
1776
1777 let db = Database::in_memory();
1778 let conn = db.connect();
1779 conn.query("CREATE NODE TABLE Function (name STRING, lines INT64, PRIMARY KEY (name))")
1780 .unwrap();
1781
1782 let batch = DeltaBatchBuilder::new("file:main.rs", 1)
1783 .upsert_node("Function", "main", vec![], [("lines", TypedValue::Int64(42))])
1784 .upsert_node("Function", "helper", vec![], [("lines", TypedValue::Int64(10))])
1785 .build();
1786
1787 let stats = conn.apply_delta(batch).unwrap();
1788 assert_eq!(stats.nodes_created, 2);
1789 assert_eq!(stats.nodes_updated, 0);
1790
1791 let result = conn.query("MATCH (f:Function) RETURN f.name, f.lines").unwrap();
1792 assert_eq!(result.num_rows(), 2);
1793 }
1794
1795 #[test]
1796 fn delta_upsert_existing_node_merges() {
1797 use kyu_delta::DeltaBatchBuilder;
1798
1799 let db = Database::in_memory();
1800 let conn = db.connect();
1801 conn.query("CREATE NODE TABLE Function (name STRING, lines INT64, PRIMARY KEY (name))")
1802 .unwrap();
1803
1804 let batch1 = DeltaBatchBuilder::new("file:main.rs", 1)
1806 .upsert_node("Function", "main", vec![], [("lines", TypedValue::Int64(42))])
1807 .build();
1808 conn.apply_delta(batch1).unwrap();
1809
1810 let batch2 = DeltaBatchBuilder::new("file:main.rs", 2)
1812 .upsert_node("Function", "main", vec![], [("lines", TypedValue::Int64(50))])
1813 .build();
1814 let stats = conn.apply_delta(batch2).unwrap();
1815 assert_eq!(stats.nodes_created, 0);
1816 assert_eq!(stats.nodes_updated, 1);
1817
1818 let result = conn.query("MATCH (f:Function) WHERE f.name = 'main' RETURN f.lines").unwrap();
1819 assert_eq!(result.num_rows(), 1);
1820 assert_eq!(result.row(0)[0], TypedValue::Int64(50));
1821 }
1822
1823 #[test]
1824 fn delta_delete_node() {
1825 use kyu_delta::DeltaBatchBuilder;
1826
1827 let db = Database::in_memory();
1828 let conn = db.connect();
1829 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
1830 .unwrap();
1831 conn.query("CREATE (n:Person {id: 1, name: 'Alice'})").unwrap();
1832 conn.query("CREATE (n:Person {id: 2, name: 'Bob'})").unwrap();
1833
1834 let batch = DeltaBatchBuilder::new("cleanup", 1)
1835 .delete_node("Person", "1")
1836 .build();
1837 let stats = conn.apply_delta(batch).unwrap();
1838 assert_eq!(stats.nodes_deleted, 1);
1839
1840 let result = conn.query("MATCH (p:Person) RETURN p.name").unwrap();
1841 assert_eq!(result.num_rows(), 1);
1842 assert_eq!(result.row(0)[0], TypedValue::String(SmolStr::new("Bob")));
1843 }
1844
1845 #[test]
1846 fn delta_upsert_and_delete_edges() {
1847 use kyu_delta::DeltaBatchBuilder;
1848
1849 let db = Database::in_memory();
1850 let conn = db.connect();
1851 conn.query("CREATE NODE TABLE Person (id INT64, PRIMARY KEY (id))").unwrap();
1852 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person, since INT64)").unwrap();
1853 conn.query("CREATE (n:Person {id: 1})").unwrap();
1854 conn.query("CREATE (n:Person {id: 2})").unwrap();
1855
1856 let batch = DeltaBatchBuilder::new("social", 1)
1858 .upsert_edge("Person", "1", "KNOWS", "Person", "2", [("since", TypedValue::Int64(2024))])
1859 .build();
1860 let stats = conn.apply_delta(batch).unwrap();
1861 assert_eq!(stats.edges_created, 1);
1862
1863 let storage = db.storage().read().unwrap();
1865 let catalog = db.catalog().read();
1866 let rel_table_id = catalog.find_by_name("KNOWS").unwrap().table_id();
1867 let rows = storage.scan_rows(rel_table_id).unwrap();
1868 assert_eq!(rows.len(), 1);
1869 assert_eq!(rows[0].1[0], TypedValue::Int64(1)); assert_eq!(rows[0].1[1], TypedValue::Int64(2)); assert_eq!(rows[0].1[2], TypedValue::Int64(2024)); drop(storage);
1873 drop(catalog);
1874
1875 let batch2 = DeltaBatchBuilder::new("social", 2)
1877 .upsert_edge("Person", "1", "KNOWS", "Person", "2", [("since", TypedValue::Int64(2025))])
1878 .build();
1879 let stats2 = conn.apply_delta(batch2).unwrap();
1880 assert_eq!(stats2.edges_updated, 1);
1881
1882 let storage = db.storage().read().unwrap();
1883 let rows = storage.scan_rows(rel_table_id).unwrap();
1884 assert_eq!(rows[0].1[2], TypedValue::Int64(2025));
1885 drop(storage);
1886
1887 let batch3 = DeltaBatchBuilder::new("social", 3)
1889 .delete_edge("Person", "1", "KNOWS", "Person", "2")
1890 .build();
1891 let stats3 = conn.apply_delta(batch3).unwrap();
1892 assert_eq!(stats3.edges_deleted, 1);
1893
1894 let storage = db.storage().read().unwrap();
1895 let rows = storage.scan_rows(rel_table_id).unwrap();
1896 assert_eq!(rows.len(), 0);
1897 }
1898
1899 #[test]
1900 fn delta_idempotent_replay() {
1901 use kyu_delta::DeltaBatchBuilder;
1902
1903 let db = Database::in_memory();
1904 let conn = db.connect();
1905 conn.query("CREATE NODE TABLE File (path STRING, hash STRING, PRIMARY KEY (path))")
1906 .unwrap();
1907
1908 let batch = DeltaBatchBuilder::new("watcher", 100)
1909 .upsert_node("File", "src/main.rs", vec![], [("hash", TypedValue::String(SmolStr::new("abc123")))])
1910 .build();
1911
1912 let stats1 = conn.apply_delta(batch.clone()).unwrap();
1914 assert_eq!(stats1.nodes_created, 1);
1915
1916 let stats2 = conn.apply_delta(batch).unwrap();
1918 assert_eq!(stats2.nodes_created, 0);
1919 assert_eq!(stats2.nodes_updated, 1);
1920
1921 let result = conn.query("MATCH (f:File) RETURN f.path").unwrap();
1923 assert_eq!(result.num_rows(), 1);
1924 }
1925
1926 #[test]
1927 fn delta_stats_correct() {
1928 use kyu_delta::DeltaBatchBuilder;
1929
1930 let db = Database::in_memory();
1931 let conn = db.connect();
1932 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))").unwrap();
1933 conn.query("CREATE REL TABLE KNOWS (FROM Person TO Person)").unwrap();
1934
1935 let batch = DeltaBatchBuilder::new("test", 1)
1936 .upsert_node("Person", "1", vec![], [("name", TypedValue::String(SmolStr::new("Alice")))])
1937 .upsert_node("Person", "2", vec![], [("name", TypedValue::String(SmolStr::new("Bob")))])
1938 .upsert_edge("Person", "1", "KNOWS", "Person", "2", Vec::<(&str, TypedValue)>::new())
1939 .build();
1940
1941 let stats = conn.apply_delta(batch).unwrap();
1942 assert_eq!(stats.nodes_created, 2);
1943 assert_eq!(stats.edges_created, 1);
1944 assert_eq!(stats.total_deltas, 3);
1945 assert!(stats.elapsed_micros > 0);
1946 }
1947}