rag_plusplus_core/
api.rs

1//! RAG++ Public API Contracts
2//!
3//! This module defines the formal input/output contracts for RAG++.
4//! These types form the stable interface for external consumers.
5
6use crate::error::Result;
7use crate::stats::OutcomeStats;
8use crate::types::{MemoryRecord, RecordId};
9use std::collections::HashMap;
10use std::time::Duration;
11
12// ============================================================================
13// INPUT CONTRACTS
14// ============================================================================
15
16/// Query request to the RAG++ retrieval engine.
17///
18/// This is the primary input contract for retrieval operations.
19#[derive(Debug, Clone)]
20pub struct RetrievalRequest {
21    /// Query embedding vector (must match index dimension)
22    pub embedding: Vec<f32>,
23
24    /// Number of candidates to retrieve
25    pub k: usize,
26
27    /// Optional metadata filter expression
28    pub filter: Option<FilterExpression>,
29
30    /// Optional: specific indexes to search (None = search all)
31    pub index_names: Option<Vec<String>>,
32
33    /// Whether to compute priors from retrieved records
34    pub compute_priors: bool,
35
36    /// Optional timeout for the query
37    pub timeout: Option<Duration>,
38}
39
40impl RetrievalRequest {
41    /// Create a simple retrieval request.
42    #[must_use]
43    pub fn new(embedding: Vec<f32>, k: usize) -> Self {
44        Self {
45            embedding,
46            k,
47            filter: None,
48            index_names: None,
49            compute_priors: true,
50            timeout: None,
51        }
52    }
53
54    /// Add a metadata filter.
55    #[must_use]
56    pub fn with_filter(mut self, filter: FilterExpression) -> Self {
57        self.filter = Some(filter);
58        self
59    }
60
61    /// Specify which indexes to search.
62    #[must_use]
63    pub fn with_indexes(mut self, names: Vec<String>) -> Self {
64        self.index_names = Some(names);
65        self
66    }
67
68    /// Set query timeout.
69    #[must_use]
70    pub fn with_timeout(mut self, timeout: Duration) -> Self {
71        self.timeout = Some(timeout);
72        self
73    }
74
75    /// Validate the request.
76    pub fn validate(&self, expected_dim: usize) -> Result<()> {
77        if self.embedding.len() != expected_dim {
78            return Err(crate::error::Error::DimensionMismatch {
79                expected: expected_dim,
80                got: self.embedding.len(),
81            });
82        }
83        if self.k == 0 {
84            return Err(crate::error::Error::InvalidQuery {
85                reason: "k must be greater than 0".into(),
86            });
87        }
88        if self.embedding.iter().any(|x| !x.is_finite()) {
89            return Err(crate::error::Error::InvalidQuery {
90                reason: "embedding contains NaN or Inf".into(),
91            });
92        }
93        Ok(())
94    }
95}
96
97/// Simplified filter expression for the public API.
98#[derive(Debug, Clone)]
99pub enum FilterExpression {
100    /// Field equals value
101    Eq(String, FilterValue),
102    /// Field not equals value
103    Ne(String, FilterValue),
104    /// Field greater than value
105    Gt(String, FilterValue),
106    /// Field greater than or equal
107    Gte(String, FilterValue),
108    /// Field less than value
109    Lt(String, FilterValue),
110    /// Field less than or equal
111    Lte(String, FilterValue),
112    /// Field in set of values
113    In(String, Vec<FilterValue>),
114    /// Logical AND of expressions
115    And(Vec<FilterExpression>),
116    /// Logical OR of expressions
117    Or(Vec<FilterExpression>),
118    /// Logical NOT
119    Not(Box<FilterExpression>),
120}
121
122/// Filter value types.
123#[derive(Debug, Clone)]
124pub enum FilterValue {
125    String(String),
126    Int(i64),
127    Float(f64),
128    Bool(bool),
129}
130
131/// Record to be ingested into the RAG++ corpus.
132#[derive(Debug, Clone)]
133pub struct IngestRecord {
134    /// Unique identifier (must be unique within corpus)
135    pub id: String,
136
137    /// Embedding vector
138    pub embedding: Vec<f32>,
139
140    /// Human-readable context description
141    pub context: String,
142
143    /// Primary outcome metric
144    pub outcome: f64,
145
146    /// Arbitrary metadata
147    pub metadata: HashMap<String, MetadataValue>,
148}
149
150/// Metadata value for ingestion.
151#[derive(Debug, Clone)]
152pub enum MetadataValue {
153    String(String),
154    Int(i64),
155    Float(f64),
156    Bool(bool),
157    StringList(Vec<String>),
158}
159
160// ============================================================================
161// OUTPUT CONTRACTS
162// ============================================================================
163
164/// Response from a RAG++ retrieval query.
165///
166/// This is the primary output contract for retrieval operations.
167#[derive(Debug, Clone)]
168pub struct RetrievalResponse {
169    /// Statistical priors computed from retrieved records
170    pub prior: PriorBundle,
171
172    /// Ranked candidates with scores
173    pub candidates: Vec<RankedCandidate>,
174
175    /// Query execution latency
176    pub latency: Duration,
177
178    /// Which indexes were searched
179    pub indexes_searched: Vec<String>,
180
181    /// Total number of records considered
182    pub records_scanned: usize,
183
184    /// Whether the query hit the cache
185    pub cache_hit: bool,
186}
187
188/// Statistical priors from retrieved trajectories.
189///
190/// This is the core value proposition of RAG++ - surfacing implicit
191/// knowledge from past execution outcomes as queryable statistics.
192#[derive(Debug, Clone, Default)]
193pub struct PriorBundle {
194    /// Mean outcome of retrieved trajectories
195    pub mean: Option<f64>,
196
197    /// Variance of outcomes
198    pub variance: Option<f64>,
199
200    /// Standard deviation
201    pub std_dev: Option<f64>,
202
203    /// Confidence in the estimate (0-1, based on sample count)
204    pub confidence: f64,
205
206    /// Number of samples contributing to statistics
207    pub count: u64,
208
209    /// Minimum observed outcome
210    pub min: Option<f64>,
211
212    /// Maximum observed outcome
213    pub max: Option<f64>,
214
215    /// Weighted mean (by retrieval score)
216    pub weighted_mean: Option<f64>,
217}
218
219impl PriorBundle {
220    /// Create from outcome statistics.
221    #[must_use]
222    pub fn from_stats(stats: &OutcomeStats) -> Self {
223        let count = stats.count();
224        let confidence = Self::compute_confidence(count);
225
226        Self {
227            mean: stats.mean_scalar(),
228            variance: stats.variance_scalar(),
229            std_dev: stats.std_scalar(),
230            confidence,
231            count,
232            min: stats.min().and_then(|m| m.first().copied().map(f64::from)),
233            max: stats.max().and_then(|m| m.first().copied().map(f64::from)),
234            weighted_mean: None,
235        }
236    }
237
238    /// Create from a set of outcomes with optional weights.
239    #[must_use]
240    pub fn from_outcomes(outcomes: &[f64], weights: Option<&[f64]>) -> Self {
241        if outcomes.is_empty() {
242            return Self::default();
243        }
244
245        let count = outcomes.len() as u64;
246        let confidence = Self::compute_confidence(count);
247
248        // Simple statistics
249        let mean = outcomes.iter().sum::<f64>() / outcomes.len() as f64;
250        let variance = if outcomes.len() > 1 {
251            let sum_sq: f64 = outcomes.iter().map(|x| (x - mean).powi(2)).sum();
252            Some(sum_sq / (outcomes.len() - 1) as f64)
253        } else {
254            None
255        };
256        let std_dev = variance.map(|v| v.sqrt());
257        let min = outcomes.iter().copied().fold(f64::INFINITY, f64::min);
258        let max = outcomes.iter().copied().fold(f64::NEG_INFINITY, f64::max);
259
260        // Weighted mean
261        let weighted_mean = weights.map(|w| {
262            let total_weight: f64 = w.iter().sum();
263            if total_weight > 0.0 {
264                outcomes
265                    .iter()
266                    .zip(w.iter())
267                    .map(|(o, w)| o * w)
268                    .sum::<f64>()
269                    / total_weight
270            } else {
271                mean
272            }
273        });
274
275        Self {
276            mean: Some(mean),
277            variance,
278            std_dev,
279            confidence,
280            count,
281            min: Some(min),
282            max: Some(max),
283            weighted_mean,
284        }
285    }
286
287    /// Compute confidence based on sample count.
288    ///
289    /// Uses a logistic function that approaches 1.0 as count increases.
290    fn compute_confidence(count: u64) -> f64 {
291        if count == 0 {
292            return 0.0;
293        }
294        // Logistic: 1 / (1 + e^(-k(x-x0)))
295        // Tuned so: count=5 -> ~0.5, count=20 -> ~0.9, count=50 -> ~0.99
296        let k = 0.15;
297        let x0 = 10.0;
298        1.0 / (1.0 + (-(k * (count as f64 - x0))).exp())
299    }
300
301    /// Whether the prior has enough samples to be reliable.
302    #[must_use]
303    pub fn is_reliable(&self) -> bool {
304        self.confidence >= 0.8
305    }
306
307    /// Whether any statistics are available.
308    #[must_use]
309    pub fn is_empty(&self) -> bool {
310        self.count == 0
311    }
312}
313
314/// A ranked candidate from retrieval.
315#[derive(Debug, Clone)]
316pub struct RankedCandidate {
317    /// Record identifier
318    pub record_id: String,
319
320    /// Retrieval score (higher = more relevant)
321    pub score: f64,
322
323    /// Raw distance from query
324    pub distance: f64,
325
326    /// Rank position (1-indexed)
327    pub rank: u32,
328
329    /// Outcome value from the record
330    pub outcome: f64,
331
332    /// Record context string
333    pub context: String,
334}
335
336// ============================================================================
337// TRAIT DEFINITIONS
338// ============================================================================
339
340/// Core trait for RAG++ retrieval engines.
341///
342/// Implementations must provide thread-safe retrieval operations.
343pub trait RetrievalEngine: Send + Sync {
344    /// Execute a retrieval query.
345    fn query(&self, request: &RetrievalRequest) -> Result<RetrievalResponse>;
346
347    /// Get the embedding dimension.
348    fn dimension(&self) -> usize;
349
350    /// Get the number of records in the corpus.
351    fn corpus_size(&self) -> usize;
352
353    /// Get available index names.
354    fn index_names(&self) -> Vec<String>;
355}
356
357/// Trait for record storage.
358pub trait Corpus: Send + Sync {
359    /// Ingest a record into the corpus.
360    fn ingest(&mut self, record: IngestRecord) -> Result<RecordId>;
361
362    /// Ingest multiple records.
363    fn ingest_batch(&mut self, records: Vec<IngestRecord>) -> Result<Vec<RecordId>>;
364
365    /// Update outcome statistics for a record.
366    fn update_outcome(&mut self, id: &RecordId, outcome: f64) -> Result<()>;
367
368    /// Remove a record from the corpus.
369    fn remove(&mut self, id: &RecordId) -> Result<bool>;
370
371    /// Get a record by ID.
372    fn get(&self, id: &RecordId) -> Option<MemoryRecord>;
373
374    /// Get corpus size.
375    fn size(&self) -> usize;
376}
377
378/// Trait for vector indexes.
379pub trait VectorSearcher: Send + Sync {
380    /// Add a vector to the index.
381    fn add(&mut self, id: &str, vector: &[f32]) -> Result<()>;
382
383    /// Search for nearest neighbors.
384    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchHit>>;
385
386    /// Remove a vector from the index.
387    fn remove(&mut self, id: &str) -> Result<bool>;
388
389    /// Get the dimension.
390    fn dimension(&self) -> usize;
391
392    /// Get the number of vectors.
393    fn len(&self) -> usize;
394
395    /// Check if empty.
396    fn is_empty(&self) -> bool {
397        self.len() == 0
398    }
399}
400
401/// A search hit from vector search.
402#[derive(Debug, Clone)]
403pub struct SearchHit {
404    /// Record ID
405    pub id: String,
406    /// Distance from query
407    pub distance: f32,
408    /// Score (typically 1 / (1 + distance) or similar)
409    pub score: f32,
410}
411
412// ============================================================================
413// BUILDER PATTERN
414// ============================================================================
415
416/// Builder for constructing RAG++ instances.
417#[derive(Debug, Clone)]
418pub struct RAGBuilder {
419    dimension: usize,
420    index_type: IndexType,
421    cache_enabled: bool,
422    cache_size: usize,
423    default_k: usize,
424}
425
426/// Index type selection.
427#[derive(Debug, Clone, Copy, PartialEq, Eq)]
428pub enum IndexType {
429    /// Exact search (brute force)
430    Flat,
431    /// Approximate search (HNSW)
432    Hnsw,
433}
434
435impl Default for RAGBuilder {
436    fn default() -> Self {
437        Self {
438            dimension: 512,
439            index_type: IndexType::Flat,
440            cache_enabled: true,
441            cache_size: 10000,
442            default_k: 10,
443        }
444    }
445}
446
447impl RAGBuilder {
448    /// Create a new builder with the given embedding dimension.
449    #[must_use]
450    pub fn new(dimension: usize) -> Self {
451        Self {
452            dimension,
453            ..Default::default()
454        }
455    }
456
457    /// Set the index type.
458    #[must_use]
459    pub fn index_type(mut self, index_type: IndexType) -> Self {
460        self.index_type = index_type;
461        self
462    }
463
464    /// Enable or disable caching.
465    #[must_use]
466    pub fn cache(mut self, enabled: bool) -> Self {
467        self.cache_enabled = enabled;
468        self
469    }
470
471    /// Set cache size.
472    #[must_use]
473    pub fn cache_size(mut self, size: usize) -> Self {
474        self.cache_size = size;
475        self
476    }
477
478    /// Set default k for queries.
479    #[must_use]
480    pub fn default_k(mut self, k: usize) -> Self {
481        self.default_k = k;
482        self
483    }
484
485    /// Get the configured dimension.
486    #[must_use]
487    pub fn get_dimension(&self) -> usize {
488        self.dimension
489    }
490
491    /// Get the configured index type.
492    #[must_use]
493    pub fn get_index_type(&self) -> IndexType {
494        self.index_type
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_retrieval_request_validation() {
504        let valid = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 10);
505        assert!(valid.validate(3).is_ok());
506
507        let wrong_dim = RetrievalRequest::new(vec![1.0, 2.0], 10);
508        assert!(wrong_dim.validate(3).is_err());
509
510        let zero_k = RetrievalRequest::new(vec![1.0, 2.0, 3.0], 0);
511        assert!(zero_k.validate(3).is_err());
512
513        let nan = RetrievalRequest::new(vec![1.0, f32::NAN, 3.0], 10);
514        assert!(nan.validate(3).is_err());
515    }
516
517    #[test]
518    fn test_prior_bundle_from_outcomes() {
519        let outcomes = vec![0.8, 0.9, 0.7, 0.85];
520        let prior = PriorBundle::from_outcomes(&outcomes, None);
521
522        assert!(prior.mean.is_some());
523        assert!((prior.mean.unwrap() - 0.8125).abs() < 1e-6);
524        assert_eq!(prior.count, 4);
525        assert!(prior.confidence > 0.0);
526    }
527
528    #[test]
529    fn test_prior_bundle_empty() {
530        let prior = PriorBundle::from_outcomes(&[], None);
531        assert!(prior.is_empty());
532        assert!(!prior.is_reliable());
533    }
534
535    #[test]
536    fn test_prior_bundle_weighted() {
537        let outcomes = vec![1.0, 0.0];
538        let weights = vec![0.8, 0.2];
539        let prior = PriorBundle::from_outcomes(&outcomes, Some(&weights));
540
541        // Weighted mean: (1.0 * 0.8 + 0.0 * 0.2) / 1.0 = 0.8
542        assert!(prior.weighted_mean.is_some());
543        assert!((prior.weighted_mean.unwrap() - 0.8).abs() < 1e-6);
544    }
545
546    #[test]
547    fn test_confidence_scaling() {
548        assert!(PriorBundle::compute_confidence(0) == 0.0);
549        assert!(PriorBundle::compute_confidence(5) > 0.3);
550        assert!(PriorBundle::compute_confidence(20) > 0.8);
551        assert!(PriorBundle::compute_confidence(100) > 0.99);
552    }
553
554    #[test]
555    fn test_builder() {
556        let builder = RAGBuilder::new(768)
557            .index_type(IndexType::Hnsw)
558            .cache(true)
559            .cache_size(5000)
560            .default_k(20);
561
562        assert_eq!(builder.get_dimension(), 768);
563        assert_eq!(builder.get_index_type(), IndexType::Hnsw);
564    }
565}