axiomsync 1.0.0

Core data-processing engine for AxiomSync local retrieval runtime.
Documentation
use std::collections::HashMap;
use std::time::Instant;

use uuid::Uuid;

use crate::index::InMemoryIndex;
use crate::models::{
    ContextHit, FindResult, QueryPlan, RetrievalStep, RetrievalTrace, SearchOptions, TracePoint,
    TraceStats, classify_hit_buckets,
};

use super::budget::{ResolvedBudget, resolve_budget};
use super::config::DrrConfig;
use super::expansion::run_single_query;
use super::planner::{PlannedQuery, collect_scope_names, is_om_hint, plan_queries};
use super::scoring::{
    fanout_priority_weight, merge_hits, merge_trace_points, scale_hit_scores,
    scale_trace_point_scores, sort_hits_by_score_desc_uri_asc, sorted_trace_points,
    tokenize_keywords, typed_query_plans,
};

#[derive(Debug, Clone)]
pub struct DrrEngine {
    config: DrrConfig,
}

const FANOUT_PRIORITY_WEIGHT_NOTE: &str = "fanout_weight:p1=1.00,p2=0.82,p3=0.64,p4+=0.46";

impl DrrEngine {
    #[must_use]
    pub const fn new(config: DrrConfig) -> Self {
        Self { config }
    }

    pub fn run(&self, index: &InMemoryIndex, options: &SearchOptions) -> FindResult {
        let start = Instant::now();
        let trace_id = Uuid::new_v4().to_string();
        let planned_queries = plan_queries(options);
        let request_budget = resolve_budget(&self.config, options.budget.as_ref());
        let fanout = execute_planned_queries(
            &self.config,
            index,
            options,
            &planned_queries,
            request_budget,
            start,
        );

        let limit = options.limit.max(1);
        let mut hits: Vec<_> = fanout.merged_hits.into_values().collect();
        sort_hits_by_score_desc_uri_asc(&mut hits);
        hits.truncate(limit);

        let final_topk = hits
            .iter()
            .map(|h| TracePoint {
                uri: h.uri.clone(),
                score: h.score,
            })
            .collect::<Vec<_>>();

        let start_points = sorted_trace_points(fanout.merged_start_points);
        let stop_reason = build_stop_reason(&fanout.stop_reasons);

        let trace = RetrievalTrace {
            trace_id,
            request_type: options.request_type.clone(),
            query: options.query.clone(),
            target_uri: options.target_uri.as_ref().map(ToString::to_string),
            start_points,
            steps: fanout.merged_steps,
            final_topk,
            stop_reason,
            metrics: TraceStats {
                latency_ms: start.elapsed().as_millis(),
                explored_nodes: fanout.explored_nodes,
                convergence_rounds: fanout.convergence_rounds,
                typed_query_count: planned_queries.len(),
                relation_enriched_hits: 0,
                relation_enriched_links: 0,
            },
        };

        let hit_buckets = classify_hit_buckets(&hits);
        let notes = build_query_notes(options, request_budget, planned_queries.len());
        let memories = hit_buckets
            .memories
            .iter()
            .filter_map(|&index| hits.get(index).cloned())
            .collect::<Vec<_>>();
        let resources = hit_buckets
            .resources
            .iter()
            .filter_map(|&index| hits.get(index).cloned())
            .collect::<Vec<_>>();
        let skills = hit_buckets
            .skills
            .iter()
            .filter_map(|&index| hits.get(index).cloned())
            .collect::<Vec<_>>();

        FindResult {
            query_plan: QueryPlan {
                scopes: collect_scope_names(&planned_queries),
                keywords: tokenize_keywords(&options.query),
                typed_queries: typed_query_plans(&planned_queries),
                notes,
            },
            query_results: hits,
            hit_buckets,
            memories,
            resources,
            skills,
            trace: Some(trace),
            trace_uri: None,
        }
    }
}

#[derive(Debug, Default)]
struct FanoutState {
    merged_hits: HashMap<String, ContextHit>,
    merged_start_points: HashMap<String, f32>,
    merged_steps: Vec<RetrievalStep>,
    explored_nodes: usize,
    convergence_rounds: u32,
    stop_reasons: Vec<String>,
}

fn execute_planned_queries(
    config: &DrrConfig,
    index: &InMemoryIndex,
    options: &SearchOptions,
    planned_queries: &[PlannedQuery],
    request_budget: ResolvedBudget,
    start: Instant,
) -> FanoutState {
    let mut state = FanoutState::default();
    let mut round_offset = 0u32;
    let mut remaining_nodes = request_budget.nodes;

    for planned in planned_queries {
        if remaining_nodes == 0 {
            state.stop_reasons.push("budget_nodes".to_string());
            break;
        }

        let remaining_ms = request_budget.time_ms.map(|max_ms| {
            let elapsed = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
            max_ms.saturating_sub(elapsed)
        });
        if remaining_ms == Some(0) {
            state.stop_reasons.push("budget_ms".to_string());
            break;
        }

        let mut single = run_single_query(
            config,
            index,
            options,
            planned,
            ResolvedBudget {
                time_ms: remaining_ms,
                nodes: remaining_nodes,
                depth: request_budget.depth,
            },
        );
        let weight = fanout_priority_weight(planned.priority);
        scale_hit_scores(&mut single.hits, weight);
        scale_trace_point_scores(&mut single.trace.start_points, weight);
        merge_hits(&mut state.merged_hits, single.hits);
        merge_trace_points(&mut state.merged_start_points, &single.trace.start_points);

        for step in single.trace.steps {
            state.merged_steps.push(RetrievalStep {
                round: step.round.saturating_add(round_offset),
                current_uri: step.current_uri,
                children_examined: step.children_examined,
                children_selected: step.children_selected,
                queue_size_after: step.queue_size_after,
            });
        }
        if let Some(last_round) = state.merged_steps.last().map(|step| step.round) {
            round_offset = last_round;
        }

        state.explored_nodes += single.trace.metrics.explored_nodes;
        state.convergence_rounds += single.trace.metrics.convergence_rounds;
        state.stop_reasons.push(single.trace.stop_reason);
        remaining_nodes = remaining_nodes.saturating_sub(single.trace.metrics.explored_nodes);
    }

    state
}

fn build_stop_reason(stop_reasons: &[String]) -> String {
    if stop_reasons.len() <= 1 {
        stop_reasons
            .first()
            .cloned()
            .unwrap_or_else(|| "queue_empty".to_string())
    } else {
        format!("fanout:{}", stop_reasons.join("|"))
    }
}

fn build_query_notes(
    options: &SearchOptions,
    request_budget: ResolvedBudget,
    planned_query_count: usize,
) -> Vec<String> {
    let mut notes = vec![
        "drr".to_string(),
        format!("fanout:{planned_query_count}"),
        FANOUT_PRIORITY_WEIGHT_NOTE.to_string(),
        format!("budget_nodes:{}", request_budget.nodes),
        format!("budget_depth:{}", request_budget.depth),
    ];
    if !options.session_hints.is_empty() {
        notes.push(format!("session_hints:{}", options.session_hints.len()));
        let om_hint_count = options
            .session_hints
            .iter()
            .filter(|hint| is_om_hint(hint))
            .count();
        notes.push(format!("session_om_hints:{om_hint_count}"));
    }
    if options.filter.is_some() {
        notes.push("filter".to_string());
    }
    if let Some(max_ms) = request_budget.time_ms {
        notes.push(format!("budget_ms:{max_ms}"));
    }
    notes
}