use crate::features::runtime::api::ExecutablePredicate;
use crate::features::storage::api as storage_api;
use crate::features::storage::sstable::EntryKind;
#[cfg(all(feature = "zig", has_zig))]
extern "C" {
fn iridium_zig_cosine_score_batch(
query: *const f32,
dim: usize,
embeddings: *const f32,
rows: usize,
out_scores: *mut f64,
);
}
pub(crate) fn compare_score(score: f64, operator: &str, threshold: f64) -> bool {
match operator {
">" => score > threshold,
"<" => score < threshold,
"=" => (score - threshold).abs() < f64::EPSILON,
">=" => score >= threshold,
"<=" => score <= threshold,
_ => false,
}
}
pub(super) fn extract_latest_vector(
handle: &storage_api::StorageHandle,
logical: &storage_api::LogicalNode,
expected_metric: storage_api::VectorMetric,
) -> Result<Option<storage_api::StructuredVector>, String> {
let latest = logical
.deltas
.iter()
.filter(|entry| entry.kind == EntryKind::VectorDelta)
.max_by_key(|entry| entry.version);
let Some(entry) = latest else {
return Ok(None);
};
storage_api::decode_runtime_vector(handle, &entry.value, Some(expected_metric))
.map_err(|err| format!("{:?}", err))
}
pub(super) fn build_query_vector(
pred: &ExecutablePredicate,
dim: usize,
) -> Result<Vec<f32>, String> {
if let Some(inline) = pred.inline_vector.as_ref() {
if inline.len() != dim {
return Err(format!(
"query vector dimension mismatch: expected {}, got {}",
dim,
inline.len()
));
}
return Ok(inline.clone());
}
let mut state = 1469598103934665603_u64;
for byte in pred.param.as_bytes() {
state ^= *byte as u64;
state = state.wrapping_mul(1099511628211_u64);
}
let mut out = Vec::with_capacity(dim);
for _ in 0..dim {
state = state
.wrapping_mul(6364136223846793005_u64)
.wrapping_add(1442695040888963407_u64);
let raw = ((state >> 40) as u32) as f32 / (u24_max() as f32);
out.push((raw * 2.0) - 1.0);
}
Ok(out)
}
pub(super) fn score_embeddings_batch(
metric: storage_api::VectorMetric,
query: &[f32],
embeddings_flat: &[f32],
dim: usize,
) -> Vec<f64> {
match metric {
storage_api::VectorMetric::Cosine => cosine_scores_batch(query, embeddings_flat, dim),
storage_api::VectorMetric::Euclidean => euclidean_scores_batch(query, embeddings_flat, dim),
}
}
pub(super) fn cosine_scores_batch(query: &[f32], embeddings_flat: &[f32], dim: usize) -> Vec<f64> {
if dim == 0 || query.len() != dim || embeddings_flat.is_empty() {
return Vec::new();
}
if !embeddings_flat.len().is_multiple_of(dim) {
return Vec::new();
}
let rows = embeddings_flat.len() / dim;
let mut out_scores = vec![0.0_f64; rows];
#[cfg(all(feature = "zig", has_zig))]
{
unsafe {
iridium_zig_cosine_score_batch(
query.as_ptr(),
dim,
embeddings_flat.as_ptr(),
rows,
out_scores.as_mut_ptr(),
);
}
out_scores
}
#[cfg(not(all(feature = "zig", has_zig)))]
{
for row_idx in 0..rows {
let start = row_idx * dim;
let end = start + dim;
let row = &embeddings_flat[start..end];
out_scores[row_idx] = cosine_similarity(query, row).unwrap_or(0.0);
}
out_scores
}
}
fn euclidean_scores_batch(query: &[f32], embeddings_flat: &[f32], dim: usize) -> Vec<f64> {
if dim == 0 || query.len() != dim || embeddings_flat.is_empty() {
return Vec::new();
}
if !embeddings_flat.len().is_multiple_of(dim) {
return Vec::new();
}
let rows = embeddings_flat.len() / dim;
let mut out_scores = vec![0.0_f64; rows];
for (row_idx, out_score) in out_scores.iter_mut().enumerate().take(rows) {
let start = row_idx * dim;
let end = start + dim;
let row = &embeddings_flat[start..end];
*out_score = euclidean_distance(query, row).unwrap_or(f64::INFINITY);
}
out_scores
}
#[cfg(test)]
pub(crate) fn cosine_scores_batch_for_test(
query: &[f32],
embeddings_flat: &[f32],
dim: usize,
) -> Vec<f64> {
cosine_scores_batch(query, embeddings_flat, dim)
}
#[cfg(test)]
pub(crate) fn cosine_similarity_scalar_for_test(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
cosine_similarity(lhs, rhs)
}
#[allow(dead_code)]
fn cosine_similarity(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
if lhs.len() != rhs.len() || lhs.is_empty() {
return None;
}
let mut dot = 0.0_f64;
let mut lhs_norm = 0.0_f64;
let mut rhs_norm = 0.0_f64;
for (a, b) in lhs.iter().zip(rhs.iter()) {
let a = *a as f64;
let b = *b as f64;
dot += a * b;
lhs_norm += a * a;
rhs_norm += b * b;
}
if lhs_norm <= f64::EPSILON || rhs_norm <= f64::EPSILON {
return None;
}
Some(dot / (lhs_norm.sqrt() * rhs_norm.sqrt()))
}
fn euclidean_distance(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
if lhs.len() != rhs.len() || lhs.is_empty() {
return None;
}
let mut sum = 0.0_f64;
for (a, b) in lhs.iter().zip(rhs.iter()) {
let delta = (*a as f64) - (*b as f64);
sum += delta * delta;
}
Some(sum.sqrt())
}
fn u24_max() -> u32 {
(1_u32 << 24) - 1
}