iridium-db 0.2.0

A high-performance vector-graph hybrid storage and indexing engine
use std::sync::{OnceLock, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};

use crate::features::query::TypedAst;
use crate::features::storage::api as storage_api;

use super::types::{
    ExecutableBitmapPredicate, ExecutablePredicate, ExecutableScalarPredicate, ExecuteParams,
    ExplainError, ExplainPlan, PhysicalOp, PlannerStatsSnapshot, Result,
};

const PLANNER_STATS_SCHEMA_VERSION: u32 = 1;
static PLANNER_STATS: OnceLock<RwLock<Option<PlannerStatsSnapshot>>> = OnceLock::new();

fn planner_stats_cell() -> &'static RwLock<Option<PlannerStatsSnapshot>> {
    PLANNER_STATS.get_or_init(|| RwLock::new(None))
}

pub fn collect_planner_stats(
    handle: &storage_api::StorageHandle,
    params: &ExecuteParams,
) -> PlannerStatsSnapshot {
    let scan_range = params.scan_end_exclusive.saturating_sub(params.scan_start);
    let sstable_count = handle.l0_runs.len() as u64 + handle.sstable_cache.len() as u64;

    let node_scan_base_cost = (scan_range / 1_000 + 1)
        .saturating_mul(sstable_count + 1)
        .saturating_mul(10);

    let vector_scan_base_cost = if handle.hnsw_total_vectors > 0 {
        (handle.hnsw_total_vectors as f64).log2() as u64 * 15 + 20
    } else {
        node_scan_base_cost.saturating_mul(3)
    };

    let vector_selectivity_ppm = (handle
        .hnsw_total_vectors
        .min(scan_range)
        .saturating_mul(1_000_000)
        / scan_range.max(1))
    .min(1_000_000) as u32;

    let graph_selectivity_ppm = (sstable_count.saturating_mul(100_000)).min(800_000) as u32;

    let skew_penalty_cost = if handle.pending_deltas_per_node.is_empty() {
        0
    } else {
        let max_pending = handle
            .pending_deltas_per_node
            .values()
            .copied()
            .max()
            .unwrap_or(0) as u64;
        let avg_pending = handle
            .pending_deltas_per_node
            .values()
            .map(|&v| v as u64)
            .sum::<u64>()
            / handle.pending_deltas_per_node.len() as u64;
        max_pending.saturating_sub(avg_pending).saturating_mul(2)
    };

    let ts = now_millis();
    PlannerStatsSnapshot {
        schema_version: PLANNER_STATS_SCHEMA_VERSION,
        stats_version: ts,
        collected_at_millis: ts,
        ttl_millis: 60_000,
        node_scan_base_cost,
        vector_scan_base_cost,
        filter_base_cost: 10,
        vector_selectivity_ppm,
        graph_selectivity_ppm,
        skew_penalty_cost,
    }
}

pub fn set_planner_stats(stats: PlannerStatsSnapshot) -> Result<()> {
    if stats.schema_version != PLANNER_STATS_SCHEMA_VERSION {
        return Err(ExplainError::InvalidPlan(format!(
            "unsupported planner stats schema version {} (expected {})",
            stats.schema_version, PLANNER_STATS_SCHEMA_VERSION
        )));
    }
    if stats.vector_selectivity_ppm > 1_000_000 || stats.graph_selectivity_ppm > 1_000_000 {
        return Err(ExplainError::InvalidPlan(
            "planner selectivity ppm must be <= 1000000".to_string(),
        ));
    }
    if stats.ttl_millis == 0 {
        return Err(ExplainError::InvalidPlan(
            "planner stats ttl_millis must be > 0".to_string(),
        ));
    }
    let mut guard = planner_stats_cell()
        .write()
        .map_err(|_| ExplainError::InvalidPlan("planner stats lock poisoned".to_string()))?;
    *guard = Some(stats);
    Ok(())
}

pub fn clear_planner_stats() {
    if let Ok(mut guard) = planner_stats_cell().write() {
        *guard = None;
    }
}

