use super::backpressure::{BackpressureController, FlowAction as BackpressureFlowAction};
use super::chunking::ResponseChunker;
use super::coordinator::{GlobalStreamingMetrics, StreamingCoordinator};
use super::quality_analyzer::{QualityAnalyzer, StreamingQuality};
use super::state_management::StreamStateManager;
use super::types::*;
use super::typing_simulation::TypingSimulator;
use crate::core::error::Result;
use crate::core::traits::{Model, Tokenizer};
use crate::pipeline::conversational::types::*;
use async_stream::stream;
use futures::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::sleep;
use trustformers_models::common_patterns::{
GenerationConfig as ModelsGenerationConfig, GenerativeModel,
};
use uuid::Uuid;
pub struct ConversationalStreamingPipeline<M, T>
where
M: Model + Send + Sync + GenerativeModel,
T: Tokenizer + Send + Sync,
{
model: Arc<M>,
tokenizer: Arc<T>,
coordinator: StreamingCoordinator,
chunker: ResponseChunker,
typing_simulator: TypingSimulator,
state_manager: StreamStateManager,
backpressure_controller: BackpressureController,
quality_analyzer: QualityAnalyzer,
config: AdvancedStreamingConfig,
}
impl<M, T> ConversationalStreamingPipeline<M, T>
where
M: Model + Send + Sync + GenerativeModel,
T: Tokenizer + Send + Sync,
{
pub fn new(model: M, tokenizer: T, config: AdvancedStreamingConfig) -> Self {
let coordinator = StreamingCoordinator::new(config.clone());
let chunker = ResponseChunker::new(ChunkingStrategy::Adaptive, config.clone());
let typing_simulator = TypingSimulator::new(config.clone());
let state_manager = StreamStateManager::new(config.clone());
let backpressure_controller = BackpressureController::new(config.clone());
let quality_analyzer = QualityAnalyzer::new();
Self {
model: Arc::new(model),
tokenizer: Arc::new(tokenizer),
coordinator,
chunker,
typing_simulator,
state_manager,
backpressure_controller,
quality_analyzer,
config,
}
}
pub async fn generate_streaming_response(
&self,
input: ConversationalInput,
conversation_state: &ConversationState,
) -> Result<Pin<Box<dyn Stream<Item = Result<ExtendedStreamingResponse>> + Send + '_>>> {
let session_id = self
.coordinator
.create_session(
input.conversation_id.clone().unwrap_or_else(|| Uuid::new_v4().to_string()),
)
.await?;
self.state_manager
.update_state(
StreamConnection::Streaming,
"Starting streaming response".to_string(),
)
.await?;
let context = self.build_streaming_context(&input, conversation_state)?;
let full_response = self.generate_full_response(&context).await?;
let metadata = self.analyze_response_metadata(&full_response, &input);
let chunks = self.chunker.chunk_response(&full_response, &metadata);
let stream = self.create_chunk_stream(chunks, session_id, metadata).await?;
Ok(Box::pin(stream))
}
async fn generate_full_response(&self, context: &str) -> Result<String> {
let tokenized = self.tokenizer.encode(context)?;
let gen_config = ModelsGenerationConfig {
max_new_tokens: self.config.base_config.chunk_size * 20, temperature: 0.7,
top_p: 0.9,
top_k: Some(50),
do_sample: true,
early_stopping: true,
repetition_penalty: 1.1,
length_penalty: 1.0,
..ModelsGenerationConfig::default()
};
let response = self.model.generate(context, &gen_config)?;
Ok(self.clean_response(&response))
}
fn clean_response(&self, response: &str) -> String {
let mut cleaned = response.trim().to_string();
cleaned = cleaned.replace("<|endoftext|>", "");
cleaned = cleaned.replace("<|end|>", "");
if !cleaned.ends_with(['.', '!', '?']) && !cleaned.is_empty() {
cleaned.push('.');
}
cleaned
}
fn build_streaming_context(
&self,
input: &ConversationalInput,
conversation_state: &ConversationState,
) -> Result<String> {
let mut context = String::new();
let recent_turns = conversation_state.get_recent_context(2000);
for turn in recent_turns {
let role_str = match turn.role {
ConversationRole::User => "User",
ConversationRole::Assistant => "Assistant",
ConversationRole::System => "System",
};
context.push_str(&format!("{}: {}\n", role_str, turn.content));
}
context.push_str(&format!("User: {}\nAssistant:", input.message));
Ok(context)
}
fn analyze_response_metadata(
&self,
response: &str,
_input: &ConversationalInput,
) -> ConversationMetadata {
ConversationMetadata {
sentiment: Some("neutral".to_string()),
intent: Some("response".to_string()),
confidence: 0.8,
topics: vec!["conversation".to_string()],
safety_flags: Vec::new(),
entities: Vec::new(),
quality_score: 0.8,
engagement_level: EngagementLevel::Medium,
reasoning_type: None,
}
}
async fn create_chunk_stream(
&self,
chunks: Vec<StreamChunk>,
session_id: String,
metadata: ConversationMetadata,
) -> Result<impl Stream<Item = Result<ExtendedStreamingResponse>> + Send + '_> {
let total_chunks = chunks.len();
let stream = stream! {
let mut chunk_index = 0;
let start_time = Instant::now();
for chunk in chunks {
let typing_delay = self.typing_simulator.calculate_typing_delay(&chunk);
let natural_pause = self.typing_simulator.calculate_natural_pause(&chunk);
sleep(typing_delay).await;
let buffer_state = BufferState {
current_size: chunk_index * 50, max_size: self.config.max_buffer_size,
utilization: (chunk_index * 50) as f32 / self.config.max_buffer_size as f32,
pending_chunks: total_chunks - chunk_index,
};
let enhanced_buffer_state = super::backpressure::EnhancedBufferState::from(buffer_state.clone());
if let Ok(actions) = self.backpressure_controller.monitor_and_adjust(&enhanced_buffer_state).await {
for action in actions {
match action {
BackpressureFlowAction::PauseFlow => {
sleep(Duration::from_millis(100)).await;
},
BackpressureFlowAction::DecreaseRate(_) => {
sleep(Duration::from_millis(50)).await;
},
_ => {},
}
}
}
let quality_measurement = self.quality_analyzer.analyze_chunk_quality(&chunk, typing_delay).await;
let elapsed = start_time.elapsed();
let metrics = StreamingMetrics {
chunks_per_second: if elapsed.as_secs() > 0 {
chunk_index as f32 / elapsed.as_secs() as f32
} else {
0.0
},
avg_chunk_size: chunk.content.len() as f32,
total_chunks: chunk_index + 1,
bytes_streamed: (chunk_index + 1) * chunk.content.len(),
duration_ms: elapsed.as_millis() as u64,
buffer_utilization: buffer_state.utilization,
error_count: 0,
retry_count: 0,
};
let quality = super::types::StreamingQuality {
smoothness: quality_measurement.smoothness,
naturalness: quality_measurement.naturalness,
responsiveness: quality_measurement.responsiveness,
coherence: quality_measurement.coherence,
overall_quality: (quality_measurement.smoothness +
quality_measurement.naturalness +
quality_measurement.responsiveness +
quality_measurement.coherence) / 4.0,
};
let extended_response = ExtendedStreamingResponse {
base_response: StreamingResponse {
chunk: chunk.content.clone(),
is_final: chunk_index == total_chunks - 1,
chunk_index,
total_chunks: Some(total_chunks),
metadata: Some(metadata.clone()),
},
state: if chunk_index == total_chunks - 1 {
StreamingState::Completed
} else {
StreamingState::Streaming
},
timestamp: chrono::Utc::now(),
estimated_completion: if chunk_index < total_chunks - 1 {
let remaining_chunks = total_chunks - chunk_index - 1;
let estimated_remaining_ms = remaining_chunks as u64 * typing_delay.as_millis() as u64;
Some(chrono::Utc::now() + chrono::Duration::milliseconds(estimated_remaining_ms as i64))
} else {
None
},
metrics,
quality,
};
yield Ok(extended_response);
if !natural_pause.is_zero() {
sleep(natural_pause).await;
}
chunk_index += 1;
}
let _ = self.coordinator.close_session(&session_id).await;
};
Ok(stream)
}
pub async fn get_streaming_stats(&self) -> Result<GlobalStreamingMetrics> {
Ok(self.coordinator.get_global_metrics().await)
}
pub async fn get_quality_metrics(&self) -> Result<StreamingQuality> {
Ok(self.quality_analyzer.calculate_overall_quality().await)
}
pub async fn cleanup_sessions(&self, max_age_minutes: u64) -> Result<usize> {
Ok(self.coordinator.cleanup_expired_sessions(max_age_minutes).await)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::conversational::streaming::chunking::ResponseChunker;
use crate::pipeline::conversational::streaming::types::{
AdvancedStreamingConfig, ChunkingStrategy, StreamingQuality,
};
use crate::pipeline::conversational::types::{
ConversationMetadata, ConversationRole, ConversationTurn, EngagementLevel,
};
fn default_adv_config() -> AdvancedStreamingConfig {
AdvancedStreamingConfig::default()
}
fn empty_metadata() -> ConversationMetadata {
ConversationMetadata {
sentiment: None,
intent: None,
confidence: 0.0,
topics: vec![],
safety_flags: vec![],
entities: vec![],
quality_score: 0.0,
engagement_level: EngagementLevel::Medium,
reasoning_type: None,
}
}
#[test]
fn test_advanced_config_default_values() {
let config = default_adv_config();
assert!(
config.adaptive_chunking,
"adaptive_chunking should be true by default"
);
assert!(
config.natural_pausing,
"natural_pausing should be true by default"
);
assert!(
config.enable_backpressure,
"enable_backpressure should be true by default"
);
assert!(
config.enable_error_recovery,
"enable_error_recovery should be true by default"
);
}
#[test]
fn test_advanced_config_max_buffer_size_positive() {
let config = default_adv_config();
assert!(
config.max_buffer_size > 0,
"max_buffer_size must be positive"
);
}
#[test]
fn test_advanced_config_typing_speed_positive() {
let config = default_adv_config();
assert!(
config.base_typing_speed > 0.0,
"base_typing_speed must be positive"
);
}
#[test]
fn test_chunker_fixed_size_produces_chunks() {
let chunker = ResponseChunker::new(ChunkingStrategy::FixedSize(5), default_adv_config());
let chunks = chunker.chunk_response("Hello World!", &empty_metadata());
assert!(
!chunks.is_empty(),
"chunker must produce at least one chunk"
);
}
#[test]
fn test_chunker_word_boundary_preserves_content() {
let chunker = ResponseChunker::new(ChunkingStrategy::WordBoundary, default_adv_config());
let text = "The quick brown fox";
let chunks = chunker.chunk_response(text, &empty_metadata());
let reassembled: String = chunks
.iter()
.map(|c| c.content.clone())
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_string();
for word in ["The", "quick", "brown", "fox"] {
assert!(
reassembled.contains(word),
"reassembled text must contain original word '{}'",
word
);
}
}
#[test]
fn test_chunker_sentence_boundary_splits_on_period() {
let chunker =
ResponseChunker::new(ChunkingStrategy::SentenceBoundary, default_adv_config());
let text = "First sentence. Second sentence. Third sentence.";
let chunks = chunker.chunk_response(text, &empty_metadata());
assert!(
!chunks.is_empty(),
"sentence boundary chunker must produce chunks"
);
}
#[test]
fn test_chunker_adaptive_produces_nonempty_chunks_for_nonempty_text() {
let chunker = ResponseChunker::new(ChunkingStrategy::Adaptive, default_adv_config());
let text = "This is a test of the adaptive chunking strategy.";
let chunks = chunker.chunk_response(text, &empty_metadata());
assert!(
!chunks.is_empty(),
"adaptive chunker must produce chunks for nonempty text"
);
}
#[test]
fn test_chunker_fixed_size_chunks_have_sequential_indices() {
let chunker = ResponseChunker::new(ChunkingStrategy::FixedSize(3), default_adv_config());
let chunks = chunker.chunk_response("abcdefghij", &empty_metadata());
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(
chunk.index, i,
"chunk index must be sequential starting from 0"
);
}
}
#[test]
fn test_chunker_empty_text_produces_no_chunks() {
let chunker = ResponseChunker::new(ChunkingStrategy::FixedSize(10), default_adv_config());
let chunks = chunker.chunk_response("", &empty_metadata());
assert!(chunks.is_empty(), "empty text should yield no chunks");
}
#[test]
fn test_streaming_quality_default_all_ones() {
let quality = StreamingQuality::default();
assert!(
(quality.smoothness - 1.0).abs() < f32::EPSILON,
"default smoothness must be 1.0"
);
assert!(
(quality.naturalness - 1.0).abs() < f32::EPSILON,
"default naturalness must be 1.0"
);
assert!(
(quality.responsiveness - 1.0).abs() < f32::EPSILON,
"default responsiveness must be 1.0"
);
assert!(
(quality.coherence - 1.0).abs() < f32::EPSILON,
"default coherence must be 1.0"
);
assert!(
(quality.overall_quality - 1.0).abs() < f32::EPSILON,
"default overall_quality must be 1.0"
);
}
#[tokio::test]
async fn test_pipeline_coordinator_session_lifecycle() {
let coord = StreamingCoordinator::new(default_adv_config());
assert_eq!(
coord.get_active_session_count().await,
0,
"coordinator should start with 0 sessions"
);
let id = coord
.create_session("conv-test".to_string())
.await
.expect("create_session must succeed");
assert_eq!(
coord.get_active_session_count().await,
1,
"one session must be active after creation"
);
coord.close_session(&id).await.expect("close_session must succeed");
assert_eq!(
coord.get_active_session_count().await,
0,
"count must be 0 after close"
);
}
#[tokio::test]
async fn test_pipeline_coordinator_cleanup_empty_returns_zero() {
let coord = StreamingCoordinator::new(default_adv_config());
let removed = coord.cleanup_expired_sessions(1).await;
assert_eq!(removed, 0, "cleanup on empty coordinator must return 0");
}
#[tokio::test]
async fn test_pipeline_get_global_metrics_initial() {
let coord = StreamingCoordinator::new(default_adv_config());
let metrics = coord.get_global_metrics().await;
assert_eq!(
metrics.active_streams, 0,
"initial active_streams must be 0"
);
assert_eq!(
metrics.total_streams_created, 0,
"initial total_streams_created must be 0"
);
}
#[tokio::test]
async fn test_pipeline_metrics_update_after_session_create() {
let coord = StreamingCoordinator::new(default_adv_config());
let _id = coord
.create_session("c1".to_string())
.await
.expect("create_session must succeed");
let metrics = coord.get_global_metrics().await;
assert_eq!(
metrics.total_streams_created, 1,
"total_streams_created must be 1 after one session"
);
}
#[tokio::test]
async fn test_pipeline_sessions_by_conversation() {
let coord = StreamingCoordinator::new(default_adv_config());
let _id1 = coord
.create_session("conv-a".to_string())
.await
.expect("session 1 must succeed");
let _id2 = coord
.create_session("conv-a".to_string())
.await
.expect("session 2 must succeed");
let sessions = coord.get_sessions_by_conversation("conv-a").await;
assert_eq!(sessions.len(), 2, "two sessions for conv-a must be found");
}
}