lash-core 0.1.0-alpha.37

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use std::collections::{HashMap, HashSet};

use regex::Regex;

const DEFAULT_LIMIT: usize = 10;
const MAX_LIMIT: usize = 100;

#[derive(Clone, Copy, PartialEq, Eq)]
pub enum SearchMode {
    Hybrid,
    Literal,
    Regex,
}

impl SearchMode {
    pub fn parse(value: Option<&str>) -> Self {
        match value
            .unwrap_or("hybrid")
            .trim()
            .to_ascii_lowercase()
            .as_str()
        {
            "literal" => Self::Literal,
            "regex" => Self::Regex,
            _ => Self::Hybrid,
        }
    }
}

#[derive(Clone)]
pub struct SearchDoc {
    pub fields: HashMap<&'static str, String>,
}

pub fn limit_from_args(args: &serde_json::Value) -> usize {
    args.get("limit")
        .and_then(|v| v.as_i64())
        .and_then(|n| usize::try_from(n).ok())
        .map(|n| n.clamp(1, MAX_LIMIT))
        .unwrap_or(DEFAULT_LIMIT)
}

fn compile_regex(pattern: Option<&str>) -> Option<Regex> {
    let raw = pattern.unwrap_or_default();
    if raw.is_empty() {
        return None;
    }
    match Regex::new(&format!("(?i){raw}")) {
        Ok(re) => Some(re),
        Err(_) => {
            let escaped = regex::escape(raw);
            Regex::new(&format!("(?i){escaped}")).ok()
        }
    }
}

fn tokenize(text: &str) -> Vec<String> {
    text.split(|c: char| !(c.is_ascii_alphanumeric() || c == '_'))
        .filter(|t| !t.is_empty())
        .map(|t| t.to_ascii_lowercase())
        .collect()
}

fn hybrid_token_match(query_token: &str, candidate_token: &str) -> bool {
    query_token == candidate_token
        || (query_token.len() >= 3
            && candidate_token.len() >= 3
            && (candidate_token.contains(query_token) || query_token.contains(candidate_token)))
}

fn hybrid_text_match(query_tokens: &[String], text: &str) -> bool {
    if query_tokens.is_empty() {
        return false;
    }
    let tokens = tokenize(text);
    query_tokens.iter().any(|query| {
        tokens
            .iter()
            .any(|candidate| hybrid_token_match(query, candidate))
    })
}

