iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use crate::features::runtime::api::ExecutablePredicate;
use crate::features::storage::api as storage_api;
use crate::features::storage::sstable::EntryKind;

#[cfg(all(feature = "zig", has_zig))]
extern "C" {
    fn iridium_zig_cosine_score_batch(
        query: *const f32,
        dim: usize,
        embeddings: *const f32,
        rows: usize,
        out_scores: *mut f64,
    );
}

pub(crate) fn compare_score(score: f64, operator: &str, threshold: f64) -> bool {
    match operator {
        ">" => score > threshold,
        "<" => score < threshold,
        "=" => (score - threshold).abs() < f64::EPSILON,
        ">=" => score >= threshold,
        "<=" => score <= threshold,
        _ => false,
    }
}

pub(super) fn extract_latest_vector(
    handle: &storage_api::StorageHandle,
    logical: &storage_api::LogicalNode,
    expected_metric: storage_api::VectorMetric,
) -> Result<Option<storage_api::StructuredVector>, String> {
    let latest = logical
        .deltas
        .iter()
        .filter(|entry| entry.kind == EntryKind::VectorDelta)
        .max_by_key(|entry| entry.version);
    let Some(entry) = latest else {
        return Ok(None);
    };
    storage_api::decode_runtime_vector(handle, &entry.value, Some(expected_metric))
        .map_err(|err| format!("{:?}", err))
}

pub(super) fn build_query_vector(
    pred: &ExecutablePredicate,
    dim: usize,
) -> Result<Vec<f32>, String> {
    if let Some(inline) = pred.inline_vector.as_ref() {
        if inline.len() != dim {
            return Err(format!(
                "query vector dimension mismatch: expected {}, got {}",
                dim,
                inline.len()
            ));
        }
        return Ok(inline.clone());
    }

    let mut state = 1469598103934665603_u64;
    for byte in pred.param.as_bytes() {
        state ^= *byte as u64;
        state = state.wrapping_mul(1099511628211_u64);
    }
    let mut out = Vec::with_capacity(dim);
    for _ in 0..dim {
        state = state
            .wrapping_mul(6364136223846793005_u64)
            .wrapping_add(1442695040888963407_u64);
        let raw = ((state >> 40) as u32) as f32 / (u24_max() as f32);
        out.push((raw * 2.0) - 1.0);
    }
    Ok(out)
}

pub(super) fn score_embeddings_batch(
    metric: storage_api::VectorMetric,
    query: &[f32],
    embeddings_flat: &[f32],
    dim: usize,
) -> Vec<f64> {
    match metric {
        storage_api::VectorMetric::Cosine => cosine_scores_batch(query, embeddings_flat, dim),
        storage_api::VectorMetric::Euclidean => euclidean_scores_batch(query, embeddings_flat, dim),
    }
}

pub(super) fn cosine_scores_batch(query: &[f32], embeddings_flat: &[f32], dim: usize) -> Vec<f64> {
    if dim == 0 || query.len() != dim || embeddings_flat.is_empty() {
        return Vec::new();
    }
    if !embeddings_flat.len().is_multiple_of(dim) {
        return Vec::new();
    }
    let rows = embeddings_flat.len() / dim;
    let mut out_scores = vec![0.0_f64; rows];

    #[cfg(all(feature = "zig", has_zig))]
    {
        unsafe {
            iridium_zig_cosine_score_batch(
                query.as_ptr(),
                dim,
                embeddings_flat.as_ptr(),
                rows,
                out_scores.as_mut_ptr(),
            );
        }
        out_scores
    }

    #[cfg(not(all(feature = "zig", has_zig)))]
    {
        for row_idx in 0..rows {
            let start = row_idx * dim;
            let end = start + dim;
            let row = &embeddings_flat[start..end];
            out_scores[row_idx] = cosine_similarity(query, row).unwrap_or(0.0);
        }
        out_scores
    }
}

fn euclidean_scores_batch(query: &[f32], embeddings_flat: &[f32], dim: usize) -> Vec<f64> {
    if dim == 0 || query.len() != dim || embeddings_flat.is_empty() {
        return Vec::new();
    }
    if !embeddings_flat.len().is_multiple_of(dim) {
        return Vec::new();
    }
    let rows = embeddings_flat.len() / dim;
    let mut out_scores = vec![0.0_f64; rows];
    for (row_idx, out_score) in out_scores.iter_mut().enumerate().take(rows) {
        let start = row_idx * dim;
        let end = start + dim;
        let row = &embeddings_flat[start..end];
        *out_score = euclidean_distance(query, row).unwrap_or(f64::INFINITY);
    }
    out_scores
}

#[cfg(test)]
pub(crate) fn cosine_scores_batch_for_test(
    query: &[f32],
    embeddings_flat: &[f32],
    dim: usize,
) -> Vec<f64> {
    cosine_scores_batch(query, embeddings_flat, dim)
}

#[cfg(test)]
pub(crate) fn cosine_similarity_scalar_for_test(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
    cosine_similarity(lhs, rhs)
}

#[allow(dead_code)]
fn cosine_similarity(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
    if lhs.len() != rhs.len() || lhs.is_empty() {
        return None;
    }
    let mut dot = 0.0_f64;
    let mut lhs_norm = 0.0_f64;
    let mut rhs_norm = 0.0_f64;
    for (a, b) in lhs.iter().zip(rhs.iter()) {
        let a = *a as f64;
        let b = *b as f64;
        dot += a * b;
        lhs_norm += a * a;
        rhs_norm += b * b;
    }
    if lhs_norm <= f64::EPSILON || rhs_norm <= f64::EPSILON {
        return None;
    }
    Some(dot / (lhs_norm.sqrt() * rhs_norm.sqrt()))
}

fn euclidean_distance(lhs: &[f32], rhs: &[f32]) -> Option<f64> {
    if lhs.len() != rhs.len() || lhs.is_empty() {
        return None;
    }
    let mut sum = 0.0_f64;
    for (a, b) in lhs.iter().zip(rhs.iter()) {
        let delta = (*a as f64) - (*b as f64);
        sum += delta * delta;
    }
    Some(sum.sqrt())
}

fn u24_max() -> u32 {
    (1_u32 << 24) - 1
}