1use std::collections::{HashMap, HashSet, VecDeque};
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7#[derive(Debug, Error)]
8pub enum SemanticError {
9 #[error(
10 "ONNX model not found at {path}; set [semantic].model to a local .onnx file or enable allow_fallback"
11 )]
12 OnnxModelNotFound { path: String },
13 #[error(
14 "ONNX vocab not found at {path}; set [semantic].vocab to a local vocab.txt/tokenizer file or enable allow_fallback"
15 )]
16 OnnxVocabNotFound { path: String },
17 #[error(
18 "ONNX backend requested but ctx-semantic was built without the `onnx` feature; rebuild with `cargo build --features ctx-semantic/onnx` or enable semantic.allow_fallback"
19 )]
20 OnnxFeatureDisabled,
21 #[error("ONNX inference failed: {0}")]
22 OnnxInference(String),
23}
24
25#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
26pub struct Features {
27 pub semantic_similarity: f64,
28 pub keyword_overlap: f64,
29 pub recency: f64,
30 pub graph_distance_bonus: f64,
31 pub failure_bonus: f64,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum SemanticBackendKind {
37 LocalHash,
38 Onnx,
39}
40
41impl SemanticBackendKind {
42 pub fn parse(value: &str) -> Option<Self> {
43 match value.trim().to_lowercase().as_str() {
44 "local" | "local_hash" | "hash" => Some(Self::LocalHash),
45 "onnx" | "onnx_runtime" => Some(Self::Onnx),
46 _ => None,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct RankingConfig {
53 pub backend: SemanticBackendKind,
54 pub max_chunks: usize,
55 pub adaptive_threshold: bool,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct SemanticEngineConfig {
60 pub backend: SemanticBackendKind,
61 pub model_path: Option<PathBuf>,
62 pub vocab_path: Option<PathBuf>,
63 pub max_chunks: usize,
64 pub adaptive_threshold: bool,
65 pub allow_fallback: bool,
66}
67
68impl SemanticEngineConfig {
69 pub fn local_hash(max_chunks: usize, adaptive_threshold: bool) -> Self {
70 Self {
71 backend: SemanticBackendKind::LocalHash,
72 model_path: None,
73 vocab_path: None,
74 max_chunks,
75 adaptive_threshold,
76 allow_fallback: true,
77 }
78 }
79
80 fn from_ranking_config(config: RankingConfig) -> Self {
81 Self {
82 backend: config.backend,
83 model_path: None,
84 vocab_path: None,
85 max_chunks: config.max_chunks,
86 adaptive_threshold: config.adaptive_threshold,
87 allow_fallback: true,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ChunkCandidate {
94 pub id: String,
95 pub text: String,
96 pub keyword_hint: String,
97 pub recency: f64,
98 pub graph_distance: f64,
99 pub failure_relevance: f64,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct RankedChunk {
104 pub id: String,
105 pub score: f64,
106 pub features: Features,
107 pub reason: String,
108 pub text: String,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
112pub struct EmbeddingMetadata {
113 pub model_id: String,
114 pub text_hash: u64,
115 pub dimensions: usize,
116}
117
118#[derive(Debug, Clone)]
119struct CacheEntry {
120 metadata: EmbeddingMetadata,
121 vector: Vec<f32>,
122}
123
124#[derive(Debug, Clone)]
125pub struct EmbeddingCache {
126 capacity: usize,
127 order: VecDeque<String>,
128 entries: HashMap<String, CacheEntry>,
129}
130
131impl EmbeddingCache {
132 pub fn new(capacity: usize) -> Self {
133 Self {
134 capacity: capacity.max(1),
135 order: VecDeque::new(),
136 entries: HashMap::new(),
137 }
138 }
139
140 pub fn put(&mut self, model_id: &str, text: &str, vector: Vec<f32>) -> EmbeddingMetadata {
141 let metadata = EmbeddingMetadata {
142 model_id: model_id.to_string(),
143 text_hash: stable_text_hash(text),
144 dimensions: vector.len(),
145 };
146 let key = cache_key(&metadata.model_id, metadata.text_hash);
147
148 if !self.entries.contains_key(&key) {
149 self.order.push_back(key.clone());
150 }
151 self.entries.insert(
152 key.clone(),
153 CacheEntry {
154 metadata: metadata.clone(),
155 vector,
156 },
157 );
158 self.evict_if_needed();
159 metadata
160 }
161
162 pub fn get(&self, model_id: &str, text: &str) -> Option<&[f32]> {
163 let key = cache_key(model_id, stable_text_hash(text));
164 self.entries.get(&key).map(|entry| entry.vector.as_slice())
165 }
166
167 pub fn metadata(&self, model_id: &str, text: &str) -> Option<&EmbeddingMetadata> {
168 let key = cache_key(model_id, stable_text_hash(text));
169 self.entries.get(&key).map(|entry| &entry.metadata)
170 }
171
172 fn evict_if_needed(&mut self) {
173 while self.entries.len() > self.capacity {
174 if let Some(oldest) = self.order.pop_front() {
175 self.entries.remove(&oldest);
176 } else {
177 break;
178 }
179 }
180 }
181}
182
183pub fn score(features: Features) -> f64 {
184 0.40 * features.semantic_similarity
185 + 0.20 * features.keyword_overlap
186 + 0.15 * features.recency
187 + 0.15 * features.graph_distance_bonus
188 + 0.10 * features.failure_bonus
189}
190
191pub fn rank_chunks_hybrid(
192 query: &str,
193 candidates: &[ChunkCandidate],
194 config: RankingConfig,
195) -> Vec<RankedChunk> {
196 rank_chunks(
197 query,
198 candidates,
199 SemanticEngineConfig::from_ranking_config(config),
200 )
201 .unwrap_or_default()
202}
203
204pub fn rank_chunks(
205 query: &str,
206 candidates: &[ChunkCandidate],
207 config: SemanticEngineConfig,
208) -> Result<Vec<RankedChunk>, SemanticError> {
209 if candidates.is_empty() {
210 return Ok(Vec::new());
211 }
212
213 let backend = resolve_backend(&config)?;
214 let mut cache = EmbeddingCache::new(candidates.len() + 1);
215 let query_embedding = embed_text(query, &backend, &mut cache)?;
216 let mut seen_fingerprint = HashSet::new();
217 let mut ranked = Vec::new();
218
219 for candidate in candidates {
220 let fingerprint = normalize_text(&candidate.text);
221 if !seen_fingerprint.insert(fingerprint) {
222 continue;
223 }
224
225 let candidate_embedding = embed_text(&candidate.text, &backend, &mut cache)?;
226 let semantic_similarity = cosine_dense(&query_embedding, &candidate_embedding);
227 let keyword_overlap = keyword_similarity(query, &candidate.keyword_hint, &candidate.text);
228 let features = Features {
229 semantic_similarity,
230 keyword_overlap,
231 recency: candidate.recency.clamp(0.0, 1.0),
232 graph_distance_bonus: graph_distance_bonus(candidate.graph_distance),
233 failure_bonus: candidate.failure_relevance.clamp(0.0, 1.0),
234 };
235
236 let total_score = score(features);
237 ranked.push(RankedChunk {
238 id: candidate.id.clone(),
239 score: total_score,
240 features,
241 reason: format_reason(&backend, features),
242 text: candidate.text.clone(),
243 });
244 }
245
246 ranked.sort_by(|a, b| {
247 b.score
248 .partial_cmp(&a.score)
249 .unwrap_or(std::cmp::Ordering::Equal)
250 });
251
252 let thresholded = if config.adaptive_threshold && !ranked.is_empty() {
253 apply_adaptive_threshold(ranked, config.max_chunks.max(1))
254 } else {
255 ranked
256 };
257
258 let mut final_ranked = thresholded;
259 final_ranked.truncate(config.max_chunks.max(1));
260 Ok(final_ranked)
261}
262
263#[derive(Debug, Clone)]
264#[allow(dead_code)]
265enum ResolvedBackend {
266 LocalHash {
267 model_id: String,
268 fallback_from: Option<&'static str>,
269 },
270 Onnx {
271 model_id: String,
272 model_path: PathBuf,
273 vocab_path: Option<PathBuf>,
274 },
275}
276
277fn resolve_backend(config: &SemanticEngineConfig) -> Result<ResolvedBackend, SemanticError> {
278 match config.backend {
279 SemanticBackendKind::LocalHash => Ok(ResolvedBackend::LocalHash {
280 model_id: "local_hash:v1".to_string(),
281 fallback_from: None,
282 }),
283 SemanticBackendKind::Onnx => resolve_onnx_backend(config),
284 }
285}
286
287fn resolve_onnx_backend(config: &SemanticEngineConfig) -> Result<ResolvedBackend, SemanticError> {
288 let Some(model_path) = config.model_path.clone() else {
289 return fallback_or_error(
290 config.allow_fallback,
291 SemanticError::OnnxModelNotFound {
292 path: "semantic.model".to_string(),
293 },
294 );
295 };
296
297 if !model_path.exists() {
298 return fallback_or_error(
299 config.allow_fallback,
300 SemanticError::OnnxModelNotFound {
301 path: format!("{} (semantic.model)", model_path.display()),
302 },
303 );
304 }
305
306 if let Some(vocab_path) = &config.vocab_path {
307 if !vocab_path.exists() {
308 return fallback_or_error(
309 config.allow_fallback,
310 SemanticError::OnnxVocabNotFound {
311 path: format!("{} (semantic.vocab)", vocab_path.display()),
312 },
313 );
314 }
315 }
316
317 #[cfg(not(feature = "onnx"))]
318 {
319 fallback_or_error(config.allow_fallback, SemanticError::OnnxFeatureDisabled)
320 }
321
322 #[cfg(feature = "onnx")]
323 {
324 Ok(ResolvedBackend::Onnx {
325 model_id: format!("onnx:{}", model_path.display()),
326 model_path,
327 vocab_path: config.vocab_path.clone(),
328 })
329 }
330}
331
332fn fallback_or_error(
333 allow_fallback: bool,
334 error: SemanticError,
335) -> Result<ResolvedBackend, SemanticError> {
336 if allow_fallback {
337 Ok(ResolvedBackend::LocalHash {
338 model_id: "local_hash:v1".to_string(),
339 fallback_from: Some("onnx"),
340 })
341 } else {
342 Err(error)
343 }
344}
345
346fn embed_text(
347 text: &str,
348 backend: &ResolvedBackend,
349 cache: &mut EmbeddingCache,
350) -> Result<Vec<f32>, SemanticError> {
351 let model_id = backend.model_id();
352 if let Some(cached) = cache.get(model_id, text) {
353 return Ok(cached.to_vec());
354 }
355
356 let vector = match backend {
357 ResolvedBackend::LocalHash { .. } => local_hash_embedding(text),
358 ResolvedBackend::Onnx {
359 model_path,
360 vocab_path,
361 ..
362 } => onnx_embedding(model_path, vocab_path.as_ref(), text)?,
363 };
364 cache.put(model_id, text, vector.clone());
365 Ok(vector)
366}
367
368impl ResolvedBackend {
369 fn model_id(&self) -> &str {
370 match self {
371 Self::LocalHash { model_id, .. } | Self::Onnx { model_id, .. } => model_id,
372 }
373 }
374
375 fn label(&self) -> &'static str {
376 match self {
377 Self::LocalHash { .. } => "local_hash",
378 Self::Onnx { .. } => "onnx",
379 }
380 }
381
382 fn fallback_from(&self) -> Option<&'static str> {
383 match self {
384 Self::LocalHash { fallback_from, .. } => *fallback_from,
385 Self::Onnx { .. } => None,
386 }
387 }
388}
389
390fn format_reason(backend: &ResolvedBackend, features: Features) -> String {
391 let mut reason = format!(
392 "backend={} semantic={:.3} keyword={:.3} recency={:.3} graph={:.3} failure={:.3}",
393 backend.label(),
394 features.semantic_similarity,
395 features.keyword_overlap,
396 features.recency,
397 features.graph_distance_bonus,
398 features.failure_bonus
399 );
400 if let Some(source) = backend.fallback_from() {
401 reason.push_str(&format!(" fallback_from={source}"));
402 }
403 reason
404}
405
406fn apply_adaptive_threshold(ranked: Vec<RankedChunk>, max_chunks: usize) -> Vec<RankedChunk> {
407 let top = ranked[0].score;
408 let threshold = (top * 0.35).max(0.15);
409 let mut kept = ranked
410 .iter()
411 .filter(|entry| entry.score >= threshold)
412 .cloned()
413 .collect::<Vec<_>>();
414
415 if kept.len() < 2 && ranked.len() >= 2 && max_chunks >= 2 {
416 kept = ranked.iter().take(2).cloned().collect();
417 }
418
419 kept
420}
421
422fn local_hash_embedding(text: &str) -> Vec<f32> {
423 const DIMS: usize = 256;
424 let mut vector = vec![0.0f32; DIMS];
425 for token in tokenize(text) {
426 let hash = stable_text_hash(&token);
427 let idx = (hash as usize) % DIMS;
428 vector[idx] += 1.0;
429
430 let chars = token.chars().collect::<Vec<_>>();
431 for window in chars.windows(3) {
432 let gram = window.iter().collect::<String>();
433 let gram_idx = (stable_text_hash(&gram) as usize) % DIMS;
434 vector[gram_idx] += 0.25;
435 }
436 }
437 normalize_dense(vector)
438}
439
440#[cfg(feature = "onnx")]
441fn onnx_embedding(
442 model_path: &std::path::Path,
443 vocab_path: Option<&PathBuf>,
444 text: &str,
445) -> Result<Vec<f32>, SemanticError> {
446 use tract_onnx::prelude::*;
447
448 let tokens = if let Some(vocab_path) = vocab_path {
449 wordpiece_token_ids(vocab_path, text)?
450 } else {
451 hashed_token_ids(text)
452 };
453 let seq_len = tokens.len().max(1);
454 let attention = vec![1_i64; seq_len];
455 let token_types = vec![0_i64; seq_len];
456
457 let model = tract_onnx::onnx()
458 .model_for_path(model_path)
459 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
460 .into_optimized()
461 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
462 let input_count = model
463 .input_outlets()
464 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
465 .len();
466 if input_count == 0 || input_count > 3 {
467 return Err(SemanticError::OnnxInference(format!(
468 "expected ONNX text embedding model with 1-3 inputs, got {input_count}"
469 )));
470 }
471
472 let model = model
473 .into_runnable()
474 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
475
476 let mut inputs = TVec::new();
477 for values in [&tokens, &attention, &token_types]
478 .into_iter()
479 .take(input_count)
480 {
481 inputs.push(
482 Tensor::from_shape(&[1, seq_len], values)
483 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?
484 .into(),
485 );
486 }
487
488 let outputs = model
489 .run(inputs)
490 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
491 let first = outputs
492 .first()
493 .ok_or_else(|| SemanticError::OnnxInference("model returned no outputs".to_string()))?;
494 let view = first
495 .to_array_view::<f32>()
496 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
497 let shape = view.shape();
498
499 let vector = match shape.len() {
500 2 => view.iter().copied().collect::<Vec<_>>(),
501 3 => {
502 let dim = shape[2];
503 let mut pooled = vec![0.0f32; dim];
504 for token_idx in 0..shape[1] {
505 for dim_idx in 0..dim {
506 pooled[dim_idx] += view[[0, token_idx, dim_idx]];
507 }
508 }
509 for value in &mut pooled {
510 *value /= shape[1].max(1) as f32;
511 }
512 pooled
513 }
514 _ => view.iter().copied().collect::<Vec<_>>(),
515 };
516
517 Ok(normalize_dense(vector))
518}
519
520#[cfg(not(feature = "onnx"))]
521fn onnx_embedding(
522 _model_path: &std::path::Path,
523 _vocab_path: Option<&PathBuf>,
524 _text: &str,
525) -> Result<Vec<f32>, SemanticError> {
526 Err(SemanticError::OnnxFeatureDisabled)
527}
528
529#[cfg(feature = "onnx")]
530fn wordpiece_token_ids(vocab_path: &PathBuf, text: &str) -> Result<Vec<i64>, SemanticError> {
531 let vocab = std::fs::read_to_string(vocab_path)
532 .map_err(|err| SemanticError::OnnxInference(err.to_string()))?;
533 let mut ids = HashMap::new();
534 for (idx, token) in vocab.lines().enumerate() {
535 ids.insert(token.trim().to_string(), idx as i64);
536 }
537
538 let cls = *ids.get("[CLS]").unwrap_or(&101);
539 let sep = *ids.get("[SEP]").unwrap_or(&102);
540 let unk = *ids.get("[UNK]").unwrap_or(&100);
541 let mut out = vec![cls];
542 for token in tokenize(text).into_iter().take(254) {
543 out.push(*ids.get(&token).unwrap_or(&unk));
544 }
545 out.push(sep);
546 Ok(out)
547}
548
549#[cfg(feature = "onnx")]
550fn hashed_token_ids(text: &str) -> Vec<i64> {
551 let mut out = vec![101_i64];
552 out.extend(
553 tokenize(text)
554 .into_iter()
555 .take(254)
556 .map(|token| ((stable_text_hash(&token) % 30_000) + 1_000) as i64),
557 );
558 out.push(102_i64);
559 out
560}
561
562fn keyword_similarity(query: &str, hint: &str, text: &str) -> f64 {
563 let hinted = if hint.trim().is_empty() {
564 text.to_string()
565 } else {
566 format!("{hint} {text}")
567 };
568 jaccard_similarity(query, &hinted)
569}
570
571fn cosine_dense(a: &[f32], b: &[f32]) -> f64 {
572 if a.is_empty() || b.is_empty() {
573 return 0.0;
574 }
575 let len = a.len().min(b.len());
576 let mut dot = 0.0f64;
577 let mut norm_a = 0.0f64;
578 let mut norm_b = 0.0f64;
579 for idx in 0..len {
580 let va = a[idx] as f64;
581 let vb = b[idx] as f64;
582 dot += va * vb;
583 norm_a += va * va;
584 norm_b += vb * vb;
585 }
586 if norm_a == 0.0 || norm_b == 0.0 {
587 0.0
588 } else {
589 (dot / (norm_a.sqrt() * norm_b.sqrt())).clamp(0.0, 1.0)
590 }
591}
592
593fn normalize_dense(mut vector: Vec<f32>) -> Vec<f32> {
594 let norm = vector
595 .iter()
596 .map(|value| (*value as f64) * (*value as f64))
597 .sum::<f64>()
598 .sqrt();
599 if norm > 0.0 {
600 for value in &mut vector {
601 *value = (*value as f64 / norm) as f32;
602 }
603 }
604 vector
605}
606
607fn jaccard_similarity(a: &str, b: &str) -> f64 {
608 let sa = tokenize(a).into_iter().collect::<HashSet<_>>();
609 let sb = tokenize(b).into_iter().collect::<HashSet<_>>();
610 if sa.is_empty() || sb.is_empty() {
611 return 0.0;
612 }
613
614 let inter = sa.intersection(&sb).count() as f64;
615 let union = sa.union(&sb).count() as f64;
616 (inter / union).clamp(0.0, 1.0)
617}
618
619fn graph_distance_bonus(distance: f64) -> f64 {
620 let d = distance.max(0.0);
621 (1.0 / (1.0 + d)).clamp(0.0, 1.0)
622}
623
624fn normalize_text(text: &str) -> String {
625 text.split_whitespace()
626 .map(|s| s.to_lowercase())
627 .collect::<Vec<_>>()
628 .join(" ")
629}
630
631fn tokenize(text: &str) -> Vec<String> {
632 text.split(|c: char| !c.is_alphanumeric() && c != '_')
633 .filter(|part| part.len() > 1)
634 .map(|part| part.to_lowercase())
635 .collect()
636}
637
638fn cache_key(model_id: &str, text_hash: u64) -> String {
639 format!("{model_id}:{text_hash:016x}")
640}
641
642fn stable_text_hash(text: &str) -> u64 {
643 fxhash64(normalize_text(text).as_bytes())
644}
645
646fn fxhash64(bytes: &[u8]) -> u64 {
647 let mut hash: u64 = 0xcbf29ce484222325;
648 for byte in bytes {
649 hash ^= *byte as u64;
650 hash = hash.wrapping_mul(0x100000001b3);
651 }
652 hash
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn score_is_monotonic_for_semantic_similarity() {
661 let low = score(Features {
662 semantic_similarity: 0.1,
663 keyword_overlap: 0.0,
664 recency: 0.0,
665 graph_distance_bonus: 0.0,
666 failure_bonus: 0.0,
667 });
668
669 let high = score(Features {
670 semantic_similarity: 0.9,
671 keyword_overlap: 0.0,
672 recency: 0.0,
673 graph_distance_bonus: 0.0,
674 failure_bonus: 0.0,
675 });
676
677 assert!(high > low);
678 }
679}