fn hybrid_fallback_score(
    query_tokens: &[String],
    doc: &SearchDoc,
    field_weights: &[(&'static str, f64)],
) -> f64 {
    let mut score = 0.0_f64;
    for (field, weight) in field_weights {
        if *weight <= 0.0 {
            continue;
        }
        let text = doc
            .fields
            .get(field)
            .map(String::as_str)
            .unwrap_or_default();
        if text.is_empty() {
            continue;
        }
        let tokens = tokenize(text);
        if tokens.is_empty() {
            continue;
        }
        let hits = query_tokens
            .iter()
            .filter(|query| {
                tokens
                    .iter()
                    .any(|candidate| hybrid_token_match(query, candidate))
            })
            .count() as f64;
        if hits > 0.0 {
            score += hits * *weight * 0.25;
        }
    }
    score
}

fn bm25_scores(
    query_tokens: &[String],
    docs: &[SearchDoc],
    field_weights: &[(&'static str, f64)],
) -> Vec<f64> {
    let n_docs = docs.len();
    if n_docs == 0 {
        return Vec::new();
    }

    let mut doc_tfs: Vec<HashMap<String, f64>> = Vec::with_capacity(n_docs);
    let mut doc_lens: Vec<f64> = Vec::with_capacity(n_docs);
    let mut doc_freq: HashMap<String, usize> = HashMap::new();

    for doc in docs {
        let mut tf: HashMap<String, f64> = HashMap::new();
        let mut dlen = 0.0_f64;
        for (field, weight) in field_weights {
            if *weight <= 0.0 {
                continue;
            }
            let text = doc
                .fields
                .get(field)
                .map(String::as_str)
                .unwrap_or_default();
            let tokens = tokenize(text);
            if tokens.is_empty() {
                continue;
            }
            let mut counts: HashMap<String, usize> = HashMap::new();
            for tok in &tokens {
                *counts.entry(tok.clone()).or_insert(0) += 1;
            }
            for (tok, count) in counts {
                *tf.entry(tok).or_insert(0.0) += (count as f64) * *weight;
            }
            dlen += (tokens.len() as f64) * *weight;
        }
        for tok in tf.keys() {
            *doc_freq.entry(tok.clone()).or_insert(0) += 1;
        }
        doc_tfs.push(tf);
        doc_lens.push(dlen);
    }

    let avgdl = {
        let sum: f64 = doc_lens.iter().sum();
        let avg = sum / (n_docs as f64);
        if avg <= 0.0 { 1.0 } else { avg }
    };

    let mut qtf: HashMap<String, usize> = HashMap::new();
    for tok in query_tokens {
        *qtf.entry(tok.clone()).or_insert(0) += 1;
    }

    let k1 = 1.5_f64;
    let b = 0.75_f64;
    let mut scores = vec![0.0_f64; n_docs];
    for (i, tf) in doc_tfs.iter().enumerate() {
        let dl = doc_lens[i];
        let norm = 1.0 - b + b * (dl / avgdl);
        for (tok, qcount) in &qtf {
            let freq = *tf.get(tok).unwrap_or(&0.0);
            if freq <= 0.0 {
                continue;
            }
            let df = *doc_freq.get(tok).unwrap_or(&0) as f64;
            let idf = ((n_docs as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
            let denom = freq + k1 * norm;
            if denom <= 0.0 {
                continue;
            }
            let term = idf * ((freq * (k1 + 1.0)) / denom);
            scores[i] += term * (1.0 + (*qcount as f64).ln());
        }
    }

    scores
}

fn field_hits(
    fields: &HashMap<&'static str, String>,
    query: &str,
    mode: SearchMode,
    regex_filter: Option<&Regex>,
) -> Vec<String> {
    let query_lower = query.to_ascii_lowercase();
    let query_tokens: HashSet<String> = tokenize(query).into_iter().collect();
    let mut hits = Vec::new();
    for (field, value) in fields {
        if value.is_empty() {
            continue;
        }
        let value_lower = value.to_ascii_lowercase();
        let mut hit = match mode {
            SearchMode::Regex => regex_filter.is_some_and(|re| re.is_match(value)),
            SearchMode::Literal => !query.is_empty() && value_lower.contains(&query_lower),
            SearchMode::Hybrid => {
                if !query_tokens.is_empty() {
                    let tokens = tokenize(value);
                    query_tokens.iter().any(|query| {
                        tokens
                            .iter()
                            .any(|candidate| hybrid_token_match(query, candidate))
                    })
                } else if query.is_empty() {
                    true
                } else {
                    value_lower.contains(&query_lower)
                }
            }
        };
        if hit && regex_filter.is_some() && mode != SearchMode::Regex {
            hit = regex_filter.is_some_and(|re| re.is_match(value));
        }
        if hit {
            hits.push((*field).to_string());
        }
    }
    hits
}

pub fn rank_docs(
    docs: &[SearchDoc],
    query: &str,
    mode: SearchMode,
    regex: Option<&str>,
    field_weights: &[(&'static str, f64)],
) -> Vec<(usize, f64, Vec<String>)> {
    let query_tokens = tokenize(query);
    let query_lower = query.to_ascii_lowercase();
    let mut scores = vec![0.0_f64; docs.len()];
    if mode == SearchMode::Hybrid && !query_tokens.is_empty() {
        scores = bm25_scores(&query_tokens, docs, field_weights);
        for (idx, score) in scores.iter_mut().enumerate() {
            if *score <= 0.0 {
                *score = hybrid_fallback_score(&query_tokens, &docs[idx], field_weights);
            }
        }
    }
    let regex_filter = match mode {
        SearchMode::Regex => compile_regex(regex.or(Some(query))),
        _ => compile_regex(regex),
    };

    let mut indices: Vec<usize> = (0..docs.len()).collect();
    if mode == SearchMode::Hybrid {
        indices.sort_by(|a, b| {
            scores[*b]
                .partial_cmp(&scores[*a])
                .unwrap_or(std::cmp::Ordering::Equal)
                .then_with(|| a.cmp(b))
        });
    }

    let mut ranked = Vec::new();
    for idx in indices {
        let haystack = docs[idx]
            .fields
            .values()
            .filter(|value| !value.is_empty())
            .cloned()
            .collect::<Vec<_>>()
            .join("\n");
        let haystack_lower = haystack.to_ascii_lowercase();

        let mut include = match mode {
            SearchMode::Regex => regex_filter
                .as_ref()
                .is_some_and(|re| re.is_match(&haystack)),
            SearchMode::Literal => !query.is_empty() && haystack_lower.contains(&query_lower),
            SearchMode::Hybrid => {
                if query.is_empty() {
                    true
                } else if !query_tokens.is_empty() {
                    scores[idx] > 0.0
                        || haystack_lower.contains(&query_lower)
                        || hybrid_text_match(&query_tokens, &haystack)
                } else {
                    haystack_lower.contains(&query_lower)
                }
            }
        };
        if include && regex_filter.is_some() && mode != SearchMode::Regex {
            include = regex_filter
                .as_ref()
                .is_some_and(|re| re.is_match(&haystack));
        }
        if !include {
            continue;
        }

        let hits = field_hits(&docs[idx].fields, query, mode, regex_filter.as_ref());
        if hits.is_empty() && !(query.is_empty() && mode != SearchMode::Regex) {
            continue;
        }
        ranked.push((idx, scores[idx], hits));
    }
    ranked
}