Skip to main content

ctx_semantic/
lib.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum SemanticError {
9    #[error(
10        "ONNX model not found at {path}; set [semantic].model to a local .onnx file or enable allow_fallback"
11    )]
12    OnnxModelNotFound { path: String },
13    #[error(
14        "ONNX vocab not found at {path}; set [semantic].vocab to a local vocab.txt/tokenizer file or enable allow_fallback"
15    )]
16    OnnxVocabNotFound { path: String },
17    #[error(
18        "ONNX backend requested but ctx-semantic was built without the `onnx` feature; rebuild with `cargo build --features ctx-semantic/onnx` or enable semantic.allow_fallback"
19    )]
20    OnnxFeatureDisabled,
21    #[error("ONNX inference failed: {0}")]
22    OnnxInference(String),
23}
24
25#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
26pub struct Features {
27    pub semantic_similarity: f64,
28    pub keyword_overlap: f64,
29    pub recency: f64,
30    pub graph_distance_bonus: f64,
31    pub failure_bonus: f64,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum SemanticBackendKind {
37    LocalHash,
38    Onnx,
39}
40
41impl SemanticBackendKind {
42    pub fn parse(value: &str) -> Option<Self> {
43        match value.trim().to_lowercase().as_str() {
44            "local" | "local_hash" | "hash" => Some(Self::LocalHash),
45            "onnx" | "onnx_runtime" => Some(Self::Onnx),
46            _ => None,
47        }
48    }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RankingConfig {
53    pub backend: SemanticBackendKind,
54    pub max_chunks: usize,
55    pub adaptive_threshold: bool,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct SemanticEngineConfig {
60    pub backend: SemanticBackendKind,
61    pub model_path: Option<PathBuf>,
62    pub vocab_path: Option<PathBuf>,
63    pub max_chunks: usize,
64    pub adaptive_threshold: bool,
65    pub allow_fallback: bool,
66}
67
68impl SemanticEngineConfig {
69    pub fn local_hash(max_chunks: usize, adaptive_threshold: bool) -> Self {
70        Self {
71            backend: SemanticBackendKind::LocalHash,
72            model_path: None,
73            vocab_path: None,
74            max_chunks,
75            adaptive_threshold,
76            allow_fallback: true,
77        }
78    }
79
80    fn from_ranking_config(config: RankingConfig) -> Self {
81        Self {
82            backend: config.backend,
83            model_path: None,
84            vocab_path: None,
85            max_chunks: config.max_chunks,
86            adaptive_threshold: config.adaptive_threshold,
87            allow_fallback: true,
88        }
89    }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ChunkCandidate {
94    pub id: String,
95    pub text: String,
96    pub keyword_hint: String,
97    pub recency: f64,
98    pub graph_distance: f64,
99    pub failure_relevance: f64,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct RankedChunk {
104    pub id: String,
105    pub score: f64,
106    pub features: Features,
107    pub reason: String,
108    pub text: String,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
112pub struct EmbeddingMetadata {
113    pub model_id: String,
114    pub text_hash: u64,
115    pub dimensions: usize,
116}
117
118#[derive(Debug, Clone)]
119struct CacheEntry {
120    metadata: EmbeddingMetadata,
121    vector: Vec<f32>,
122}
123
124#[derive(Debug, Clone)]
125pub struct EmbeddingCache {
126    capacity: usize,
127    order: VecDeque<String>,
128    entries: HashMap<String, CacheEntry>,
129}
130
131impl EmbeddingCache {
132    pub fn new(capacity: usize) -> Self {
133        Self {
134            capacity: capacity.max(1),
135            order: VecDeque::new(),
136            entries: HashMap::new(),
137        }
138    }
139
140    pub fn put(&mut self, model_id: &str, text: &str, vector: Vec<f32>) -> EmbeddingMetadata {
141        let metadata = EmbeddingMetadata {
142            model_id: model_id.to_string(),
143            text_hash: stable_text_hash(text),
144            dimensions: vector.len(),
145        };
146        let key = cache_key(&metadata.model_id, metadata.text_hash);
147
148        if !self.entries.contains_key(&key) {
149            self.order.push_back(key.clone());
150        }
151        self.entries.insert(
152            key.clone(),
153            CacheEntry {
154                metadata: metadata.clone(),
155                vector,
156            },
157        );
158        self.evict_if_needed();
159        metadata
160    }
161
162    pub fn get(&self, model_id: &str, text: &str) -> Option<&[f32]> {
163        let key = cache_key(model_id, stable_text_hash(text));
164        self.entries.get(&key).map(|entry| entry.vector.as_slice())
165    }
166
167    pub fn metadata(&self, model_id: &str, text: &str) -> Option<&EmbeddingMetadata> {
168        let key = cache_key(model_id, stable_text_hash(text));
169        self.entries.get(&key).map(|entry| &entry.metadata)
170    }
171
172    fn evict_if_needed(&mut self) {
173        while self.entries.len() > self.capacity {
174            if let Some(oldest) = self.order.pop_front() {
175                self.entries.remove(&oldest);
176            } else {
177                break;
178            }
179        }
180    }
181}
182
183pub fn score(features: Features) -> f64 {
184    0.40 * features.semantic_similarity
185        + 0.20 * features.keyword_overlap
186        + 0.15 * features.recency
187        + 0.15 * features.graph_distance_bonus
188        + 0.10 * features.failure_bonus
189}
190
191pub fn rank_chunks_hybrid(
192    query: &str,
193    candidates: &[ChunkCandidate],
194    config: RankingConfig,
195) -> Vec<RankedChunk> {
196    rank_chunks(
197        query,
198        candidates,
199        SemanticEngineConfig::from_ranking_config(config),
200    )
201    .unwrap_or_default()
202}
203
204pub fn rank_chunks(
205    query: &str,
206    candidates: &[ChunkCandidate],
207    config: SemanticEngineConfig,
208) -> Result<Vec<RankedChunk>, SemanticError> {
209    if candidates.is_empty() {
210        return Ok(Vec::new());
211    }
212
213    let backend = resolve_backend(&config)?;
214    let mut cache = EmbeddingCache::new(candidates.len() + 1);
215    let query_embedding = embed_text(query, &backend, &mut cache)?;
216    let mut seen_fingerprint = HashSet::new();
217    let mut ranked = Vec::new();
218
219    for candidate in candidates {
220        let fingerprint = normalize_text(&candidate.text);
221        if !seen_fingerprint.insert(fingerprint) {
222            continue;
223        }
224
225        let candidate_embedding = embed_text(&candidate.text, &backend, &mut cache)?;
226        let semantic_similarity = cosine_dense(&query_embedding, &candidate_embedding);
227        let keyword_overlap = keyword_similarity(query, &candidate.keyword_hint, &candidate.text);
228        let features = Features {
229            semantic_similarity,
230            keyword_overlap,
231            recency: candidate.recency.clamp(0.0, 1.0),
232            graph_distance_bonus: graph_distance_bonus(candidate.graph_distance),
233            failure_bonus: candidate.failure_relevance.clamp(0.0, 1.0),
234        };
235
236        let total_score = score(features);
237        ranked.push(RankedChunk {
238            id: candidate.id.clone(),
239            score: total_score,
240            features,
241            reason: format_reason(&backend, features),
242            text: candidate.text.clone(),
243        });
244    }
245
246    ranked.sort_by(|a, b| {
247        b.score
248            .partial_cmp(&a.score)
249            .unwrap_or(std::cmp::Ordering::Equal)
250    });
251
252    let thresholded = if config.adaptive_threshold && !ranked.is_empty() {
253        apply_adaptive_threshold(ranked, config.max_chunks.max(1))
254    } else {
255        ranked
256    };
257
258    let mut final_ranked = thresholded;
259    final_ranked.truncate(config.max_chunks.max(1));
260    Ok(final_ranked)
261}
262
263#[derive(Debug, Clone)]
264#[allow(dead_code)]
265enum ResolvedBackend {
266    LocalHash {
267        model_id: String,
268        fallback_from: Option<&'static str>,
269    },
270    Onnx {
271        model_id: String,
272        model_path: PathBuf,
273        vocab_path: Option<PathBuf>,
274    },
275}
276
277fn resolve_backend(config: &SemanticEngineConfig) -> Result<ResolvedBackend, SemanticError> {
278    match config.backend {
279        SemanticBackendKind::LocalHash => Ok(ResolvedBackend::LocalHash {
280            model_id: "local_hash:v1".to_string(),
281            fallback_from: None,
282        }),
283        SemanticBackendKind::Onnx => resolve_onnx_backend(config),
284    }
285}
286
287fn resolve_onnx_backend(config: &SemanticEngineConfig) -> Result<ResolvedBackend, SemanticError> {
288    let Some(model_path) = config.model_path.clone() else {
289        return fallback_or_error(
290            config.allow_fallback,
291            SemanticError::OnnxModelNotFound {
292                path: "semantic.model".to_string(),
293            },
294        );
295    };
296
297    if !model_path.exists() {
298        return fallback_or_error(
299            config.allow_fallback,
300            SemanticError::OnnxModelNotFound {
301                path: format!("{} (semantic.model)", model_path.display()),
302            },
303        );
304    }
305
306    if let Some(vocab_path) = &config.vocab_path {
307        if !vocab_path.exists() {
308            return fallback_or_error(
309                config.allow_fallback,
310                SemanticError::OnnxVocabNotFound {
311                    path: format!("{} (semantic.vocab)", vocab_path.display()),
312                },
313            );
314        }
315    }
316
317    #[cfg(not(feature = "onnx"))]
318    {
319        fallback_or_error(config.allow_fallback, SemanticError::OnnxFeatureDisabled)
320    }
321
322    #[cfg(feature = "onnx")]
323    {
324        Ok(ResolvedBackend::Onnx {
325            model_id: format!("onnx:{}", model_path.display()),
326            model_path,
327            vocab_path: config.vocab_path.clone(),
328        })
329    }
330}
331
332fn fallback_or_error(
333    allow_fallback: bool,
334    error: SemanticError,
335) -> Result<ResolvedBackend, SemanticError> {
336    if allow_fallback {
337        Ok(ResolvedBackend::LocalHash {
338            model_id: "local_hash:v1".to_string(),
339            fallback_from: Some("onnx"),
340        })
341    } else {
342        Err(error)
343    }
344}
345
346fn embed_text(
347    text: &str,
348    backend: &ResolvedBackend,
349    cache: &mut EmbeddingCache,
350) -> Result<Vec<f32>, SemanticError> {
351    let model_id = backend.model_id();
352    if let Some(cached) = cache.get(model_id, text) {
353        return Ok(cached.to_vec());
354    }
355
356    let vector = match backend {
357        ResolvedBackend::LocalHash { .. } => local_hash_embedding(text),
358        ResolvedBackend::Onnx {
359            model_path,
360            vocab_path,
361            ..
362        } => onnx_embedding(model_path, vocab_path.as_ref(), text)?,
363    };
364    cache.put(model_id, text, vector.clone());
365    Ok(vector)
366}
367
368impl ResolvedBackend {
369    fn model_id(&self) -> &str {
370        match self {
371            Self::LocalHash { model_id, .. } | Self::Onnx { model_id, .. } => model_id,
372        }
373    }
374
375    fn label(&self) -> &'static str {
376        match self {
377            Self::LocalHash { .. } => "local_hash",
378            Self::Onnx { .. } => "onnx",
379        }
380    }
381
382    fn fallback_from(&self) -> Option<&'static str> {
383        match self {
384            Self::LocalHash { fallback_from, .. } => *fallback_from,
385            Self::Onnx { .. } => None,
386        }
387    }
388}
389
390fn format_reason(backend: &ResolvedBackend, features: Features) -> String {
391    let mut reason = format!(
392        "backend={} semantic={:.3} keyword={:.3} recency={:.3} graph={:.3} failure={:.3}",
393        backend.label(),
394        features.semantic_similarity,
395        features.keyword_overlap,
396        features.recency,
397        features.graph_distance_bonus,
398        features.failure_bonus
399    );
400    if let Some(source) = backend.fallback_from() {
401        reason.push_str(&format!(" fallback_from={source}"));
402    }
403    reason
404}
405
406fn apply_adaptive_threshold(ranked: Vec<RankedChunk>, max_chunks: usize) -> Vec<RankedChunk> {
407    let top = ranked[0].score;
408    let threshold = (top * 0.35).max(0.15);
409    let mut kept = ranked
410        .iter()
411        .filter(|entry| entry.score >= threshold)
412        .cloned()
413        .collect::<Vec<_>>();
414
415    if kept.len() < 2 && ranked.len() >= 2 && max_chunks >= 2 {
416        kept = ranked.iter().take(2).cloned().collect();
417    }
418
419    kept
420}
421
422fn local_hash_embedding(text: &str) -> Vec<f32> {
423    const DIMS: usize = 256;
424    let mut vector = vec![0.0f32; DIMS];
425    for token in tokenize(text) {
426        let hash = stable_text_hash(&token);
427        let idx = (hash as usize) % DIMS;
428        vector[idx] += 1.0;
429
430        let chars = token.chars().collect::<Vec<_>>();
431        for window in chars.windows(3) {
432            let gram = window.iter().collect::<String>();
433            let gram_idx = (stable_text_hash(&gram) as usize) % DIMS;
434            vector[gram_idx] += 0.25;
435        }
436    }
437    normalize_dense(vector)
438}
439
440#[cfg(feature = "onnx")]
441fn onnx_embedding(
442    model_path: &std::path::Path,
443    vocab_path: Option<&PathBuf>,
444    text: &str,
445) -> Result<Vec<f32>, SemanticError> {
446    use tract_onnx::prelude::*;
447
448    let tokens = if let Some(vocab_path) = vocab_path {
449        wordpiece_token_ids(vocab_path, text)?
450    } else {
451        hashed_token_ids(text)
452    };
453    let seq_len = tokens.len().max(1);
454    let attention = vec![1_i64; seq_len];
455    let token_types = vec![0_i64; seq_len];
456
457    let model = tract_onnx::onnx()
458        .model_for_path(model_path)
459        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
460        .into_optimized()
461        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
462    let input_count = model
463        .input_outlets()
464        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
465        .len();
466    if input_count == 0 || input_count > 3 {
467        return Err(SemanticError::OnnxInference(format!(
468            "expected ONNX text embedding model with 1-3 inputs, got {input_count}"
469        )));
470    }
471
472    let model = model
473        .into_runnable()
474        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
475
476    let mut inputs = TVec::new();
477    for values in [&tokens, &attention, &token_types]
478        .into_iter()
479        .take(input_count)
480    {
481        inputs.push(
482            Tensor::from_shape(&[1, seq_len], values)
483                .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
484                .into(),
485        );
486    }
487
488    let outputs = model
489        .run(inputs)
490        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
491    let first = outputs
492        .first()
493        .ok_or_else(|| SemanticError::OnnxInference("model returned no outputs".to_string()))?;
494    let view = first
495        .to_array_view::<f32>()
496        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
497    let shape = view.shape();
498
499    let vector = match shape.len() {
500        2 => view.iter().copied().collect::<Vec<_>>(),
501        3 => {
502            let dim = shape[2];
503            let mut pooled = vec![0.0f32; dim];
504            for token_idx in 0..shape[1] {
505                for dim_idx in 0..dim {
506                    pooled[dim_idx] += view[[0, token_idx, dim_idx]];
507                }
508            }
509            for value in &mut pooled {
510                *value /= shape[1].max(1) as f32;
511            }
512            pooled
513        }
514        _ => view.iter().copied().collect::<Vec<_>>(),
515    };
516
517    Ok(normalize_dense(vector))
518}
519
520#[cfg(not(feature = "onnx"))]
521fn onnx_embedding(
522    _model_path: &std::path::Path,
523    _vocab_path: Option<&PathBuf>,
524    _text: &str,
525) -> Result<Vec<f32>, SemanticError> {
526    Err(SemanticError::OnnxFeatureDisabled)
527}
528
529#[cfg(feature = "onnx")]
530fn wordpiece_token_ids(vocab_path: &PathBuf, text: &str) -> Result<Vec<i64>, SemanticError> {
531    let vocab = std::fs::read_to_string(vocab_path)
532        .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
533    let mut ids = HashMap::new();
534    for (idx, token) in vocab.lines().enumerate() {
535        ids.insert(token.trim().to_string(), idx as i64);
536    }
537
538    let cls = *ids.get("[CLS]").unwrap_or(&101);
539    let sep = *ids.get("[SEP]").unwrap_or(&102);
540    let unk = *ids.get("[UNK]").unwrap_or(&100);
541    let mut out = vec![cls];
542    for token in tokenize(text).into_iter().take(254) {
543        out.push(*ids.get(&token).unwrap_or(&unk));
544    }
545    out.push(sep);
546    Ok(out)
547}
548
549#[cfg(feature = "onnx")]
550fn hashed_token_ids(text: &str) -> Vec<i64> {
551    let mut out = vec![101_i64];
552    out.extend(
553        tokenize(text)
554            .into_iter()
555            .take(254)
556            .map(|token| ((stable_text_hash(&token) % 30_000) + 1_000) as i64),
557    );
558    out.push(102_i64);
559    out
560}
561
562fn keyword_similarity(query: &str, hint: &str, text: &str) -> f64 {
563    let hinted = if hint.trim().is_empty() {
564        text.to_string()
565    } else {
566        format!("{hint} {text}")
567    };
568    jaccard_similarity(query, &hinted)
569}
570
571fn cosine_dense(a: &[f32], b: &[f32]) -> f64 {
572    if a.is_empty() || b.is_empty() {
573        return 0.0;
574    }
575    let len = a.len().min(b.len());
576    let mut dot = 0.0f64;
577    let mut norm_a = 0.0f64;
578    let mut norm_b = 0.0f64;
579    for idx in 0..len {
580        let va = a[idx] as f64;
581        let vb = b[idx] as f64;
582        dot += va * vb;
583        norm_a += va * va;
584        norm_b += vb * vb;
585    }
586    if norm_a == 0.0 || norm_b == 0.0 {
587        0.0
588    } else {
589        (dot / (norm_a.sqrt() * norm_b.sqrt())).clamp(0.0, 1.0)
590    }
591}
592
593fn normalize_dense(mut vector: Vec<f32>) -> Vec<f32> {
594    let norm = vector
595        .iter()
596        .map(|value| (*value as f64) * (*value as f64))
597        .sum::<f64>()
598        .sqrt();
599    if norm > 0.0 {
600        for value in &mut vector {
601            *value = (*value as f64 / norm) as f32;
602        }
603    }
604    vector
605}
606
607fn jaccard_similarity(a: &str, b: &str) -> f64 {
608    let sa = tokenize(a).into_iter().collect::<HashSet<_>>();
609    let sb = tokenize(b).into_iter().collect::<HashSet<_>>();
610    if sa.is_empty() || sb.is_empty() {
611        return 0.0;
612    }
613
614    let inter = sa.intersection(&sb).count() as f64;
615    let union = sa.union(&sb).count() as f64;
616    (inter / union).clamp(0.0, 1.0)
617}
618
619fn graph_distance_bonus(distance: f64) -> f64 {
620    let d = distance.max(0.0);
621    (1.0 / (1.0 + d)).clamp(0.0, 1.0)
622}
623
624fn normalize_text(text: &str) -> String {
625    text.split_whitespace()
626        .map(|s| s.to_lowercase())
627        .collect::<Vec<_>>()
628        .join(" ")
629}
630
631fn tokenize(text: &str) -> Vec<String> {
632    text.split(|c: char| !c.is_alphanumeric() && c != '_')
633        .filter(|part| part.len() > 1)
634        .map(|part| part.to_lowercase())
635        .collect()
636}
637
638fn cache_key(model_id: &str, text_hash: u64) -> String {
639    format!("{model_id}:{text_hash:016x}")
640}
641
642fn stable_text_hash(text: &str) -> u64 {
643    fxhash64(normalize_text(text).as_bytes())
644}
645
646fn fxhash64(bytes: &[u8]) -> u64 {
647    let mut hash: u64 = 0xcbf29ce484222325;
648    for byte in bytes {
649        hash ^= *byte as u64;
650        hash = hash.wrapping_mul(0x100000001b3);
651    }
652    hash
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn score_is_monotonic_for_semantic_similarity() {
661        let low = score(Features {
662            semantic_similarity: 0.1,
663            keyword_overlap: 0.0,
664            recency: 0.0,
665            graph_distance_bonus: 0.0,
666            failure_bonus: 0.0,
667        });
668
669        let high = score(Features {
670            semantic_similarity: 0.9,
671            keyword_overlap: 0.0,
672            recency: 0.0,
673            graph_distance_bonus: 0.0,
674            failure_bonus: 0.0,
675        });
676
677        assert!(high > low);
678    }
679}