use crate::error::Result;
use crate::stats::OutcomeStats;
use crate::types::{MemoryRecord, RecordId};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetrievalRequest {
pub embedding: Vec<f32>,
pub k: usize,
pub filter: Option<FilterExpression>,
pub index_names: Option<Vec<String>>,
pub compute_priors: bool,
pub timeout: Option<Duration>,
}
impl RetrievalRequest {
#[must_use]
pub fn new(embedding: Vec<f32>, k: usize) -> Self {
Self {
embedding,
k,
filter: None,
index_names: None,
compute_priors: true,
timeout: None,
}
}
#[must_use]
pub fn with_filter(mut self, filter: FilterExpression) -> Self {
self.filter = Some(filter);
self
}
#[must_use]
pub fn with_indexes(mut self, names: Vec<String>) -> Self {
self.index_names = Some(names);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn validate(&self, expected_dim: usize) -> Result<()> {
if self.embedding.len() != expected_dim {
return Err(crate::error::Error::DimensionMismatch {
expected: expected_dim,
got: self.embedding.len(),
});
}
if self.k == 0 {
return Err(crate::error::Error::InvalidQuery {
reason: "k must be greater than 0".into(),
});
}
if self.embedding.iter().any(|x| !x.is_finite()) {
return Err(crate::error::Error::InvalidQuery {
reason: "embedding contains NaN or Inf".into(),
});
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum FilterExpression {
Eq(String, FilterValue),
Ne(String, FilterValue),
Gt(String, FilterValue),
Gte(String, FilterValue),
Lt(String, FilterValue),
Lte(String, FilterValue),
In(String, Vec<FilterValue>),
And(Vec<FilterExpression>),
Or(Vec<FilterExpression>),
Not(Box<FilterExpression>),
}
#[derive(Debug, Clone)]
pub enum FilterValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
}
#[derive(Debug, Clone)]
pub struct IngestRecord {
pub id: String,
pub embedding: Vec<f32>,
pub context: String,
pub outcome: f64,
pub metadata: HashMap<String, MetadataValue>,
}
#[derive(Debug, Clone)]
pub enum MetadataValue {
String(String),
Int(i64),
Float(f64),
Bool(bool),
StringList(Vec<String>),
}
#[derive(Debug, Clone)]
pub struct RetrievalResponse {
pub prior: PriorBundle,
pub candidates: Vec<RankedCandidate>,
pub latency: Duration,
pub indexes_searched: Vec<String>,
pub records_scanned: usize,
pub cache_hit: bool,
}
#[derive(Debug, Clone, Default)]
pub struct PriorBundle {
pub mean: Option<f64>,
pub variance: Option<f64>,
pub std_dev: Option<f64>,
pub confidence: f64,
pub count: u64,
pub min: Option<f64>,
pub max: Option<f64>,
pub weighted_mean: Option<f64>,
}
impl PriorBundle {
#[must_use]
pub fn from_stats(stats: &OutcomeStats) -> Self {
let count = stats.count();
let confidence = Self::compute_confidence(count);
Self {
mean: stats.mean_scalar(),
variance: stats.variance_scalar(),
std_dev: stats.std_scalar(),
confidence,
count,
min: stats.min().and_then(|m| m.first().copied().map(f64::from)),
max: stats.max().and_then(|m| m.first().copied().map(f64::from)),
weighted_mean: None,
}
}
#[must_use]
pub fn from_outcomes(outcomes: &[f64], weights: Option<&[f64]>) -> Self {
if outcomes.is_empty() {
return Self::default();
}
let count = outcomes.len() as u64;
let confidence = Self::compute_confidence(count);
let mean = outcomes.iter().sum::<f64>() / outcomes.len() as f64;
let variance = if outcomes.len() > 1 {
let sum_sq: f64 = outcomes.iter().map(|x| (x - mean).powi(2)).sum();
Some(sum_sq / (outcomes.len() - 1) as f64)
} else {
None
};
let std_dev = variance.map(|v| v.sqrt());
let min = outcomes.iter().copied().fold(f64::INFINITY, f64::min);
let max = outcomes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let weighted_mean = weights.map(|w| {
let total_weight: f64 = w.iter().sum();
if total_weight > 0.0 {
outcomes
.iter()
.zip(w.iter())
.map(|(o, w)| o * w)
.sum::<f64>()
/ total_weight
} else {
mean
}
});
Self {
mean: Some(mean),
variance,
std_dev,
confidence,
count,
min: Some(min),
max: Some(max),
weighted_mean,
}
}
fn compute_confidence(count: u64) -> f64 {
if count == 0 {
return 0.0;
}
let k = 0.15;
let x0 = 10.0;
1.0 / (1.0 + (-(k * (count as f64 - x0))).exp())
}
#[must_use]
pub fn is_reliable(&self) -> bool {
self.confidence >= 0.8
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
}
#[derive(Debug, Clone)]
pub struct RankedCandidate {
pub record_id: String,
pub score: f64,
pub distance: f64,
pub rank: u32,
pub outcome: f64,
pub context: String,
}
pub trait RetrievalEngine: Send + Sync {
fn query(&self, request: &RetrievalRequest) -> Result<RetrievalResponse>;
fn dimension(&self) -> usize;
fn corpus_size(&self) -> usize;
fn index_names(&self) -> Vec<String>;
}
pub trait Corpus: Send + Sync {
fn ingest(&mut self, record: IngestRecord) -> Result<RecordId>;
fn ingest_batch(&mut self, records: Vec<IngestRecord>) -> Result<Vec<RecordId>>;
fn update_outcome(&mut self, id: &RecordId, outcome: f64) -> Result<()>;
fn remove(&mut self, id: &RecordId) -> Result<bool>;
fn get(&self, id: &RecordId) -> Option<MemoryRecord>;
fn size(&self) -> usize;
}
pub trait VectorSearcher: Send + Sync {
fn add(&mut self, id: &str, vector: &[f32]) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchHit>>;
fn remove(&mut self, id: &str) -> Result<bool>;
fn dimension(&self) -> usize;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct SearchHit {
pub id: String,
pub distance: f32,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct RAGBuilder {
dimension: usize,
index_type: IndexType,
cache_enabled: bool,
cache_size: usize,
default_k: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IndexType {
Flat,
Hnsw,
}
impl Default for RAGBuilder {
fn default() -> Self {
Self {
dimension: 512,
index_type: IndexType::Flat,
cache_enabled: true,
cache_size: 10000,
default_k: 10,
}
}
}
impl RAGBuilder {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self {
dimension,
..Default::default()
}
}
#[must_use]
pub fn index_type(mut self, index_type: IndexType) -> Self {
self.index_type = index_type;
self
}
#[must_use]
pub fn cache(mut self, enabled: bool) -> Self {
self.cache_enabled = enabled;
self
}
#[must_use]
pub fn cache_size(mut self, size: usize) -> Self {
self.cache_size = size;
self
}
#[must_use]
pub fn default_k(mut self, k: usize) -> Self {
self.default_k = k;
self
}
#[must_use]
pub fn get_dimension(&self) -> usize {
self.dimension
}
#[must_use]
pub fn get_index_type(&self) -> IndexType {
self.index_type
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retrieval_request_validation() {
let valid = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 10);
assert!(valid.validate(3).is_ok());
let wrong_dim = RetrievalRequest::new(vec![1.0, 2.0], 10);
assert!(wrong_dim.validate(3).is_err());
let zero_k = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 0);
assert!(zero_k.validate(3).is_err());
let nan = RetrievalRequest::new(vec![1.0, f32::NAN, 3.0], 10);
assert!(nan.validate(3).is_err());
}
#[test]
fn test_prior_bundle_from_outcomes() {
let outcomes = vec![0.8, 0.9, 0.7, 0.85];
let prior = PriorBundle::from_outcomes(&outcomes, None);
assert!(prior.mean.is_some());
assert!((prior.mean.unwrap() - 0.8125).abs() < 1e-6);
assert_eq!(prior.count, 4);
assert!(prior.confidence > 0.0);
}
#[test]
fn test_prior_bundle_empty() {
let prior = PriorBundle::from_outcomes(&[], None);
assert!(prior.is_empty());
assert!(!prior.is_reliable());
}
#[test]
fn test_prior_bundle_weighted() {
let outcomes = vec![1.0, 0.0];
let weights = vec![0.8, 0.2];
let prior = PriorBundle::from_outcomes(&outcomes, Some(&weights));
assert!(prior.weighted_mean.is_some());
assert!((prior.weighted_mean.unwrap() - 0.8).abs() < 1e-6);
}
#[test]
fn test_confidence_scaling() {
assert!(PriorBundle::compute_confidence(0) == 0.0);
assert!(PriorBundle::compute_confidence(5) > 0.3);
assert!(PriorBundle::compute_confidence(20) > 0.8);
assert!(PriorBundle::compute_confidence(100) > 0.99);
}
#[test]
fn test_builder() {
let builder = RAGBuilder::new(768)
.index_type(IndexType::Hnsw)
.cache(true)
.cache_size(5000)
.default_k(20);
assert_eq!(builder.get_dimension(), 768);
assert_eq!(builder.get_index_type(), IndexType::Hnsw);
}
}