pub fn planner_stats_snapshot() -> Option<PlannerStatsSnapshot> {
    planner_stats_cell()
        .read()
        .ok()
        .and_then(|guard| guard.clone())
}

pub fn explain(typed: &TypedAst) -> Result<ExplainPlan> {
    let ast = &typed.ast;
    if ast.match_alias.is_empty() {
        return Err(ExplainError::InvalidPlan(
            "MATCH alias must not be empty".to_string(),
        ));
    }

    let has_scalar_predicate = typed.scalar_predicate.is_some();
    let has_vector_predicate = ast.where_predicate.is_some();
    let has_bitmap_predicate = ast.where_bitmap_predicate.is_some();
    let stats = planner_stats_snapshot().filter(is_stats_fresh);
    let (bitmap_first, vector_first, planner_mode, stats_version) =
        choose_strategy(has_vector_predicate, has_bitmap_predicate, stats.as_ref());
    let strategy = if bitmap_first {
        "bitmap-first"
    } else if vector_first {
        "vector-first"
    } else {
        "graph-first"
    }
    .to_string();

    let match_projection = ast.match_aliases.join(", ");
    let mut logical_ops = vec![format!("Match({})", match_projection)];
    if let Some(pred) = &ast.where_bitmap_predicate {
        logical_ops.push(format!(
            "BitmapFilter({} = {})",
            pred.index_name, pred.value_key
        ));
    }
    if let Some(pred) = &ast.where_predicate {
        logical_ops.push(format!(
            "Filter({} {} {})",
            pred.function, pred.operator, pred.threshold
        ));
    }
    if let Some(pred) = &typed.scalar_predicate {
        logical_ops.push(format!(
            "ScalarFilter({} {} {})",
            pred.field, pred.operator, pred.value
        ));
    }
    if !ast.with_items.is_empty() {
        let with_items = ast
            .with_items
            .iter()
            .map(format_projection)
            .collect::<Vec<_>>()
            .join(", ");
        logical_ops.push(format!("With({})", with_items));
    }
    let return_items = ast
        .return_items
        .iter()
        .map(format_projection)
        .collect::<Vec<_>>()
        .join(", ");
    logical_ops.push(format!("Return({})", return_items));
    if let Some(limit) = ast.limit {
        logical_ops.push(format!("Limit({})", limit));
    }

    let mut physical_ops = vec![if bitmap_first {
        PhysicalOp::BitmapScan
    } else if vector_first {
        PhysicalOp::VectorScan
    } else {
        PhysicalOp::NodeScan
    }];
    if ast.where_predicate.is_some() || ast.where_bitmap_predicate.is_some() || has_scalar_predicate
    {
        physical_ops.push(PhysicalOp::Filter);
    }
    physical_ops.push(PhysicalOp::Project);
    if ast.limit.is_some() {
        physical_ops.push(PhysicalOp::Limit);
    }

    let mut estimated_cost: u64 = if let Some(snapshot) = stats.as_ref() {
        estimate_cost(
            snapshot,
            bitmap_first,
            vector_first,
            has_vector_predicate,
            has_bitmap_predicate,
        )
    } else {
        let mut base: u64 = if bitmap_first {
            70
        } else if vector_first {
            90
        } else {
            140
        };
        if has_vector_predicate {
            base += 20;
        }
        if has_bitmap_predicate {
            base = base.saturating_sub(20);
        }
        base
    };
    if let Some(limit) = ast.limit {
        estimated_cost = estimated_cost.saturating_sub(limit.min(100) / 5);
    }

    let predicate = if let Some(pred) = &ast.where_predicate {
        let threshold = pred
            .threshold
            .parse::<f64>()
            .map_err(|_| ExplainError::InvalidPlan("invalid predicate threshold".to_string()))?;
        let metric = storage_api::VectorMetric::from_function(&pred.function).ok_or_else(|| {
            ExplainError::InvalidPlan(format!(
                "unsupported vector predicate function '{}'",
                pred.function
            ))
        })?;
        Some(ExecutablePredicate {
            metric,
            param: pred.param.clone(),
            inline_vector: storage_api::parse_inline_vector_param(&pred.param),
            operator: pred.operator.clone(),
            threshold,
        })
    } else {
        None
    };
    let bitmap_predicate =
        ast.where_bitmap_predicate
            .as_ref()
            .map(|pred| ExecutableBitmapPredicate {
                index_name: pred.index_name.clone(),
                value_key: pred.value_key.clone(),
            });
    let scalar_predicate = typed
        .scalar_predicate
        .as_ref()
        .map(|pred| ExecutableScalarPredicate {
            field: pred.field.clone(),
            operator: pred.operator.clone(),
            value: pred.value,
        });

    Ok(ExplainPlan {
        strategy,
        planner_mode,
        stats_version,
        logical_ops,
        physical_ops,
        estimated_cost,
        limit: ast.limit,
        predicate,
        bitmap_predicate,
        scalar_predicate,
    })
}

