use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Instant;
use crate::error::OxiRagError;
use crate::pipeline::RagPipeline;
use crate::types::{PipelineOutput, Query, SpeculationDecision};
#[cfg(feature = "native")]
use std::pin::Pin;
#[cfg(feature = "native")]
use futures::Stream;
#[cfg(feature = "native")]
use tokio::sync::mpsc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineChunk {
pub chunk_id: usize,
pub chunk_type: ChunkType,
pub content: String,
pub metadata: ChunkMetadata,
}
impl PipelineChunk {
#[must_use]
pub fn new(chunk_id: usize, chunk_type: ChunkType, content: impl Into<String>) -> Self {
Self {
chunk_id,
chunk_type,
content: content.into(),
metadata: ChunkMetadata::default(),
}
}
#[must_use]
pub fn with_metadata(mut self, metadata: ChunkMetadata) -> Self {
self.metadata = metadata;
self
}
#[must_use]
pub fn with_timestamp(mut self, timestamp_ms: u64) -> Self {
self.metadata.timestamp_ms = timestamp_ms;
self
}
#[must_use]
pub fn with_layer(mut self, layer: impl Into<String>) -> Self {
self.metadata.layer = Some(layer.into());
self
}
#[must_use]
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.metadata.confidence = Some(confidence);
self
}
#[must_use]
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.metadata.duration_ms = Some(duration_ms);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ChunkType {
SearchStarted,
SearchResult {
rank: usize,
score: f32,
},
SearchCompleted {
total: usize,
},
DraftGenerated,
SpeculationStarted,
SpeculationProgress {
stage: String,
confidence: f32,
},
SpeculationDecision(SpeculationDecision),
VerificationStarted,
ClaimExtracted {
claim_id: usize,
},
ClaimVerified {
claim_id: usize,
status: String,
},
VerificationCompleted,
FinalAnswer,
Error(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChunkMetadata {
pub timestamp_ms: u64,
pub layer: Option<String>,
pub confidence: Option<f32>,
pub duration_ms: Option<u64>,
}
impl ChunkMetadata {
#[must_use]
pub fn new(timestamp_ms: u64) -> Self {
Self {
timestamp_ms,
layer: None,
confidence: None,
duration_ms: None,
}
}
#[must_use]
pub fn with_layer(mut self, layer: impl Into<String>) -> Self {
self.layer = Some(layer.into());
self
}
#[must_use]
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = Some(confidence);
self
}
#[must_use]
pub fn with_duration(mut self, duration_ms: u64) -> Self {
self.duration_ms = Some(duration_ms);
self
}
}
#[cfg(feature = "native")]
pub struct StreamingPipelineResult {
receiver: mpsc::Receiver<PipelineChunk>,
final_output: Option<PipelineOutput>,
collected_chunks: Vec<PipelineChunk>,
}
#[cfg(feature = "native")]
impl StreamingPipelineResult {
#[must_use]
pub fn new(receiver: mpsc::Receiver<PipelineChunk>) -> Self {
Self {
receiver,
final_output: None,
collected_chunks: Vec::new(),
}
}
#[must_use]
pub fn from_output(output: PipelineOutput) -> Self {
let (_, receiver) = mpsc::channel(1);
Self {
receiver,
final_output: Some(output),
collected_chunks: Vec::new(),
}
}
pub async fn next(&mut self) -> Option<PipelineChunk> {
let chunk = self.receiver.recv().await;
if let Some(ref c) = chunk {
self.collected_chunks.push(c.clone());
}
chunk
}
#[must_use]
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = PipelineChunk> + Send>> {
Box::pin(tokio_stream::wrappers::ReceiverStream::new(self.receiver))
}
pub async fn collect(mut self) -> Result<PipelineOutput, OxiRagError> {
if let Some(output) = self.final_output {
return Ok(output);
}
let mut last_error: Option<String> = None;
while let Some(chunk) = self.receiver.recv().await {
if let ChunkType::Error(ref err) = chunk.chunk_type {
last_error = Some(err.clone());
}
self.collected_chunks.push(chunk);
}
if let Some(err) = last_error {
return Err(OxiRagError::Pipeline(
crate::error::PipelineError::ExecutionError(err),
));
}
Err(OxiRagError::Pipeline(
crate::error::PipelineError::ExecutionError(
"Stream ended without final output".to_string(),
),
))
}
pub async fn for_each<F>(mut self, mut callback: F)
where
F: FnMut(PipelineChunk),
{
while let Some(chunk) = self.next().await {
callback(chunk);
}
}
#[must_use]
pub fn has_final_output(&self) -> bool {
self.final_output.is_some()
}
#[must_use]
pub fn collected_chunks(&self) -> &[PipelineChunk] {
&self.collected_chunks
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.collected_chunks.len()
}
}
#[cfg(not(feature = "native"))]
pub struct StreamingPipelineResult {
output: Option<PipelineOutput>,
error: Option<OxiRagError>,
}
#[cfg(not(feature = "native"))]
impl StreamingPipelineResult {
#[must_use]
pub fn from_output(output: PipelineOutput) -> Self {
Self {
output: Some(output),
error: None,
}
}
#[must_use]
pub fn from_error(error: OxiRagError) -> Self {
Self {
output: None,
error: Some(error),
}
}
pub async fn collect(self) -> Result<PipelineOutput, OxiRagError> {
match (self.output, self.error) {
(Some(output), _) => Ok(output),
(None, Some(err)) => Err(err),
(None, None) => Err(OxiRagError::Pipeline(
crate::error::PipelineError::ExecutionError("No output available".to_string()),
)),
}
}
}
#[async_trait]
pub trait StreamingPipeline: RagPipeline + Send + Sync {
async fn process_streaming(&self, query: Query)
-> Result<StreamingPipelineResult, OxiRagError>;
async fn process_batch_streaming(&self, queries: Vec<Query>) -> Vec<StreamingPipelineResult>;
}
pub struct StreamingPipelineWrapper<P: RagPipeline> {
inner: P,
chunk_buffer_size: usize,
}
impl<P: RagPipeline> StreamingPipelineWrapper<P> {
#[must_use]
pub fn new(pipeline: P) -> Self {
Self {
inner: pipeline,
chunk_buffer_size: 32,
}
}
#[must_use]
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.chunk_buffer_size = size.max(1);
self
}
#[must_use]
pub fn inner(&self) -> &P {
&self.inner
}
#[must_use]
pub fn inner_mut(&mut self) -> &mut P {
&mut self.inner
}
#[must_use]
pub fn buffer_size(&self) -> usize {
self.chunk_buffer_size
}
}
#[async_trait]
impl<P: RagPipeline + Send + Sync> RagPipeline for StreamingPipelineWrapper<P> {
async fn process(&self, query: Query) -> Result<PipelineOutput, OxiRagError> {
self.inner.process(query).await
}
async fn process_batch(&self, queries: Vec<Query>) -> Vec<Result<PipelineOutput, OxiRagError>> {
self.inner.process_batch(queries).await
}
async fn index(&mut self, document: crate::types::Document) -> Result<(), OxiRagError> {
self.inner.index(document).await
}
async fn index_batch(
&mut self,
documents: Vec<crate::types::Document>,
) -> Result<(), OxiRagError> {
self.inner.index_batch(documents).await
}
fn config(&self) -> &crate::pipeline::PipelineConfig {
self.inner.config()
}
}
#[cfg(feature = "native")]
#[inline]
fn elapsed_ms(start: &Instant) -> u64 {
#[allow(clippy::cast_possible_truncation)]
{
start.elapsed().as_millis().min(u128::from(u64::MAX)) as u64
}
}
#[cfg(feature = "native")]
async fn emit_search_chunks(
tx: &mpsc::Sender<PipelineChunk>,
output: &PipelineOutput,
chunk_id: &mut usize,
start: &Instant,
) {
for (rank, search_result) in output.search_results.iter().enumerate() {
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::SearchResult {
rank,
score: search_result.score,
},
truncate_content(&search_result.document.content, 100),
)
.with_layer("Echo")
.with_confidence(search_result.score)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
}
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::SearchCompleted {
total: output.search_results.len(),
},
format!("Found {} results", output.search_results.len()),
)
.with_layer("Echo")
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
}
#[cfg(feature = "native")]
async fn emit_speculation_chunks(
tx: &mpsc::Sender<PipelineChunk>,
speculation: &crate::types::SpeculationResult,
chunk_id: &mut usize,
start: &Instant,
) {
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::SpeculationStarted,
"Starting speculation verification",
)
.with_layer("Speculator")
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::SpeculationProgress {
stage: "verification".to_string(),
confidence: speculation.confidence,
},
&speculation.explanation,
)
.with_layer("Speculator")
.with_confidence(speculation.confidence)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::SpeculationDecision(speculation.decision.clone()),
format!("Decision: {:?}", speculation.decision),
)
.with_layer("Speculator")
.with_confidence(speculation.confidence)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
}
#[cfg(feature = "native")]
async fn emit_verification_chunks(
tx: &mpsc::Sender<PipelineChunk>,
verification: &crate::types::VerificationResult,
chunk_id: &mut usize,
start: &Instant,
) {
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::VerificationStarted,
"Starting logic verification",
)
.with_layer("Judge")
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
for (i, claim_result) in verification.claim_results.iter().enumerate() {
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::ClaimExtracted { claim_id: i },
&claim_result.claim.text,
)
.with_layer("Judge")
.with_confidence(claim_result.claim.confidence)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::ClaimVerified {
claim_id: i,
status: format!("{:?}", claim_result.status),
},
claim_result
.explanation
.as_deref()
.unwrap_or("No explanation"),
)
.with_layer("Judge")
.with_duration(claim_result.duration_ms)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
}
let chunk = PipelineChunk::new(
*chunk_id,
ChunkType::VerificationCompleted,
&verification.summary,
)
.with_layer("Judge")
.with_confidence(verification.confidence)
.with_duration(verification.total_duration_ms)
.with_timestamp(elapsed_ms(start));
let _ = tx.send(chunk).await;
*chunk_id += 1;
}
#[cfg(feature = "native")]
#[async_trait]
impl<P: RagPipeline + Send + Sync> StreamingPipeline for StreamingPipelineWrapper<P> {
async fn process_streaming(
&self,
query: Query,
) -> Result<StreamingPipelineResult, OxiRagError> {
let (tx, rx) = mpsc::channel(self.chunk_buffer_size);
let query_text = query.text.clone();
let result = self.inner.process(query.clone()).await;
let start = Instant::now();
tokio::spawn(async move {
let mut chunk_id = 0;
let chunk = PipelineChunk::new(
chunk_id,
ChunkType::SearchStarted,
format!("Searching for: {query_text}"),
)
.with_layer("Echo")
.with_timestamp(elapsed_ms(&start));
let _ = tx.send(chunk).await;
chunk_id += 1;
match result {
Ok(output) => {
emit_search_chunks(&tx, &output, &mut chunk_id, &start).await;
let chunk = PipelineChunk::new(
chunk_id,
ChunkType::DraftGenerated,
truncate_content(&output.draft.content, 200),
)
.with_confidence(output.draft.confidence)
.with_timestamp(elapsed_ms(&start));
let _ = tx.send(chunk).await;
chunk_id += 1;
if let Some(ref speculation) = output.speculation {
emit_speculation_chunks(&tx, speculation, &mut chunk_id, &start).await;
}
if let Some(ref verification) = output.verification {
emit_verification_chunks(&tx, verification, &mut chunk_id, &start).await;
}
let chunk =
PipelineChunk::new(chunk_id, ChunkType::FinalAnswer, &output.final_answer)
.with_confidence(output.confidence)
.with_duration(output.total_duration_ms)
.with_timestamp(elapsed_ms(&start));
let _ = tx.send(chunk).await;
}
Err(err) => {
let chunk = PipelineChunk::new(
chunk_id,
ChunkType::Error(err.to_string()),
format!("Pipeline error: {err}"),
)
.with_timestamp(elapsed_ms(&start));
let _ = tx.send(chunk).await;
}
}
});
Ok(StreamingPipelineResult::new(rx))
}
async fn process_batch_streaming(&self, queries: Vec<Query>) -> Vec<StreamingPipelineResult> {
use futures::future::join_all;
let futures: Vec<_> = queries
.into_iter()
.map(|q| async move {
match self.process_streaming(q).await {
Ok(result) => result,
Err(err) => {
let (tx, rx) = mpsc::channel(1);
let chunk = PipelineChunk::new(
0,
ChunkType::Error(err.to_string()),
format!("Failed to start streaming: {err}"),
);
let _ = tx.send(chunk).await;
StreamingPipelineResult::new(rx)
}
}
})
.collect();
join_all(futures).await
}
}
#[cfg(not(feature = "native"))]
#[async_trait]
impl<P: RagPipeline + Send + Sync> StreamingPipeline for StreamingPipelineWrapper<P> {
async fn process_streaming(
&self,
query: Query,
) -> Result<StreamingPipelineResult, OxiRagError> {
let result = self.inner.process(query).await;
match result {
Ok(output) => Ok(StreamingPipelineResult::from_output(output)),
Err(err) => Ok(StreamingPipelineResult::from_error(err)),
}
}
async fn process_batch_streaming(&self, queries: Vec<Query>) -> Vec<StreamingPipelineResult> {
let mut results = Vec::with_capacity(queries.len());
for query in queries {
match self.process_streaming(query).await {
Ok(result) => results.push(result),
Err(err) => results.push(StreamingPipelineResult::from_error(err)),
}
}
results
}
}
#[cfg(feature = "native")]
pub struct ProgressReporter {
sender: mpsc::Sender<PipelineChunk>,
chunk_counter: usize,
start_time: Instant,
}
#[cfg(feature = "native")]
impl ProgressReporter {
#[must_use]
pub fn new(sender: mpsc::Sender<PipelineChunk>) -> Self {
Self {
sender,
chunk_counter: 0,
start_time: Instant::now(),
}
}
fn timestamp_ms(&self) -> u64 {
elapsed_ms(&self.start_time)
}
fn next_chunk_id(&mut self) -> usize {
let id = self.chunk_counter;
self.chunk_counter += 1;
id
}
pub async fn report(&mut self, chunk_type: ChunkType, content: &str) {
let chunk = PipelineChunk::new(self.next_chunk_id(), chunk_type, content)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_search_result(&mut self, rank: usize, score: f32, doc_preview: &str) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SearchResult { rank, score },
doc_preview,
)
.with_layer("Echo")
.with_confidence(score)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_speculation(&mut self, stage: &str, confidence: f32) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SpeculationProgress {
stage: stage.to_string(),
confidence,
},
format!("Speculation stage: {stage}"),
)
.with_layer("Speculator")
.with_confidence(confidence)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_claim(&mut self, claim_id: usize, claim_text: &str) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::ClaimExtracted { claim_id },
claim_text,
)
.with_layer("Judge")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_error(&mut self, error: &str) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::Error(error.to_string()),
error,
)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_final(&mut self, answer: &str, confidence: f32) {
let chunk = PipelineChunk::new(self.next_chunk_id(), ChunkType::FinalAnswer, answer)
.with_confidence(confidence)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_search_started(&mut self, query: &str) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SearchStarted,
format!("Searching for: {query}"),
)
.with_layer("Echo")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_search_completed(&mut self, total: usize) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SearchCompleted { total },
format!("Found {total} results"),
)
.with_layer("Echo")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_draft(&mut self, draft_preview: &str, confidence: f32) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::DraftGenerated,
draft_preview,
)
.with_confidence(confidence)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_speculation_started(&mut self) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SpeculationStarted,
"Starting speculation verification",
)
.with_layer("Speculator")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_speculation_decision(
&mut self,
decision: SpeculationDecision,
confidence: f32,
) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::SpeculationDecision(decision.clone()),
format!("Decision: {decision:?}"),
)
.with_layer("Speculator")
.with_confidence(confidence)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_verification_started(&mut self) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::VerificationStarted,
"Starting logic verification",
)
.with_layer("Judge")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_claim_verified(
&mut self,
claim_id: usize,
status: &str,
explanation: &str,
) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::ClaimVerified {
claim_id,
status: status.to_string(),
},
explanation,
)
.with_layer("Judge")
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
pub async fn report_verification_completed(
&mut self,
summary: &str,
confidence: f32,
duration_ms: u64,
) {
let chunk = PipelineChunk::new(
self.next_chunk_id(),
ChunkType::VerificationCompleted,
summary,
)
.with_layer("Judge")
.with_confidence(confidence)
.with_duration(duration_ms)
.with_timestamp(self.timestamp_ms());
let _ = self.sender.send(chunk).await;
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.chunk_counter
}
#[must_use]
pub fn elapsed_ms(&self) -> u64 {
self.timestamp_ms()
}
}
fn truncate_content(content: &str, max_len: usize) -> String {
if content.len() <= max_len {
content.to_string()
} else {
format!("{}...", &content[..max_len.saturating_sub(3)])
}
}
#[cfg(all(test, feature = "native"))]
mod tests {
use super::*;
use crate::layer1_echo::{EchoLayer, InMemoryVectorStore, MockEmbeddingProvider};
use crate::layer2_speculator::RuleBasedSpeculator;
use crate::layer3_judge::{AdvancedClaimExtractor, JudgeConfig, JudgeImpl, MockSmtVerifier};
use crate::pipeline::{Pipeline, PipelineConfig};
use crate::types::Document;
use tokio::sync::mpsc;
type TestPipeline = Pipeline<
EchoLayer<MockEmbeddingProvider, InMemoryVectorStore>,
RuleBasedSpeculator,
JudgeImpl<AdvancedClaimExtractor, MockSmtVerifier>,
>;
fn create_test_pipeline() -> TestPipeline {
let echo = EchoLayer::new(MockEmbeddingProvider::new(64), InMemoryVectorStore::new(64));
let speculator = RuleBasedSpeculator::default();
let judge = JudgeImpl::new(
AdvancedClaimExtractor::new(),
MockSmtVerifier::default(),
JudgeConfig::default(),
);
Pipeline::new(
echo,
speculator,
judge,
PipelineConfig {
enable_fast_path: false,
..Default::default()
},
)
}
#[tokio::test]
async fn test_pipeline_chunk_creation() {
let chunk = PipelineChunk::new(0, ChunkType::SearchStarted, "test content")
.with_timestamp(100)
.with_layer("Echo")
.with_confidence(0.9);
assert_eq!(chunk.chunk_id, 0);
assert_eq!(chunk.content, "test content");
assert_eq!(chunk.metadata.timestamp_ms, 100);
assert_eq!(chunk.metadata.layer, Some("Echo".to_string()));
assert_eq!(chunk.metadata.confidence, Some(0.9));
}
#[tokio::test]
async fn test_chunk_metadata_builder() {
let metadata = ChunkMetadata::new(50)
.with_layer("Speculator")
.with_confidence(0.85)
.with_duration(200);
assert_eq!(metadata.timestamp_ms, 50);
assert_eq!(metadata.layer, Some("Speculator".to_string()));
assert_eq!(metadata.confidence, Some(0.85));
assert_eq!(metadata.duration_ms, Some(200));
}
#[tokio::test]
async fn test_chunk_type_equality() {
assert_eq!(ChunkType::SearchStarted, ChunkType::SearchStarted);
assert_eq!(
ChunkType::SearchResult {
rank: 0,
score: 0.9
},
ChunkType::SearchResult {
rank: 0,
score: 0.9
}
);
assert_ne!(ChunkType::SearchStarted, ChunkType::DraftGenerated);
}
#[tokio::test]
async fn test_streaming_wrapper_creation() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline).with_buffer_size(64);
assert_eq!(wrapper.buffer_size(), 64);
}
#[tokio::test]
async fn test_streaming_wrapper_buffer_minimum() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline).with_buffer_size(0);
assert_eq!(wrapper.buffer_size(), 1);
}
#[tokio::test]
async fn test_streaming_empty_index() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("What is the meaning of life?");
let mut result = wrapper.process_streaming(query).await.unwrap();
let mut chunks = Vec::new();
while let Some(chunk) = result.next().await {
chunks.push(chunk);
}
assert!(chunks.len() >= 4);
assert!(matches!(chunks[0].chunk_type, ChunkType::SearchStarted));
}
#[tokio::test]
async fn test_streaming_with_documents() {
let mut pipeline = create_test_pipeline();
pipeline
.index(Document::new("The capital of France is Paris."))
.await
.unwrap();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("What is the capital of France?");
let mut result = wrapper.process_streaming(query).await.unwrap();
let mut has_search_result = false;
let mut has_final_answer = false;
while let Some(chunk) = result.next().await {
if matches!(chunk.chunk_type, ChunkType::SearchResult { .. }) {
has_search_result = true;
}
if matches!(chunk.chunk_type, ChunkType::FinalAnswer) {
has_final_answer = true;
}
}
assert!(has_search_result);
assert!(has_final_answer);
}
#[tokio::test]
async fn test_streaming_chunk_ordering() {
let mut pipeline = create_test_pipeline();
pipeline
.index(Document::new("Test document content."))
.await
.unwrap();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("test");
let mut result = wrapper.process_streaming(query).await.unwrap();
let mut chunks = Vec::new();
while let Some(chunk) = result.next().await {
chunks.push(chunk);
}
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.chunk_id, i);
}
}
#[tokio::test]
async fn test_streaming_into_stream() {
use futures::StreamExt;
let mut pipeline = create_test_pipeline();
pipeline
.index(Document::new("Stream test document."))
.await
.unwrap();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("stream");
let result = wrapper.process_streaming(query).await.unwrap();
let mut stream = result.into_stream();
let mut count = 0;
while let Some(chunk) = stream.next().await {
count += 1;
assert!(chunk.chunk_id < 100); }
assert!(count > 0);
}
#[tokio::test]
async fn test_streaming_batch() {
let mut pipeline = create_test_pipeline();
pipeline
.index(Document::new("Alpha document."))
.await
.unwrap();
pipeline
.index(Document::new("Beta document."))
.await
.unwrap();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let queries = vec![Query::new("alpha"), Query::new("beta")];
let results = wrapper.process_batch_streaming(queries).await;
assert_eq!(results.len(), 2);
}
#[tokio::test]
async fn test_streaming_collected_chunks() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("test");
let mut result = wrapper.process_streaming(query).await.unwrap();
let _ = result.next().await;
let _ = result.next().await;
assert!(result.collected_chunks().len() >= 2);
assert!(result.chunk_count() >= 2);
}
#[tokio::test]
async fn test_streaming_for_each() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let query = Query::new("test");
let result = wrapper.process_streaming(query).await.unwrap();
let mut count = 0;
result
.for_each(|_chunk| {
count += 1;
})
.await;
assert!(count > 0);
}
#[tokio::test]
async fn test_progress_reporter_basic() {
let (tx, mut rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
reporter.report_search_started("test query").await;
reporter.report_search_result(0, 0.9, "Test document").await;
reporter.report_search_completed(1).await;
let chunk1 = rx.recv().await.unwrap();
assert!(matches!(chunk1.chunk_type, ChunkType::SearchStarted));
let chunk2 = rx.recv().await.unwrap();
assert!(matches!(
chunk2.chunk_type,
ChunkType::SearchResult { rank: 0, .. }
));
let chunk3 = rx.recv().await.unwrap();
assert!(matches!(
chunk3.chunk_type,
ChunkType::SearchCompleted { total: 1 }
));
}
#[tokio::test]
async fn test_progress_reporter_speculation() {
let (tx, mut rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
reporter.report_speculation_started().await;
reporter.report_speculation("verification", 0.8).await;
reporter
.report_speculation_decision(SpeculationDecision::Accept, 0.9)
.await;
let chunk1 = rx.recv().await.unwrap();
assert!(matches!(chunk1.chunk_type, ChunkType::SpeculationStarted));
let chunk2 = rx.recv().await.unwrap();
assert!(matches!(
chunk2.chunk_type,
ChunkType::SpeculationProgress { .. }
));
let chunk3 = rx.recv().await.unwrap();
assert!(matches!(
chunk3.chunk_type,
ChunkType::SpeculationDecision(SpeculationDecision::Accept)
));
}
#[tokio::test]
async fn test_progress_reporter_verification() {
let (tx, mut rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
reporter.report_verification_started().await;
reporter.report_claim(0, "Test claim").await;
reporter
.report_claim_verified(0, "Verified", "Explanation")
.await;
reporter
.report_verification_completed("Summary", 0.85, 100)
.await;
let chunk1 = rx.recv().await.unwrap();
assert!(matches!(chunk1.chunk_type, ChunkType::VerificationStarted));
let chunk2 = rx.recv().await.unwrap();
assert!(matches!(
chunk2.chunk_type,
ChunkType::ClaimExtracted { claim_id: 0 }
));
let chunk3 = rx.recv().await.unwrap();
assert!(matches!(
chunk3.chunk_type,
ChunkType::ClaimVerified { claim_id: 0, .. }
));
let chunk4 = rx.recv().await.unwrap();
assert!(matches!(
chunk4.chunk_type,
ChunkType::VerificationCompleted
));
}
#[tokio::test]
async fn test_progress_reporter_error() {
let (tx, mut rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
reporter.report_error("Test error").await;
let chunk = rx.recv().await.unwrap();
assert!(matches!(chunk.chunk_type, ChunkType::Error(_)));
}
#[tokio::test]
async fn test_progress_reporter_final_answer() {
let (tx, mut rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
reporter.report_final("The answer is 42", 0.95).await;
let chunk = rx.recv().await.unwrap();
assert!(matches!(chunk.chunk_type, ChunkType::FinalAnswer));
assert_eq!(chunk.content, "The answer is 42");
assert_eq!(chunk.metadata.confidence, Some(0.95));
}
#[tokio::test]
async fn test_progress_reporter_chunk_counter() {
let (tx, _rx) = mpsc::channel(32);
let mut reporter = ProgressReporter::new(tx);
assert_eq!(reporter.chunk_count(), 0);
reporter.report(ChunkType::SearchStarted, "test").await;
assert_eq!(reporter.chunk_count(), 1);
reporter.report(ChunkType::DraftGenerated, "test").await;
assert_eq!(reporter.chunk_count(), 2);
}
#[tokio::test]
async fn test_truncate_content() {
assert_eq!(truncate_content("short", 10), "short");
assert_eq!(
truncate_content("a longer string that needs truncation", 10),
"a longe..."
);
assert_eq!(truncate_content("exactly10c", 10), "exactly10c");
}
#[tokio::test]
async fn test_streaming_result_from_output() {
let query = Query::new("test");
let draft = crate::types::Draft::new("Test answer", "test");
let output = PipelineOutput::new(query, draft);
let result = StreamingPipelineResult::from_output(output);
assert!(result.has_final_output());
}
#[tokio::test]
async fn test_chunk_with_metadata() {
let metadata = ChunkMetadata::new(100)
.with_layer("Test")
.with_confidence(0.5)
.with_duration(50);
let chunk = PipelineChunk::new(0, ChunkType::SearchStarted, "test").with_metadata(metadata);
assert_eq!(chunk.metadata.timestamp_ms, 100);
assert_eq!(chunk.metadata.layer, Some("Test".to_string()));
assert_eq!(chunk.metadata.confidence, Some(0.5));
assert_eq!(chunk.metadata.duration_ms, Some(50));
}
#[tokio::test]
async fn test_concurrent_streaming_queries() {
let mut pipeline = create_test_pipeline();
pipeline
.index(Document::new("First document content."))
.await
.unwrap();
pipeline
.index(Document::new("Second document content."))
.await
.unwrap();
let wrapper = std::sync::Arc::new(StreamingPipelineWrapper::new(pipeline));
let handles: Vec<_> = (0..3)
.map(|i| {
let wrapper = wrapper.clone();
tokio::spawn(async move {
let query = Query::new(format!("query {i}"));
let mut result = wrapper.process_streaming(query).await.unwrap();
let mut count = 0;
while let Some(_chunk) = result.next().await {
count += 1;
}
count
})
})
.collect();
for handle in handles {
let count = handle.await.unwrap();
assert!(count > 0);
}
}
#[tokio::test]
async fn test_streaming_wrapper_inner_access() {
let pipeline = create_test_pipeline();
let wrapper = StreamingPipelineWrapper::new(pipeline);
let _ = wrapper.inner();
let config = wrapper.inner().config();
assert!(!config.enable_fast_path);
}
#[tokio::test]
async fn test_streaming_wrapper_inner_mut() {
let pipeline = create_test_pipeline();
let mut wrapper = StreamingPipelineWrapper::new(pipeline);
let _inner = wrapper.inner_mut();
}
}