mnemara-core 0.6.0

Local-first, explainable AI memory engine for embedded and service-based systems
Documentation
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;

use crate::Result;
use crate::query::{RecallQuery, RecallResult};
use crate::store::MemoryStore;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct JudgedRecallCase {
    pub name: String,
    pub query: RecallQuery,
    pub relevant_record_ids: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(default)]
pub struct RecallEvaluationAssertions {
    pub expected_record_ids: Vec<String>,
    pub optional_record_ids: Vec<String>,
    pub disallowed_record_ids: Vec<String>,
    pub required_explanation_notes: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RecallEvaluationCase {
    pub name: String,
    pub query: RecallQuery,
    pub assertions: RecallEvaluationAssertions,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RecallEvaluationCaseReport {
    pub name: String,
    pub passed: bool,
    pub ranked_record_ids: Vec<String>,
    pub missing_expected_record_ids: Vec<String>,
    pub present_disallowed_record_ids: Vec<String>,
    pub missing_explanation_notes: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RecallEvaluationReport {
    pub cases: usize,
    pub passed_cases: usize,
    pub failed_cases: usize,
    pub pass_rate: f32,
    pub ranking_metrics: RankingMetrics,
    pub case_reports: Vec<RecallEvaluationCaseReport>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RankingMetrics {
    pub cases: usize,
    pub hit_rate_at_k: f32,
    pub recall_at_k: f32,
    pub mrr: f32,
    pub ndcg_at_k: f32,
}

pub fn evaluate_rankings_at_k(rankings: &[(&[String], &[String])], k: usize) -> RankingMetrics {
    if rankings.is_empty() || k == 0 {
        return RankingMetrics {
            cases: rankings.len(),
            hit_rate_at_k: 0.0,
            recall_at_k: 0.0,
            mrr: 0.0,
            ndcg_at_k: 0.0,
        };
    }

    let mut hits = 0.0f32;
    let mut recall = 0.0f32;
    let mut reciprocal_rank = 0.0f32;
    let mut ndcg = 0.0f32;

    for (ranked_ids, relevant_ids) in rankings {
        let relevant = relevant_ids.iter().cloned().collect::<BTreeSet<_>>();
        if relevant.is_empty() {
            continue;
        }
        let top_k = ranked_ids.iter().take(k).collect::<Vec<_>>();
        let matches = top_k
            .iter()
            .filter(|record_id| relevant.contains(record_id.as_str()))
            .count() as f32;
        if matches > 0.0 {
            hits += 1.0;
        }
        recall += matches / relevant.len() as f32;
        if let Some(rank) = ranked_ids
            .iter()
            .position(|record_id| relevant.contains(record_id.as_str()))
        {
            reciprocal_rank += 1.0 / (rank as f32 + 1.0);
        }

        let dcg = top_k
            .iter()
            .enumerate()
            .filter(|(_, record_id)| relevant.contains(record_id.as_str()))
            .map(|(index, _)| 1.0 / ((index as f32 + 2.0).log2()))
            .sum::<f32>();
        let ideal_hits = relevant.len().min(k);
        let ideal_dcg = (0..ideal_hits)
            .map(|index| 1.0 / ((index as f32 + 2.0).log2()))
            .sum::<f32>();
        if ideal_dcg > 0.0 {
            ndcg += dcg / ideal_dcg;
        }
    }

    let cases = rankings.len() as f32;
    RankingMetrics {
        cases: rankings.len(),
        hit_rate_at_k: hits / cases,
        recall_at_k: recall / cases,
        mrr: reciprocal_rank / cases,
        ndcg_at_k: ndcg / cases,
    }
}

pub async fn run_recall_evaluation<S>(
    store: &S,
    cases: &[RecallEvaluationCase],
    k: usize,
) -> Result<RecallEvaluationReport>
where
    S: MemoryStore + ?Sized,
{
    let mut results = Vec::with_capacity(cases.len());
    for case in cases {
        results.push(store.recall(case.query.clone()).await?);
    }
    Ok(evaluate_recall_results(cases, &results, k))
}

pub fn evaluate_recall_results(
    cases: &[RecallEvaluationCase],
    results: &[RecallResult],
    k: usize,
) -> RecallEvaluationReport {
    let mut ranked = Vec::<Vec<String>>::with_capacity(results.len());
    let mut relevant = Vec::<Vec<String>>::with_capacity(cases.len());
    let mut case_reports = Vec::with_capacity(cases.len());

    for (case, result) in cases.iter().zip(results.iter()) {
        let ranked_record_ids = result
            .hits
            .iter()
            .map(|hit| hit.record.id.clone())
            .collect::<Vec<_>>();
        let ranked_set = ranked_record_ids.iter().cloned().collect::<BTreeSet<_>>();
        let expected_set = case
            .assertions
            .expected_record_ids
            .iter()
            .cloned()
            .collect::<BTreeSet<_>>();
        let optional = case
            .assertions
            .optional_record_ids
            .iter()
            .cloned()
            .collect::<BTreeSet<_>>();
        let mut relevant_ids = expected_set
            .union(&optional)
            .cloned()
            .collect::<Vec<String>>();
        relevant_ids.sort();

        let missing_expected_record_ids = expected_set
            .difference(&ranked_set)
            .cloned()
            .collect::<Vec<_>>();
        let present_disallowed_record_ids = case
            .assertions
            .disallowed_record_ids
            .iter()
            .filter(|record_id| ranked_set.contains(*record_id))
            .cloned()
            .collect::<Vec<_>>();

        let explanation_notes = result
            .explanation
            .as_ref()
            .map(|explanation| explanation.policy_notes.as_slice())
            .unwrap_or_default();
        let missing_explanation_notes = case
            .assertions
            .required_explanation_notes
            .iter()
            .filter(|required| {
                !explanation_notes
                    .iter()
                    .any(|note| note.contains(required.as_str()))
            })
            .cloned()
            .collect::<Vec<_>>();

        let passed = missing_expected_record_ids.is_empty()
            && present_disallowed_record_ids.is_empty()
            && missing_explanation_notes.is_empty();

        ranked.push(ranked_record_ids.clone());
        relevant.push(relevant_ids);
        case_reports.push(RecallEvaluationCaseReport {
            name: case.name.clone(),
            passed,
            ranked_record_ids,
            missing_expected_record_ids,
            present_disallowed_record_ids,
            missing_explanation_notes,
        });
    }

    let ranking_pairs = ranked
        .iter()
        .zip(relevant.iter())
        .map(|(ranked_ids, relevant_ids)| (ranked_ids.as_slice(), relevant_ids.as_slice()))
        .collect::<Vec<_>>();
    let passed_cases = case_reports.iter().filter(|report| report.passed).count();
    let cases_len = cases.len();

    RecallEvaluationReport {
        cases: cases_len,
        passed_cases,
        failed_cases: cases_len.saturating_sub(passed_cases),
        pass_rate: if cases_len == 0 {
            0.0
        } else {
            passed_cases as f32 / cases_len as f32
        },
        ranking_metrics: evaluate_rankings_at_k(&ranking_pairs, k),
        case_reports,
    }
}

#[cfg(test)]
mod tests {
    use super::{
        RecallEvaluationAssertions, RecallEvaluationCase, evaluate_rankings_at_k,
        evaluate_recall_results,
    };
    use crate::{
        MemoryQualityState, MemoryRecord, MemoryRecordKind, MemoryScope, MemoryTrustLevel,
        RecallExplanation, RecallFilters, RecallHit, RecallQuery, RecallResult,
        RecallScoreBreakdown,
    };
    use std::collections::BTreeMap;

    #[test]
    fn computes_standard_ranking_metrics() {
        let ranked_a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
        let relevant_a = vec!["b".to_string(), "d".to_string()];
        let ranked_b = vec!["x".to_string(), "y".to_string(), "z".to_string()];
        let relevant_b = vec!["x".to_string()];

        let metrics =
            evaluate_rankings_at_k(&[(&ranked_a, &relevant_a), (&ranked_b, &relevant_b)], 3);

        assert_eq!(metrics.cases, 2);
        assert!(metrics.hit_rate_at_k > 0.9);
        assert!(metrics.recall_at_k > 0.7);
        assert!(metrics.mrr > 0.6);
        assert!(metrics.ndcg_at_k > 0.6);
    }

    fn scope() -> MemoryScope {
        MemoryScope {
            tenant_id: "tenant-a".to_string(),
            namespace: "ops".to_string(),
            actor_id: "ava".to_string(),
            conversation_id: None,
            session_id: None,
            source: "test".to_string(),
            labels: vec!["eval".to_string()],
            trust_level: MemoryTrustLevel::Verified,
        }
    }

    fn hit(record_id: &str) -> RecallHit {
        RecallHit {
            record: MemoryRecord {
                id: record_id.to_string(),
                scope: scope(),
                kind: MemoryRecordKind::Fact,
                content: format!("record {record_id}"),
                summary: None,
                source_id: None,
                metadata: BTreeMap::new(),
                quality_state: MemoryQualityState::Active,
                created_at_unix_ms: 1,
                updated_at_unix_ms: 1,
                expires_at_unix_ms: None,
                importance_score: 0.5,
                artifact: None,
                episode: None,
                historical_state: Default::default(),
                lineage: Vec::new(),
                conflict: None,
            },
            breakdown: RecallScoreBreakdown {
                lexical: 1.0,
                semantic: 0.0,
                graph: 0.0,
                temporal: 0.0,
                metadata: 0.0,
                episodic: 0.0,
                salience: 0.0,
                curation: 0.0,
                policy: 0.0,
                total: 1.0,
            },
            explanation: None,
        }
    }

    #[test]
    fn evaluates_judged_recall_cases_with_assertions() {
        let query = RecallQuery {
            scope: scope(),
            query_text: "release".to_string(),
            max_items: 3,
            token_budget: None,
            filters: RecallFilters::default(),
            include_explanation: true,
        };
        let cases = vec![RecallEvaluationCase {
            name: "release memory".to_string(),
            query,
            assertions: RecallEvaluationAssertions {
                expected_record_ids: vec!["a".to_string()],
                optional_record_ids: vec!["b".to_string()],
                disallowed_record_ids: vec!["x".to_string()],
                required_explanation_notes: vec!["policy=general".to_string()],
            },
        }];
        let results = vec![RecallResult {
            hits: vec![hit("a"), hit("b")],
            total_candidates_examined: 2,
            explanation: Some(RecallExplanation {
                selected_channels: vec!["lexical".to_string()],
                policy_notes: vec!["recall_policy=general".to_string()],
                trace_id: None,
                planning_trace: None,
                planning_profile: None,
                policy_profile: None,
                scorer_kind: None,
                scoring_profile: None,
            }),
        }];

        let report = evaluate_recall_results(&cases, &results, 3);

        assert_eq!(report.cases, 1);
        assert_eq!(report.passed_cases, 1);
        assert_eq!(report.failed_cases, 0);
        assert!(report.pass_rate > 0.99);
        assert!(report.ranking_metrics.hit_rate_at_k > 0.99);
        assert!(report.case_reports[0].passed);
    }
}