talon-core 0.4.2

Core retrieval engine for Talon: hybrid search (BM25 + semantic + reranker), indexing, and graph-aware ranking over markdown corpora.
Documentation
use std::collections::HashSet;
use std::time::{Duration, Instant};

use crate::config::TalonConfig;
use crate::expansion::{ExpansionClient, ExpansionError};
use crate::text::{estimate_tokens, is_fence_line, nfd};

mod phrases;

use phrases::{WeightedPhrase, clean_phrase, extract_weighted_phrases, strip_code_blocks};

const DISTILLER_OVERHEAD_TOKENS: usize = 512;
const DISTILLER_SAFETY_MARGIN_TOKENS: usize = 2_048;
const DEFAULT_QUERY_EMBEDDING_CONTEXT_TOKENS: usize = 512;
const MAX_QUERY_SET_SIZE: usize = 6;
const MAX_MAIN_QUERY_TOKENS: usize = 96;
const MAX_HINTS: usize = 16;
const DISTILLATION_MIN_REMAINING: Duration = Duration::from_secs(3);

#[derive(Debug, Clone)]
pub(super) struct RecallQueryPlan {
    pub(super) main_query: String,
    pub(super) queries: Vec<String>,
    pub(super) input_tokens: usize,
    pub(super) query_tokens: usize,
    pub(super) phrase_count: usize,
    pub(super) distillation_input_tokens: Option<usize>,
    pub(super) distillation_ran: bool,
    pub(super) distillation_ms: Option<u64>,
    pub(super) distillation_succeeded: bool,
    pub(super) distillation_fallback_reason: Option<String>,
}

pub(super) fn plan_recall_queries(
    raw_query: &str,
    expansion: Option<&ExpansionClient>,
    config: Option<&TalonConfig>,
    deadline_at: Option<Instant>,
) -> RecallQueryPlan {
    let started = Instant::now();
    let phrases = extract_weighted_phrases(raw_query);
    let embedding_budget = query_embedding_budget(config);
    let query_tokens = estimate_tokens(raw_query);
    let noisy = noisy_prompt(raw_query);
    let should_distill = query_tokens > embedding_budget || noisy;

    let mut distillation_ran = false;
    let mut distillation_succeeded = false;
    let mut distillation_fallback_reason = None;
    let mut distillation_input_tokens = None;
    let distilled = if should_distill && has_time_for_distillation(deadline_at) {
        if let Some(client) = expansion {
            distillation_ran = true;
            let view = budgeted_prompt_view(raw_query, config);
            distillation_input_tokens = Some(estimate_tokens(&view));
            let hints = phrase_hints(&phrases);
            match client.distill_recall_prompt(&view, &hints) {
                Ok(Some(body)) => {
                    distillation_succeeded = true;
                    Some(body)
                }
                Ok(None) => {
                    distillation_fallback_reason = Some("empty-or-malformed-response".to_owned());
                    None
                }
                Err(err) => {
                    distillation_fallback_reason = Some(classify_distillation_error(&err));
                    None
                }
            }
        } else {
            distillation_fallback_reason = Some("client-unavailable".to_owned());
            None
        }
    } else {
        if should_distill {
            distillation_fallback_reason = Some("deadline-too-close".to_owned());
        }
        None
    };
    let distillation_ms =
        distillation_ran.then(|| u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX));

    let main_query = distilled
        .as_ref()
        .map(|body| body.search_query.as_str())
        .filter(|query| !query.trim().is_empty())
        .map_or_else(
            || compact_main_query(raw_query, &phrases, embedding_budget),
            |query| cap_tokens(query, MAX_MAIN_QUERY_TOKENS),
        );

    let mut phrase_texts: Vec<String> = phrases.iter().map(|phrase| phrase.text.clone()).collect();
    if let Some(body) = distilled {
        phrase_texts.extend(body.phrases);
        phrase_texts.extend(body.identifiers);
    }

    let mut queries = vec![main_query];
    queries.extend(build_phrase_queries(&phrase_texts));
    let mut plan = dedupe_queries(queries);
    plan.input_tokens = query_tokens;
    plan.query_tokens = estimate_tokens(&plan.main_query);
    plan.phrase_count = phrases.len();
    plan.distillation_input_tokens = distillation_input_tokens;
    plan.distillation_ran = distillation_ran;
    plan.distillation_ms = distillation_ms;
    plan.distillation_succeeded = distillation_succeeded;
    plan.distillation_fallback_reason = distillation_fallback_reason;
    plan
}

