roboticus-agent 0.11.4

Agent core with ReAct loop, policy engine, injection defense, memory system, and skill loader
Documentation
//! Unified ranking types for memory retrieval and tool search convergence.
//!
//! Both systems produce `Vec<RankedCandidate>`, enabling shared top-K pruning
//! and budget-aware truncation logic.

use serde::{Deserialize, Serialize};

/// What kind of item this candidate represents.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CandidateKind {
    Memory,
    Tool,
}

/// A ranked item from either memory retrieval or tool search.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankedCandidate {
    pub source_id: String,
    pub source_kind: CandidateKind,
    pub raw_score: f64,
    pub adjusted_score: f64,
    pub token_cost: usize,
}

/// Select the top K candidates that fit within a token budget.
///
/// Candidates are assumed to be pre-sorted by `adjusted_score` descending.
/// Returns at most `k` items whose cumulative `token_cost` does not exceed `budget`.
pub fn top_k_within_budget(
    candidates: &[RankedCandidate],
    k: usize,
    budget: usize,
) -> Vec<RankedCandidate> {
    let mut result = Vec::with_capacity(k);
    let mut spent = 0;
    for c in candidates.iter().take(k) {
        if spent + c.token_cost > budget {
            break;
        }
        spent += c.token_cost;
        result.push(c.clone());
    }
    result
}

/// Like `top_k_within_budget`, but guarantees that pinned items are always
/// included regardless of their score. Pinned items consume budget first.
pub fn top_k_with_pinned(
    candidates: &[RankedCandidate],
    k: usize,
    budget: usize,
    always_include: &[&str],
) -> Vec<RankedCandidate> {
    let mut pinned = Vec::new();
    let mut rest = Vec::new();
    for c in candidates {
        if always_include.contains(&c.source_id.as_str()) {
            pinned.push(c.clone());
        } else {
            rest.push(c.clone());
        }
    }

    let mut result = Vec::with_capacity(k);
    let mut spent = 0;

    for c in &pinned {
        spent += c.token_cost;
        result.push(c.clone());
    }

    for c in rest.iter() {
        if result.len() >= k || spent + c.token_cost > budget {
            break;
        }
        spent += c.token_cost;
        result.push(c.clone());
    }
    result
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn ranked_candidate_ordering() {
        let a = RankedCandidate {
            source_id: "tool_a".into(),
            source_kind: CandidateKind::Tool,
            raw_score: 0.9,
            adjusted_score: 0.85,
            token_cost: 50,
        };
        let b = RankedCandidate {
            source_id: "mem_b".into(),
            source_kind: CandidateKind::Memory,
            raw_score: 0.7,
            adjusted_score: 0.95,
            token_cost: 100,
        };
        let mut v = [a, b];
        v.sort_by(|x, y| y.adjusted_score.partial_cmp(&x.adjusted_score).unwrap());
        assert_eq!(v[0].source_id, "mem_b");
    }

    #[test]
    fn top_k_prunes_to_budget() {
        let candidates: Vec<RankedCandidate> = (0..20)
            .map(|i| RankedCandidate {
                source_id: format!("item_{i}"),
                source_kind: CandidateKind::Tool,
                raw_score: 1.0 - (i as f64 * 0.05),
                adjusted_score: 1.0 - (i as f64 * 0.05),
                token_cost: 50,
            })
            .collect();
        let pruned = top_k_within_budget(&candidates, 10, 300);
        assert!(pruned.len() <= 10);
        assert!(pruned.iter().map(|c| c.token_cost).sum::<usize>() <= 300);
    }

    #[test]
    fn always_include_list_preserved() {
        let candidates = vec![
            RankedCandidate {
                source_id: "memory_store".into(),
                source_kind: CandidateKind::Tool,
                raw_score: 0.1,
                adjusted_score: 0.1,
                token_cost: 30,
            },
            RankedCandidate {
                source_id: "random_tool".into(),
                source_kind: CandidateKind::Tool,
                raw_score: 0.9,
                adjusted_score: 0.9,
                token_cost: 30,
            },
        ];
        let always_include = &["memory_store"];
        let pruned = top_k_with_pinned(&candidates, 1, 100, always_include);
        assert!(pruned.iter().any(|c| c.source_id == "memory_store"));
    }
}