Skip to main content

oxirag/
lib.rs

1//! `OxiRAG` - A three-layer RAG engine with SMT-based logic verification.
2//!
3//! `OxiRAG` provides a robust Retrieval-Augmented Generation (RAG) pipeline with:
4//!
5//! - **Layer 1 (Echo)**: Semantic search using vector embeddings
6//! - **Layer 2 (Speculator)**: Draft verification using small language models
7//! - **Layer 3 (Judge)**: Logic verification using SMT solvers
8//!
9//! # Quick Start
10//!
11//! ```rust,ignore
12//! use oxirag::prelude::*;
13//!
14//! #[tokio::main]
15//! async fn main() -> Result<(), OxiRagError> {
16//!     // Create the Echo layer with mock embedding provider
17//!     let echo = EchoLayer::new(
18//!         MockEmbeddingProvider::new(384),
19//!         InMemoryVectorStore::new(384),
20//!     );
21//!
22//!     // Create the Speculator layer
23//!     let speculator = RuleBasedSpeculator::default();
24//!
25//!     // Create the Judge layer
26//!     let judge = JudgeImpl::new(
27//!         AdvancedClaimExtractor::new(),
28//!         MockSmtVerifier::default(),
29//!         JudgeConfig::default(),
30//!     );
31//!
32//!     // Build the pipeline
33//!     let mut pipeline = PipelineBuilder::new()
34//!         .with_echo(echo)
35//!         .with_speculator(speculator)
36//!         .with_judge(judge)
37//!         .build()?;
38//!
39//!     // Index documents
40//!     pipeline.index(Document::new("The capital of France is Paris.")).await?;
41//!
42//!     // Query the pipeline
43//!     let query = Query::new("What is the capital of France?");
44//!     let result = pipeline.process(query).await?;
45//!
46//!     println!("Answer: {}", result.final_answer);
47//!     println!("Confidence: {:.2}", result.confidence);
48//!
49//!     Ok(())
50//! }
51//! ```
52//!
53//! # Features
54//!
55//! - `echo` (default): Enable Layer 1 with numrs2 for SIMD similarity
56//! - `speculator` (default): Enable Layer 2 with Candle for SLM inference
57//! - `judge` (default): Enable Layer 3 with `OxiZ` for SMT verification
58//! - `cuda`: Enable CUDA acceleration for Candle models
59//! - `metal`: Enable Metal acceleration for Candle models
60//!
61//! # Architecture
62//!
63//! ```text
64//! Query
65//!   │
66//!   ▼
67//! ┌─────────────────┐
68//! │  Layer 1: Echo  │  ← Semantic search with embeddings
69//! │  (Vector Store) │
70//! └────────┬────────┘
71//!          │
72//!          ▼
73//! ┌─────────────────────┐
74//! │ Layer 2: Speculator │  ← Draft verification with SLM
75//! │  (Draft Checker)    │
76//! └─────────┬───────────┘
77//!           │
78//!           ▼
79//! ┌─────────────────┐
80//! │  Layer 3: Judge │  ← Logic verification with SMT
81//! │  (SMT Solver)   │
82//! └────────┬────────┘
83//!          │
84//!          ▼
85//!       Response
86//! ```
87
88#![warn(missing_docs)]
89#![warn(clippy::all)]
90#![warn(clippy::pedantic)]
91#![allow(clippy::module_name_repetitions)]
92#![allow(unexpected_cfgs)]
93
94pub mod circuit_breaker;
95pub mod config;
96pub mod connection_pool;
97#[cfg(feature = "distillation")]
98pub mod distillation;
99pub mod error;
100#[cfg(feature = "hidden-states")]
101pub mod hidden_states;
102pub mod hybrid_search;
103pub mod index_management;
104pub mod layer1_echo;
105pub mod layer2_speculator;
106pub mod layer3_judge;
107#[cfg(feature = "graphrag")]
108pub mod layer4_graph;
109#[cfg(feature = "native")]
110pub mod load_testing;
111pub mod memory;
112pub mod metrics;
113pub mod pipeline;
114pub mod pipeline_debug;
115#[cfg(feature = "prefix-cache")]
116pub mod prefix_cache;
117#[cfg(feature = "quantization")]
118pub mod quantization;
119pub mod query_builder;
120pub mod query_expansion;
121pub mod relevance_feedback;
122pub mod reranker;
123pub mod retry;
124pub mod simd_similarity;
125pub mod streaming;
126pub mod types;
127
128#[cfg(feature = "wasm")]
129pub mod wasm;
130
131/// Convenient re-exports for common usage.
132pub mod prelude {
133    pub use crate::circuit_breaker::{
134        CircuitBreaker, CircuitBreakerConfig, CircuitBreakerOrOperationError,
135        CircuitBreakerRegistry, CircuitBreakerStats, CircuitPermit, CircuitState,
136        with_circuit_breaker, with_service_circuit_breaker,
137    };
138    pub use crate::config::{
139        EchoConfig, JudgeConfig as JudgeCfg, OxiRagConfig, PipelineConfig as PipelineCfg,
140        RetryConfig, SimilarityMetricConfig, SpeculatorConfig as SpeculatorCfg,
141    };
142    pub use crate::connection_pool::{
143        Connection, ConnectionError, ConnectionPool, MockConnection, PoolConfig, PoolError,
144        PoolStats, PooledConnection,
145    };
146    pub use crate::error::{
147        EmbeddingError, JudgeError, OxiRagError, PipelineError, SpeculatorError, VectorStoreError,
148    };
149    pub use crate::layer1_echo::{
150        Echo, EchoLayer, EmbeddingProvider, InMemoryVectorStore, IndexedDocument, MetadataFilter,
151        MockEmbeddingProvider, SimilarityMetric, VectorStore,
152    };
153    pub use crate::layer2_speculator::{RuleBasedSpeculator, Speculator, SpeculatorConfig};
154    pub use crate::layer3_judge::{
155        AdvancedClaimExtractor, ClaimExtractor, Judge, JudgeConfig, JudgeImpl, MockSmtVerifier,
156        SmtVerifier,
157    };
158    pub use crate::memory::{
159        MemoryBreakdown, MemoryBudget, MemoryComponent, MemoryError, MemoryGuard, MemoryMonitor,
160        MemoryStats,
161    };
162    pub use crate::metrics::{LayerTiming, MetricsCollector, PipelineMetrics, TimedOperation};
163    pub use crate::pipeline::{Pipeline, PipelineBuilder, PipelineConfig, RagPipeline};
164    pub use crate::query_builder::{ExtendedQuery, LayerHints, QueryBuilder};
165    pub use crate::retry::RetryPolicy;
166    pub use crate::simd_similarity::{
167        SimdBackend, SimilarityEngine, detect_backend, simd_batch_cosine, simd_cosine_similarity,
168        simd_dot_product, simd_euclidean_distance, simd_l2_norm,
169    };
170    pub use crate::types::{
171        ClaimStructure, ClaimVerificationResult, ComparisonOp, Document, DocumentId, Draft,
172        LogicalClaim, PipelineOutput, Quantifier, Query, SearchResult, SpeculationDecision,
173        SpeculationResult, VerificationResult, VerificationStatus,
174    };
175
176    // Index management exports
177    pub use crate::index_management::{
178        IndexManagement, IndexManager, IndexSnapshot, IndexStats, MergeResult, OptimizeConfig,
179        OptimizeResult, SerializedDocument, SerializedIndex, VacuumResult,
180    };
181
182    // Streaming pipeline exports
183    #[cfg(feature = "native")]
184    pub use crate::streaming::ProgressReporter;
185    pub use crate::streaming::{
186        ChunkMetadata, ChunkType, PipelineChunk, StreamingPipeline, StreamingPipelineResult,
187        StreamingPipelineWrapper,
188    };
189
190    #[cfg(feature = "speculator")]
191    pub use crate::layer1_echo::CandleEmbeddingProvider;
192    #[cfg(feature = "speculator")]
193    pub use crate::layer2_speculator::CandleSlmSpeculator;
194    #[cfg(feature = "judge")]
195    pub use crate::layer3_judge::OxizVerifier;
196
197    // Graph layer exports
198    #[cfg(feature = "graphrag")]
199    pub use crate::config::GraphConfig;
200    #[cfg(feature = "graphrag")]
201    pub use crate::error::GraphError;
202    #[cfg(feature = "graphrag")]
203    pub use crate::layer4_graph::{
204        Direction, EntityExtractor, EntityId, EntityType, Graph, GraphEntity, GraphLayer,
205        GraphLayerBuilder, GraphPath, GraphQuery, GraphRelationship, GraphStore,
206        HybridSearchResult, InMemoryGraphStore, MockEntityExtractor, MockRelationshipExtractor,
207        PatternEntityExtractor, PatternRelationshipExtractor, RelationshipExtractor,
208        RelationshipType, bfs_traverse, find_entities_within_hops, find_shortest_path,
209    };
210    #[cfg(feature = "graphrag")]
211    pub use crate::query_builder::GraphContext;
212
213    // Distillation layer exports
214    #[cfg(feature = "distillation")]
215    pub use crate::distillation::{
216        CandidateDetector, CandidateEvaluation, CollectorStatistics, DistillationCandidate,
217        DistillationConfig, DistillationStats, DistillationTracker, InMemoryDistillationTracker,
218        NearReadyReason, QAPair, QAPairCollector, QueryFrequencyTracker, QueryPattern,
219        TrainingExample,
220    };
221    #[cfg(feature = "distillation")]
222    pub use crate::error::DistillationError;
223
224    // Prefix cache exports
225    #[cfg(feature = "prefix-cache")]
226    pub use crate::error::PrefixCacheError;
227    #[cfg(feature = "prefix-cache")]
228    pub use crate::prefix_cache::{
229        CacheKey, CacheLookupResult, CacheStats, ContextFingerprint, ContextFingerprintGenerator,
230        Fingerprintable, InMemoryPrefixCache, KVCacheEntry, PrefixCacheConfig, PrefixCacheExt,
231        PrefixCacheStore, RollingHasher,
232    };
233
234    // Hidden states exports
235    #[cfg(feature = "hidden-states")]
236    pub use crate::error::HiddenStateError;
237    #[cfg(feature = "hidden-states")]
238    pub use crate::hidden_states::{
239        AdaptiveReuseStrategy, CachedHiddenState, DType, Device, HiddenStateCache,
240        HiddenStateCacheConfig, HiddenStateCacheStats, HiddenStateConfig, HiddenStateProvider,
241        HiddenStateProviderExt, HiddenStateTensor, HybridReuseStrategy, KVCache, LayerExtractor,
242        LayerHiddenState, LengthAwareReuseStrategy, MockHiddenStateProvider, ModelHiddenStates,
243        ModelKVCache, PrefixReuseStrategy, SemanticReuseStrategy, StatePooling, StateReuseStrategy,
244        StateSimilarity, TensorShape,
245    };
246
247    // Load testing exports
248    #[cfg(feature = "native")]
249    pub use crate::load_testing::{
250        LoadTest, LoadTestBuilder, LoadTestConfig, LoadTestResult, LoadTestStats,
251        MockQueryExecutor, MockQueryGenerator, QueryExecutor, QueryGenerator, RequestResult,
252    };
253
254    // Reranker exports
255    pub use crate::reranker::{
256        CrossEncoderReranker, FusionStrategy, HybridReranker, KeywordReranker,
257        MockCrossEncoderReranker, MockReranker, Reranker, RerankerConfig, RerankerPipeline,
258        RerankerPipelineBuilder, SemanticReranker,
259    };
260
261    // Pipeline debug exports
262    pub use crate::pipeline_debug::{
263        DebugConfig, GanttTraceFormatter, JsonTraceFormatter, LayerTraceGuard,
264        MermaidTraceFormatter, PipelineDebugger, PipelineTrace, SharedPipelineDebugger,
265        TextTraceFormatter, TraceEntry, TraceFormatter, TraceId, create_shared_debugger,
266    };
267
268    // Relevance feedback exports
269    pub use crate::relevance_feedback::{
270        FeedbackAdjuster, FeedbackConfig, FeedbackEntry, FeedbackStore, InMemoryFeedbackStore,
271        RelevanceFeedback, RelevanceModel, RocchioFeedbackAdjuster, SimpleBoostAdjuster,
272    };
273
274    // Query expansion exports
275    pub use crate::query_expansion::{
276        CompositeExpander, ExpandedQuery, ExpansionConfig, ExpansionMethod, NGramExpander,
277        PseudoRelevanceFeedback, QueryExpander, QueryReformulator, StemExpander, SynonymExpander,
278    };
279
280    // Hybrid search exports
281    pub use crate::hybrid_search::{
282        BM25Encoder, BM25Params, FusionStrategy as HybridFusionStrategy, HybridConfig,
283        HybridResult, HybridSearcher, InMemorySparseStore, SparseVector, SparseVectorStore,
284    };
285
286    // Quantization exports
287    #[cfg(feature = "quantization")]
288    pub use crate::quantization::{
289        BinaryQuantizer, Int4Quantizer, Int8Quantizer, MockQuantizedVectorStore,
290        QuantizationConfig, QuantizationType, QuantizedDocument, QuantizedTensor,
291        QuantizedVectorStore, Quantizer, compute_quantization_error, compute_snr_db,
292        hamming_distance, int4_dot_product, int8_dot_product,
293    };
294}
295
296pub use error::{OxiRagError, Result};
297
298#[cfg(test)]
299mod tests {
300    use super::prelude::*;
301
302    #[tokio::test]
303    async fn test_full_pipeline_integration() {
304        // Create all layers with mock implementations
305        let echo = EchoLayer::new(MockEmbeddingProvider::new(64), InMemoryVectorStore::new(64));
306
307        let speculator = RuleBasedSpeculator::default();
308
309        let judge = JudgeImpl::new(
310            AdvancedClaimExtractor::new(),
311            MockSmtVerifier::default(),
312            JudgeConfig::default(),
313        );
314
315        // Build pipeline
316        let mut pipeline = PipelineBuilder::new()
317            .with_echo(echo)
318            .with_speculator(speculator)
319            .with_judge(judge)
320            .with_config(PipelineConfig {
321                enable_fast_path: false,
322                ..Default::default()
323            })
324            .build()
325            .expect("Failed to build pipeline");
326
327        // Index some documents
328        let documents = vec![
329            Document::new(
330                "Rust is a systems programming language focused on safety and performance.",
331            ),
332            Document::new("The Rust compiler prevents data races at compile time."),
333            Document::new("Cargo is Rust's package manager and build system."),
334        ];
335
336        pipeline
337            .index_batch(documents)
338            .await
339            .expect("Failed to index documents");
340
341        // Process a query
342        let query = Query::new("What is Rust?").with_top_k(3);
343        let result = pipeline
344            .process(query)
345            .await
346            .expect("Failed to process query");
347
348        // Verify results
349        assert!(
350            !result.search_results.is_empty(),
351            "Should have search results"
352        );
353        assert!(
354            !result.final_answer.is_empty(),
355            "Should have a final answer"
356        );
357        assert!(result.confidence > 0.0, "Should have positive confidence");
358        assert!(
359            result.layers_used.len() >= 2,
360            "Should use at least Echo and Speculator"
361        );
362    }
363
364    #[tokio::test]
365    async fn test_document_lifecycle() {
366        let mut echo = EchoLayer::new(MockEmbeddingProvider::new(32), InMemoryVectorStore::new(32));
367
368        // Index
369        let doc = Document::new("Test document content").with_title("Test");
370        let id = echo.index(doc).await.expect("Failed to index");
371
372        // Retrieve
373        let retrieved = echo
374            .get(&id)
375            .await
376            .expect("Failed to get")
377            .expect("Document not found");
378        assert_eq!(retrieved.title, Some("Test".to_string()));
379
380        // Search
381        let results = echo
382            .search("test document", 5, None)
383            .await
384            .expect("Failed to search");
385        assert!(!results.is_empty());
386
387        // Delete
388        let deleted = echo.delete(&id).await.expect("Failed to delete");
389        assert!(deleted);
390
391        // Verify deleted
392        let retrieved = echo.get(&id).await.expect("Failed to get");
393        assert!(retrieved.is_none());
394    }
395
396    #[tokio::test]
397    async fn test_query_filtering() {
398        let mut echo = EchoLayer::new(MockEmbeddingProvider::new(32), InMemoryVectorStore::new(32));
399
400        echo.index(Document::new("High relevance content"))
401            .await
402            .unwrap();
403        echo.index(Document::new("Medium relevance")).await.unwrap();
404        echo.index(Document::new("Low relevance")).await.unwrap();
405
406        // Search with min_score filter
407        let results = echo
408            .search("high relevance", 10, Some(0.8))
409            .await
410            .expect("Failed to search");
411
412        // Results should be filtered by score
413        for result in &results {
414            assert!(result.score >= 0.8);
415        }
416    }
417
418    #[test]
419    fn test_types_serialization() {
420        let doc = Document::new("Test content")
421            .with_title("Title")
422            .with_metadata("key", "value");
423
424        let json = serde_json::to_string(&doc).expect("Failed to serialize");
425        let parsed: Document = serde_json::from_str(&json).expect("Failed to deserialize");
426
427        assert_eq!(parsed.content, doc.content);
428        assert_eq!(parsed.title, doc.title);
429    }
430
431    #[test]
432    fn test_claim_structure_smtlib() {
433        let claim = LogicalClaim::new(
434            "test",
435            ClaimStructure::Comparison {
436                left: "a".to_string(),
437                operator: ComparisonOp::GreaterThan,
438                right: "b".to_string(),
439            },
440        );
441
442        let extractor = AdvancedClaimExtractor::new();
443        let smt = extractor
444            .to_smtlib(&claim)
445            .expect("Failed to generate SMT-LIB");
446
447        assert!(smt.contains("assert"));
448        assert!(smt.contains('>'));
449    }
450}