fn classify_distillation_error(error: &ExpansionError) -> String {
    match error {
        ExpansionError::Http {
            timed_out: true, ..
        } => "timeout".to_owned(),
        ExpansionError::Http {
            status: Some(status),
            ..
        } => format!("http-{status}"),
        ExpansionError::Http { .. } => "transport-error".to_owned(),
        ExpansionError::Build { .. } => "client-build-error".to_owned(),
    }
}

fn has_time_for_distillation(deadline_at: Option<Instant>) -> bool {
    deadline_at.is_none_or(|deadline| {
        deadline.saturating_duration_since(Instant::now()) > DISTILLATION_MIN_REMAINING
    })
}

fn query_embedding_budget(config: Option<&TalonConfig>) -> usize {
    config
        .map(|cfg| cfg.embedding.context_tokens)
        .and_then(|tokens| usize::try_from(tokens).ok())
        .filter(|tokens| *tokens > 0)
        .unwrap_or(DEFAULT_QUERY_EMBEDDING_CONTEXT_TOKENS)
}

fn expansion_input_budget(config: Option<&TalonConfig>) -> usize {
    let Some(config) = config else {
        return 4_000;
    };
    expansion_input_budget_for_limits(
        config.chat.expansion.context_tokens,
        config.chat.expansion.max_output_tokens,
    )
}

fn expansion_input_budget_for_limits(context_tokens: u32, max_output_tokens: Option<u32>) -> usize {
    let context = usize::try_from(context_tokens).unwrap_or(usize::MAX);
    let output = max_output_tokens
        .and_then(|tokens| usize::try_from(tokens).ok())
        .unwrap_or(768);
    context
        .saturating_sub(output + DISTILLER_OVERHEAD_TOKENS + DISTILLER_SAFETY_MARGIN_TOKENS)
        .max(256)
}

fn noisy_prompt(query: &str) -> bool {
    let line_count = query.lines().count();
    let fence_count = query
        .lines()
        .filter(|line| is_fence_line(line.trim()))
        .count();
    line_count > 80 || fence_count >= 2 || query.contains("```") || query.contains("TRACE")
}

fn budgeted_prompt_view(query: &str, config: Option<&TalonConfig>) -> String {
    let stripped = strip_code_blocks(query);
    let budget = expansion_input_budget(config);
    budgeted_prompt_view_for_budget(&stripped, budget)
}

fn budgeted_prompt_view_for_budget(stripped: &str, budget: usize) -> String {
    if estimate_tokens(stripped) <= budget {
        return stripped.to_owned();
    }
    let capped = cap_tokens_head_tail(stripped, budget.saturating_sub(16));
    if estimate_tokens(&capped) <= budget {
        capped
    } else {
        cap_tokens(&capped, budget)
    }
}

fn compact_main_query(raw_query: &str, phrases: &[WeightedPhrase], budget: usize) -> String {
    if estimate_tokens(raw_query) <= budget {
        return raw_query.to_owned();
    }
    let phrase_query = phrase_hints(phrases)
        .into_iter()
        .take(8)
        .collect::<Vec<_>>()
        .join(" ");
    if !phrase_query.is_empty() {
        return cap_tokens(&phrase_query, MAX_MAIN_QUERY_TOKENS);
    }
    cap_tokens_tail(raw_query, budget.min(MAX_MAIN_QUERY_TOKENS))
}

fn phrase_hints(phrases: &[WeightedPhrase]) -> Vec<String> {
    phrases
        .iter()
        .take(MAX_HINTS)
        .map(|phrase| phrase.text.clone())
        .collect()
}

