1use crate::error::Result;
7use crate::stats::OutcomeStats;
8use crate::types::{MemoryRecord, RecordId};
9use std::collections::HashMap;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
20pub struct RetrievalRequest {
21 pub embedding: Vec<f32>,
23
24 pub k: usize,
26
27 pub filter: Option<FilterExpression>,
29
30 pub index_names: Option<Vec<String>>,
32
33 pub compute_priors: bool,
35
36 pub timeout: Option<Duration>,
38}
39
40impl RetrievalRequest {
41 #[must_use]
43 pub fn new(embedding: Vec<f32>, k: usize) -> Self {
44 Self {
45 embedding,
46 k,
47 filter: None,
48 index_names: None,
49 compute_priors: true,
50 timeout: None,
51 }
52 }
53
54 #[must_use]
56 pub fn with_filter(mut self, filter: FilterExpression) -> Self {
57 self.filter = Some(filter);
58 self
59 }
60
61 #[must_use]
63 pub fn with_indexes(mut self, names: Vec<String>) -> Self {
64 self.index_names = Some(names);
65 self
66 }
67
68 #[must_use]
70 pub fn with_timeout(mut self, timeout: Duration) -> Self {
71 self.timeout = Some(timeout);
72 self
73 }
74
75 pub fn validate(&self, expected_dim: usize) -> Result<()> {
77 if self.embedding.len() != expected_dim {
78 return Err(crate::error::Error::DimensionMismatch {
79 expected: expected_dim,
80 got: self.embedding.len(),
81 });
82 }
83 if self.k == 0 {
84 return Err(crate::error::Error::InvalidQuery {
85 reason: "k must be greater than 0".into(),
86 });
87 }
88 if self.embedding.iter().any(|x| !x.is_finite()) {
89 return Err(crate::error::Error::InvalidQuery {
90 reason: "embedding contains NaN or Inf".into(),
91 });
92 }
93 Ok(())
94 }
95}
96
97#[derive(Debug, Clone)]
99pub enum FilterExpression {
100 Eq(String, FilterValue),
102 Ne(String, FilterValue),
104 Gt(String, FilterValue),
106 Gte(String, FilterValue),
108 Lt(String, FilterValue),
110 Lte(String, FilterValue),
112 In(String, Vec<FilterValue>),
114 And(Vec<FilterExpression>),
116 Or(Vec<FilterExpression>),
118 Not(Box<FilterExpression>),
120}
121
122#[derive(Debug, Clone)]
124pub enum FilterValue {
125 String(String),
126 Int(i64),
127 Float(f64),
128 Bool(bool),
129}
130
131#[derive(Debug, Clone)]
133pub struct IngestRecord {
134 pub id: String,
136
137 pub embedding: Vec<f32>,
139
140 pub context: String,
142
143 pub outcome: f64,
145
146 pub metadata: HashMap<String, MetadataValue>,
148}
149
150#[derive(Debug, Clone)]
152pub enum MetadataValue {
153 String(String),
154 Int(i64),
155 Float(f64),
156 Bool(bool),
157 StringList(Vec<String>),
158}
159
160#[derive(Debug, Clone)]
168pub struct RetrievalResponse {
169 pub prior: PriorBundle,
171
172 pub candidates: Vec<RankedCandidate>,
174
175 pub latency: Duration,
177
178 pub indexes_searched: Vec<String>,
180
181 pub records_scanned: usize,
183
184 pub cache_hit: bool,
186}
187
188#[derive(Debug, Clone, Default)]
193pub struct PriorBundle {
194 pub mean: Option<f64>,
196
197 pub variance: Option<f64>,
199
200 pub std_dev: Option<f64>,
202
203 pub confidence: f64,
205
206 pub count: u64,
208
209 pub min: Option<f64>,
211
212 pub max: Option<f64>,
214
215 pub weighted_mean: Option<f64>,
217}
218
219impl PriorBundle {
220 #[must_use]
222 pub fn from_stats(stats: &OutcomeStats) -> Self {
223 let count = stats.count();
224 let confidence = Self::compute_confidence(count);
225
226 Self {
227 mean: stats.mean_scalar(),
228 variance: stats.variance_scalar(),
229 std_dev: stats.std_scalar(),
230 confidence,
231 count,
232 min: stats.min().and_then(|m| m.first().copied().map(f64::from)),
233 max: stats.max().and_then(|m| m.first().copied().map(f64::from)),
234 weighted_mean: None,
235 }
236 }
237
238 #[must_use]
240 pub fn from_outcomes(outcomes: &[f64], weights: Option<&[f64]>) -> Self {
241 if outcomes.is_empty() {
242 return Self::default();
243 }
244
245 let count = outcomes.len() as u64;
246 let confidence = Self::compute_confidence(count);
247
248 let mean = outcomes.iter().sum::<f64>() / outcomes.len() as f64;
250 let variance = if outcomes.len() > 1 {
251 let sum_sq: f64 = outcomes.iter().map(|x| (x - mean).powi(2)).sum();
252 Some(sum_sq / (outcomes.len() - 1) as f64)
253 } else {
254 None
255 };
256 let std_dev = variance.map(|v| v.sqrt());
257 let min = outcomes.iter().copied().fold(f64::INFINITY, f64::min);
258 let max = outcomes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
259
260 let weighted_mean = weights.map(|w| {
262 let total_weight: f64 = w.iter().sum();
263 if total_weight > 0.0 {
264 outcomes
265 .iter()
266 .zip(w.iter())
267 .map(|(o, w)| o * w)
268 .sum::<f64>()
269 / total_weight
270 } else {
271 mean
272 }
273 });
274
275 Self {
276 mean: Some(mean),
277 variance,
278 std_dev,
279 confidence,
280 count,
281 min: Some(min),
282 max: Some(max),
283 weighted_mean,
284 }
285 }
286
287 fn compute_confidence(count: u64) -> f64 {
291 if count == 0 {
292 return 0.0;
293 }
294 let k = 0.15;
297 let x0 = 10.0;
298 1.0 / (1.0 + (-(k * (count as f64 - x0))).exp())
299 }
300
301 #[must_use]
303 pub fn is_reliable(&self) -> bool {
304 self.confidence >= 0.8
305 }
306
307 #[must_use]
309 pub fn is_empty(&self) -> bool {
310 self.count == 0
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct RankedCandidate {
317 pub record_id: String,
319
320 pub score: f64,
322
323 pub distance: f64,
325
326 pub rank: u32,
328
329 pub outcome: f64,
331
332 pub context: String,
334}
335
336pub trait RetrievalEngine: Send + Sync {
344 fn query(&self, request: &RetrievalRequest) -> Result<RetrievalResponse>;
346
347 fn dimension(&self) -> usize;
349
350 fn corpus_size(&self) -> usize;
352
353 fn index_names(&self) -> Vec<String>;
355}
356
357pub trait Corpus: Send + Sync {
359 fn ingest(&mut self, record: IngestRecord) -> Result<RecordId>;
361
362 fn ingest_batch(&mut self, records: Vec<IngestRecord>) -> Result<Vec<RecordId>>;
364
365 fn update_outcome(&mut self, id: &RecordId, outcome: f64) -> Result<()>;
367
368 fn remove(&mut self, id: &RecordId) -> Result<bool>;
370
371 fn get(&self, id: &RecordId) -> Option<MemoryRecord>;
373
374 fn size(&self) -> usize;
376}
377
378pub trait VectorSearcher: Send + Sync {
380 fn add(&mut self, id: &str, vector: &[f32]) -> Result<()>;
382
383 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchHit>>;
385
386 fn remove(&mut self, id: &str) -> Result<bool>;
388
389 fn dimension(&self) -> usize;
391
392 fn len(&self) -> usize;
394
395 fn is_empty(&self) -> bool {
397 self.len() == 0
398 }
399}
400
401#[derive(Debug, Clone)]
403pub struct SearchHit {
404 pub id: String,
406 pub distance: f32,
408 pub score: f32,
410}
411
412#[derive(Debug, Clone)]
418pub struct RAGBuilder {
419 dimension: usize,
420 index_type: IndexType,
421 cache_enabled: bool,
422 cache_size: usize,
423 default_k: usize,
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428pub enum IndexType {
429 Flat,
431 Hnsw,
433}
434
435impl Default for RAGBuilder {
436 fn default() -> Self {
437 Self {
438 dimension: 512,
439 index_type: IndexType::Flat,
440 cache_enabled: true,
441 cache_size: 10000,
442 default_k: 10,
443 }
444 }
445}
446
447impl RAGBuilder {
448 #[must_use]
450 pub fn new(dimension: usize) -> Self {
451 Self {
452 dimension,
453 ..Default::default()
454 }
455 }
456
457 #[must_use]
459 pub fn index_type(mut self, index_type: IndexType) -> Self {
460 self.index_type = index_type;
461 self
462 }
463
464 #[must_use]
466 pub fn cache(mut self, enabled: bool) -> Self {
467 self.cache_enabled = enabled;
468 self
469 }
470
471 #[must_use]
473 pub fn cache_size(mut self, size: usize) -> Self {
474 self.cache_size = size;
475 self
476 }
477
478 #[must_use]
480 pub fn default_k(mut self, k: usize) -> Self {
481 self.default_k = k;
482 self
483 }
484
485 #[must_use]
487 pub fn get_dimension(&self) -> usize {
488 self.dimension
489 }
490
491 #[must_use]
493 pub fn get_index_type(&self) -> IndexType {
494 self.index_type
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_retrieval_request_validation() {
504 let valid = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 10);
505 assert!(valid.validate(3).is_ok());
506
507 let wrong_dim = RetrievalRequest::new(vec![1.0, 2.0], 10);
508 assert!(wrong_dim.validate(3).is_err());
509
510 let zero_k = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 0);
511 assert!(zero_k.validate(3).is_err());
512
513 let nan = RetrievalRequest::new(vec![1.0, f32::NAN, 3.0], 10);
514 assert!(nan.validate(3).is_err());
515 }
516
517 #[test]
518 fn test_prior_bundle_from_outcomes() {
519 let outcomes = vec![0.8, 0.9, 0.7, 0.85];
520 let prior = PriorBundle::from_outcomes(&outcomes, None);
521
522 assert!(prior.mean.is_some());
523 assert!((prior.mean.unwrap() - 0.8125).abs() < 1e-6);
524 assert_eq!(prior.count, 4);
525 assert!(prior.confidence > 0.0);
526 }
527
528 #[test]
529 fn test_prior_bundle_empty() {
530 let prior = PriorBundle::from_outcomes(&[], None);
531 assert!(prior.is_empty());
532 assert!(!prior.is_reliable());
533 }
534
535 #[test]
536 fn test_prior_bundle_weighted() {
537 let outcomes = vec![1.0, 0.0];
538 let weights = vec![0.8, 0.2];
539 let prior = PriorBundle::from_outcomes(&outcomes, Some(&weights));
540
541 assert!(prior.weighted_mean.is_some());
543 assert!((prior.weighted_mean.unwrap() - 0.8).abs() < 1e-6);
544 }
545
546 #[test]
547 fn test_confidence_scaling() {
548 assert!(PriorBundle::compute_confidence(0) == 0.0);
549 assert!(PriorBundle::compute_confidence(5) > 0.3);
550 assert!(PriorBundle::compute_confidence(20) > 0.8);
551 assert!(PriorBundle::compute_confidence(100) > 0.99);
552 }
553
554 #[test]
555 fn test_builder() {
556 let builder = RAGBuilder::new(768)
557 .index_type(IndexType::Hnsw)
558 .cache(true)
559 .cache_size(5000)
560 .default_k(20);
561
562 assert_eq!(builder.get_dimension(), 768);
563 assert_eq!(builder.get_index_type(), IndexType::Hnsw);
564 }
565}