use super::*;
#[test]
fn cosine_batch_kernel_matches_scalar_scores() {
let query = vec![0.5_f32, -0.25, 1.0, 0.75];
let embeddings = vec![
0.4_f32, -0.1, 0.9, 0.8, -0.2, 0.7, 0.1, -0.9, 0.0, 0.0, 0.0, 0.0, ];
let batch = cosine_scores_batch_for_test(&query, &embeddings, 4);
assert_eq!(batch.len(), 3);
for (row_idx, score) in batch.iter().enumerate().take(3) {
let start = row_idx * 4;
let end = start + 4;
let scalar =
cosine_similarity_scalar_for_test(&query, &embeddings[start..end]).unwrap_or(0.0);
assert!((*score - scalar).abs() < 1e-9);
}
}
#[test]
fn execute_rejects_inline_vector_dimension_mismatch() {
let base = temp_dir("runtime_inline_vector_dim_mismatch");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
storage_api::put_full_node(&mut handle, 1, 1, &[2]).unwrap();
let delta = storage_api::encode_delta(
1,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Cosine,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.cosine(n.embedding, $q:1:0:2) > 0.3 RETURN n LIMIT 10")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let err = execute(&plan, &ExecuteParams::default(), &mut handle).unwrap_err();
assert!(format!("{:?}", err).contains("dimension mismatch"));
}
#[test]
fn execute_parallel_rerank_matches_single_worker() {
let base = temp_dir("runtime_parallel_match");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
for node_id in 1..=200 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
let delta = storage_api::encode_delta(
node_id,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Euclidean,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.euclidean(n.embedding, $q:1:0) < 0.1 RETURN n LIMIT 60")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let single = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 300,
morsel_size: 16,
parallel_workers: 1,
},
&mut handle,
)
.unwrap();
let parallel = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 300,
morsel_size: 16,
parallel_workers: 4,
},
&mut handle,
)
.unwrap();
let single_ids: Vec<u64> = single.rows.iter().map(|r| r.node_id).collect();
let parallel_ids: Vec<u64> = parallel.rows.iter().map(|r| r.node_id).collect();
assert_eq!(single_ids, parallel_ids);
assert!(parallel.parallel_workers >= 1);
assert!(parallel.morsels_processed >= 1);
}
#[test]
fn execute_auto_workers_uses_single_worker_for_small_rerank_sets() {
let base = temp_dir("runtime_auto_workers_small");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
for node_id in 1..=24 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
let delta = storage_api::encode_delta(
node_id,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Euclidean,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.euclidean(n.embedding, $q:1:0) < 0.1 RETURN n LIMIT 60")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let auto = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 64,
morsel_size: 16,
parallel_workers: 0,
},
&mut handle,
)
.unwrap();
assert_eq!(auto.parallel_workers, 1);
}
struct EnvVarGuard {
key: String,
old: Option<String>,
}
impl EnvVarGuard {
fn set(key: &str, value: &str) -> Self {
let old = std::env::var(key).ok();
std::env::set_var(key, value);
Self {
key: key.to_string(),
old,
}
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(value) = self.old.as_ref() {
std::env::set_var(&self.key, value);
} else {
std::env::remove_var(&self.key);
}
}
}
#[test]
fn execute_disables_internal_parallelism_when_thread_per_core_mode_enabled() {
let _guard = EnvVarGuard::set("IR_THREAD_PER_CORE", "true");
let base = temp_dir("runtime_thread_per_core_workers");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
for node_id in 1..=200 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
let delta = storage_api::encode_delta(
node_id,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Euclidean,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.euclidean(n.embedding, $q:1:0) < 0.1 RETURN n LIMIT 60")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let stream = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 300,
morsel_size: 16,
parallel_workers: 4,
},
&mut handle,
)
.unwrap();
assert_eq!(stream.parallel_workers, 1);
}
#[test]
fn execute_vector_predicate_stops_early_after_limit() {
let base = temp_dir("runtime_vector_early_exit");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
for node_id in 1..=2000 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
let delta = storage_api::encode_delta(
node_id,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Euclidean,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.euclidean(n.embedding, $q:1:0) < 0.1 RETURN n LIMIT 50")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let stream = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 2001,
morsel_size: 32,
parallel_workers: 4,
},
&mut handle,
)
.unwrap();
assert_eq!(stream.rows.len(), 50);
assert!(stream.scanned_nodes < 400);
}
#[test]
fn execute_vector_predicate_matches_full_scan_reference_for_top_k() {
let base = temp_dir("runtime_vector_early_exit_reference");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
for node_id in 1..=1500 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
let delta = storage_api::encode_delta(
node_id,
2,
&storage_api::encode_vector_payload_f32(
1,
storage_api::VectorMetric::Euclidean,
&[1.0, 0.0],
false,
),
);
storage_api::put_vector_delta(&mut handle, &delta).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE vector.euclidean(n.embedding, $q:1:0) < 0.1 RETURN n LIMIT 40")
.unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let fast = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 1501,
morsel_size: 32,
parallel_workers: 4,
},
&mut handle,
)
.unwrap();
let mut reference_plan = plan.clone();
reference_plan.limit = None;
let mut reference = execute(
&reference_plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 1501,
morsel_size: 32,
parallel_workers: 1,
},
&mut handle,
)
.unwrap();
reference.rows.truncate(40);
let fast_ids: Vec<u64> = fast.rows.iter().map(|row| row.node_id).collect();
let reference_ids: Vec<u64> = reference.rows.iter().map(|row| row.node_id).collect();
assert_eq!(fast_ids, reference_ids);
assert!(fast.scanned_nodes < reference.scanned_nodes);
}
#[test]
fn bitmap_limit_pushdown_caps_scan_count() {
let base = temp_dir("runtime_bitmap_limit_pushdown");
let mut handle = storage_api::open_store(storage_api::StorageConfig {
buffer_pool_pages: 8,
wal_dir: base.join("wal"),
wal_segment_max_bytes: 1 << 20,
manifest_path: base.join("ir.manifest"),
sstable_dir: base.join("sst"),
})
.unwrap();
storage_api::create_bitmap_index(&mut handle, "idx_country", "n.country").unwrap();
for node_id in 1..=1000 {
storage_api::put_full_node(&mut handle, node_id, 1, &[node_id + 1]).unwrap();
storage_api::bitmap_add_posting(&mut handle, "idx_country", "US", node_id).unwrap();
}
storage_api::flush(&mut handle).unwrap();
let typed = validate(
&parse("MATCH (n) WHERE bitmap.contains(idx_country, US) = 1 RETURN n LIMIT 50").unwrap(),
&Catalog,
)
.unwrap();
let plan = explain(&typed).unwrap();
let stream = execute(
&plan,
&ExecuteParams {
scan_start: 0,
scan_end_exclusive: 2000,
morsel_size: 16,
parallel_workers: 1,
},
&mut handle,
)
.unwrap();
assert_eq!(stream.rows.len(), 50);
assert_eq!(stream.scanned_nodes, 50);
let ids: Vec<u64> = stream.rows.iter().map(|row| row.node_id).collect();
assert_eq!(ids.first().copied(), Some(1));
assert_eq!(ids.last().copied(), Some(50));
}