fn build_phrase_queries(phrases: &[String]) -> Vec<String> {
    let mut literals = Vec::new();
    let mut semantic_phrases = Vec::new();
    for phrase in phrases {
        let cleaned = clean_phrase(phrase);
        if cleaned.is_empty() {
            continue;
        }
        if looks_literal(&cleaned) {
            literals.push(cleaned);
        } else {
            semantic_phrases.push(cleaned);
        }
    }

    let mut queries = Vec::new();
    for chunk in semantic_phrases.chunks(4).take(3) {
        queries.push(chunk.join(" "));
    }
    if !literals.is_empty() {
        queries.push(literals.into_iter().take(8).collect::<Vec<_>>().join(" "));
    }
    queries
}

fn dedupe_queries(queries: Vec<String>) -> RecallQueryPlan {
    let mut seen = HashSet::new();
    let mut result = Vec::with_capacity(MAX_QUERY_SET_SIZE);
    for query in queries {
        let query = clean_phrase(&query);
        if query.is_empty() {
            continue;
        }
        let key = nfd::normalize(&query).to_lowercase();
        if seen.insert(key) {
            result.push(query);
            if result.len() >= MAX_QUERY_SET_SIZE {
                break;
            }
        }
    }
    let main_query = result.first().cloned().unwrap_or_default();
    RecallQueryPlan {
        main_query,
        queries: result,
        input_tokens: 0,
        query_tokens: 0,
        phrase_count: 0,
        distillation_input_tokens: None,
        distillation_ran: false,
        distillation_ms: None,
        distillation_succeeded: false,
        distillation_fallback_reason: None,
    }
}

fn looks_literal(value: &str) -> bool {
    value.contains('/')
        || value.contains('#')
        || value.contains('_')
        || value.contains("::")
        || value.chars().any(char::is_uppercase) && value.chars().any(char::is_lowercase)
}

fn cap_tokens(input: &str, budget: usize) -> String {
    if estimate_tokens(input) <= budget {
        return input.trim().to_owned();
    }
    let max_chars = budget
        .saturating_mul(usize::from(crate::text::TOKEN_CHAR_RATIO))
        .max(1);
    input
        .chars()
        .take(max_chars)
        .collect::<String>()
        .trim()
        .to_owned()
}

fn cap_tokens_tail(input: &str, budget: usize) -> String {
    if estimate_tokens(input) <= budget {
        return input.trim().to_owned();
    }
    let max_chars = budget
        .saturating_mul(usize::from(crate::text::TOKEN_CHAR_RATIO))
        .max(1);
    let mut chars: Vec<char> = input.chars().rev().take(max_chars).collect();
    chars.reverse();
    chars.into_iter().collect::<String>().trim().to_owned()
}

fn cap_tokens_head_tail(input: &str, budget: usize) -> String {
    let max_chars = budget
        .saturating_mul(usize::from(crate::text::TOKEN_CHAR_RATIO))
        .max(1);
    let half = max_chars / 2;
    let head: String = input.chars().take(half).collect();
    let mut tail_chars: Vec<char> = input
        .chars()
        .rev()
        .take(max_chars.saturating_sub(half))
        .collect();
    tail_chars.reverse();
    let tail: String = tail_chars.into_iter().collect();
    format!("{head}\n\n[...]\n\n{tail}")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn plan_recall_queries_compacts_large_prompt() {
        let prompt = "context overflow ".repeat(800);
        let plan = plan_recall_queries(&prompt, None, None, None);
        assert!(estimate_tokens(&plan.main_query) <= MAX_MAIN_QUERY_TOKENS);
        assert!(!plan.queries.is_empty());
    }

    #[test]
    fn expansion_input_budget_respects_configured_context() {
        assert_eq!(expansion_input_budget_for_limits(16_000, Some(768)), 12_672);
        assert_eq!(expansion_input_budget_for_limits(32_768, None), 29_440);
    }

    #[test]
    fn budgeted_prompt_view_stays_under_configured_budget() {
        let prompt = "supplier order hot sauce launch readiness ".repeat(20_000);
        let budget = expansion_input_budget_for_limits(16_000, Some(768));
        let view = budgeted_prompt_view_for_budget(&prompt, budget);
        assert!(estimate_tokens(&view) <= budget);
    }
}