fn choose_strategy(
    has_vector_predicate: bool,
    has_bitmap_predicate: bool,
    stats: Option<&PlannerStatsSnapshot>,
) -> (bool, bool, String, Option<u64>) {
    if has_bitmap_predicate {
        return (
            true,
            false,
            "heuristic".to_string(),
            stats.map(|s| s.stats_version),
        );
    }
    if !has_vector_predicate {
        return (
            false,
            false,
            "heuristic".to_string(),
            stats.map(|s| s.stats_version),
        );
    }
    if let Some(snapshot) = stats {
        let vector_cost = weighted_cost(
            snapshot.vector_scan_base_cost,
            snapshot.vector_selectivity_ppm,
        )
        .saturating_add(snapshot.filter_base_cost);
        let graph_cost =
            weighted_cost(snapshot.node_scan_base_cost, snapshot.graph_selectivity_ppm)
                .saturating_add(snapshot.filter_base_cost)
                .saturating_add(snapshot.skew_penalty_cost);
        return (
            false,
            vector_cost <= graph_cost,
            "cbo".to_string(),
            Some(snapshot.stats_version),
        );
    }
    (false, true, "heuristic".to_string(), None)
}

fn estimate_cost(
    stats: &PlannerStatsSnapshot,
    bitmap_first: bool,
    vector_first: bool,
    has_vector_predicate: bool,
    has_bitmap_predicate: bool,
) -> u64 {
    let mut base = if bitmap_first {
        weighted_cost(stats.node_scan_base_cost, 100_000)
    } else if vector_first {
        weighted_cost(stats.vector_scan_base_cost, stats.vector_selectivity_ppm)
    } else {
        weighted_cost(stats.node_scan_base_cost, stats.graph_selectivity_ppm)
            .saturating_add(stats.skew_penalty_cost)
    };
    if has_vector_predicate {
        base = base.saturating_add(stats.filter_base_cost);
    }
    if has_bitmap_predicate {
        base = base.saturating_sub(stats.filter_base_cost / 2);
    }
    base
}

fn weighted_cost(base: u64, selectivity_ppm: u32) -> u64 {
    let scaled = (base as u128)
        .saturating_mul(selectivity_ppm.max(1) as u128)
        .saturating_div(1_000_000_u128);
    scaled.min(u64::MAX as u128) as u64
}

fn now_millis() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_millis()
        .min(u64::MAX as u128) as u64
}

fn is_stats_fresh(stats: &PlannerStatsSnapshot) -> bool {
    let age = now_millis().saturating_sub(stats.collected_at_millis);
    age <= stats.ttl_millis
}

fn format_projection(item: &crate::features::query::ProjectionItem) -> String {
    match item {
        crate::features::query::ProjectionItem::Identifier(value) => value.clone(),
        crate::features::query::ProjectionItem::Function {
            name,
            argument,
            alias,
        } => {
            if let Some(alias) = alias {
                format!("{}({}) AS {}", name, argument, alias)
            } else {
                format!("{}({})", name, argument)
            }
        }
    }
}