use super::{
analysis::{types::EnhancedAnalysisConfig, ConversationAnalyzer},
memory::MemoryManager,
safety::SafetyFilter,
summarization::ContextSummarizer,
types::*,
};
use crate::core::traits::{Model, Tokenizer};
use crate::error::{Result, TrustformersError};
use crate::pipeline::{BasePipeline, Pipeline};
use async_stream;
use async_trait::async_trait;
use futures::Stream;
use futures::StreamExt;
use log::{debug, error, info, warn};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::RwLock;
use trustformers_core::generation::GenerationStrategy;
use trustformers_models::common_patterns::{
GenerationConfig as ModelsGenerationConfig, GenerativeModel,
};
pub struct ConversationalPipeline<M, T> {
base: BasePipeline<M, T>,
config: ConversationalConfig,
conversations: Arc<RwLock<HashMap<String, ConversationState>>>,
safety_filter: Option<SafetyFilter>,
context_summarizer: Option<Arc<RwLock<ContextSummarizer>>>,
memory_manager: MemoryManager,
conversation_analyzer: ConversationAnalyzer,
}
impl<M, T> ConversationalPipeline<M, T>
where
M: Model + Send + Sync + GenerativeModel,
T: Tokenizer + Send + Sync,
{
pub fn new(model: M, tokenizer: T) -> Result<Self> {
let config = ConversationalConfig::default();
Ok(Self {
base: BasePipeline::new(model, tokenizer),
conversations: Arc::new(RwLock::new(HashMap::new())),
safety_filter: Some(SafetyFilter::new()),
context_summarizer: Some(Arc::new(RwLock::new(ContextSummarizer::new(
config.summarization_config.clone(),
)))),
memory_manager: MemoryManager::new(config.memory_config.clone()),
conversation_analyzer: ConversationAnalyzer::new(EnhancedAnalysisConfig::default()),
config,
})
}
pub fn with_config(mut self, config: ConversationalConfig) -> Self {
self.context_summarizer = Some(Arc::new(RwLock::new(ContextSummarizer::new(
config.summarization_config.clone(),
))));
self.memory_manager = MemoryManager::new(config.memory_config.clone());
self.config = config;
self
}
pub fn with_safety_filter(mut self, enable: bool) -> Self {
if enable {
self.safety_filter = Some(SafetyFilter::new());
} else {
self.safety_filter = None;
}
self
}
pub fn with_custom_safety_filter(mut self, filter: SafetyFilter) -> Self {
self.safety_filter = Some(filter);
self
}
pub async fn process_conversation(
&self,
input: ConversationalInput,
) -> Result<ConversationalOutput> {
let start_time = std::time::Instant::now();
debug!(
"Starting conversation processing for input: {:?}",
input.message.len()
);
let conversation_id = input
.conversation_id
.as_ref()
.cloned()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
debug!("Processing conversation ID: {}", conversation_id);
let mut conversations = self.conversations.write().await;
let mut state = conversations
.remove(&conversation_id)
.unwrap_or_else(|| ConversationState::new(conversation_id.clone()));
let config = input.config_override.as_ref().unwrap_or(&self.config);
if let Some(filter) = &self.safety_filter {
if !filter.is_safe(&input.message) {
warn!(
"Safety filter triggered for conversation {}",
conversation_id
);
let safe_response =
"I can't assist with that request. Let's talk about something else."
.to_string();
let metadata = ConversationMetadata {
sentiment: Some("neutral".to_string()),
intent: Some("safety_filtered".to_string()),
confidence: 1.0,
topics: vec!["safety".to_string()],
safety_flags: vec!["inappropriate_content".to_string()],
entities: Vec::new(),
quality_score: 0.5,
engagement_level: EngagementLevel::Low,
reasoning_type: None,
};
let generation_stats = GenerationStats {
generation_time_ms: start_time.elapsed().as_millis() as f64,
tokens_generated: safe_response.len() / 4, tokens_per_second: 0.0,
confidence: 1.0,
truncated: false,
};
return Ok(ConversationalOutput {
response: safe_response,
conversation_id: conversation_id.clone(),
conversation_state: state,
metadata,
generation_stats,
});
}
}
let user_metadata = self
.conversation_analyzer
.analyze_turn(&ConversationTurn {
role: ConversationRole::User,
content: input.message.clone(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 0,
})
.await?;
let user_conversation_metadata = ConversationMetadata {
sentiment: user_metadata.sentiment.clone(),
intent: user_metadata.intent.clone(),
confidence: user_metadata.confidence,
topics: user_metadata.topics.clone(),
safety_flags: vec![], entities: vec![], quality_score: user_metadata.quality_score,
engagement_level: user_metadata.engagement_level.clone(),
reasoning_type: None, };
let user_turn = ConversationTurn {
role: ConversationRole::User,
content: input.message.clone(),
timestamp: chrono::Utc::now(),
metadata: Some(user_conversation_metadata.clone()),
token_count: self.estimate_token_count(&input.message),
};
state.add_turn(user_turn.clone());
if let Some(memory) = self.memory_manager.create_memory(&user_turn) {
state.add_memory(memory);
}
if config.summarization_config.enabled
&& state.total_tokens > config.summarization_config.trigger_threshold
{
self.summarize_conversation(&mut state, config).await?;
}
state.trim_history(config.max_history_turns, config.max_context_tokens);
let context = self.build_enhanced_context(&state, config, &input.message)?;
if state.needs_repair() && config.repair_config.enabled {
return self.attempt_conversation_repair(&mut state, &input, config).await;
}
debug!("Generating response with context length: {}", context.len());
let response = self.generate_response_with_model(&context, config).await?;
info!("Generated response of length: {}", response.len());
let filtered_response = if let Some(filter) = &self.safety_filter {
if filter.is_safe(&response) {
response
} else {
warn!(
"Response safety filter triggered for conversation {}",
conversation_id
);
"I apologize, but I can't provide that response. Let me try a different approach."
.to_string()
}
} else {
response
};
let response_metadata = self
.conversation_analyzer
.analyze_turn(&ConversationTurn {
role: ConversationRole::Assistant,
content: filtered_response.clone(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 0,
})
.await?;
let response_conversation_metadata = ConversationMetadata {
sentiment: response_metadata.sentiment.clone(),
intent: response_metadata.intent.clone(),
confidence: response_metadata.confidence,
topics: response_metadata.topics.clone(),
safety_flags: vec![], entities: vec![], quality_score: response_metadata.quality_score,
engagement_level: response_metadata.engagement_level.clone(),
reasoning_type: None, };
let assistant_turn = ConversationTurn {
role: ConversationRole::Assistant,
content: filtered_response.clone(),
timestamp: chrono::Utc::now(),
metadata: Some(response_conversation_metadata.clone()),
token_count: self.estimate_token_count(&filtered_response),
};
state.add_turn(assistant_turn.clone());
if let Some(memory) = self.memory_manager.create_memory(&assistant_turn) {
state.add_memory(memory);
}
self.update_conversation_health(
&mut state,
&user_conversation_metadata,
&response_conversation_metadata,
);
self.memory_manager.decay_memories(&mut state.memories);
conversations.insert(conversation_id.clone(), state.clone());
debug!("Updated conversation state for ID: {}", conversation_id);
let generation_time = start_time.elapsed().as_millis() as f64;
let tokens_generated = self.estimate_token_count(&filtered_response);
let tokens_per_second = if generation_time > 0.0 {
(tokens_generated as f64) / (generation_time / 1000.0)
} else {
0.0
};
let generation_stats = GenerationStats {
generation_time_ms: generation_time,
tokens_generated,
tokens_per_second,
confidence: response_metadata.confidence,
truncated: filtered_response.len() >= config.max_response_tokens * 4, };
let output = ConversationalOutput {
response: filtered_response,
conversation_id: conversation_id.clone(),
conversation_state: state,
metadata: response_conversation_metadata,
generation_stats,
};
info!(
"Completed conversation processing for ID: {} in {:.2}ms",
conversation_id,
start_time.elapsed().as_millis()
);
Ok(output)
}
fn build_enhanced_context(
&self,
state: &ConversationState,
config: &ConversationalConfig,
current_input: &str,
) -> Result<String> {
let mut context = String::new();
if let Some(system_prompt) = &config.system_prompt {
context.push_str(&format!("System: {}\n\n", system_prompt));
}
if let Some(persona) = &config.persona {
context.push_str(&format!(
"You are {}. {}\n\nBackground: {}\n\nSpeaking style: {}\n\n",
persona.name, persona.personality, persona.background, persona.speaking_style
));
}
if let Some(summary) = &state.context_summary {
context.push_str(&format!("Previous conversation summary:\n{}\n\n", summary));
}
let relevant_memories = state.get_relevant_memories(current_input, 3);
if !relevant_memories.is_empty() {
context.push_str("Relevant context from previous conversations:\n");
for memory in relevant_memories {
context.push_str(&format!("- {}\n", memory.content));
}
context.push('\n');
}
let recent_turns = state.get_recent_context(config.max_context_tokens - context.len());
for turn in recent_turns {
let role_str = match turn.role {
ConversationRole::User => "User",
ConversationRole::Assistant => "Assistant",
ConversationRole::System => "System",
};
context.push_str(&format!("{}> {}\n", role_str, turn.content));
}
match config.conversation_mode {
ConversationMode::Chat => {
context.push_str("\nContinue the conversation naturally and helpfully.\n");
},
ConversationMode::Assistant => {
context.push_str("\nProvide helpful assistance with the user's request.\n");
},
ConversationMode::InstructionFollowing => {
context.push_str("\nFollow the user's instructions carefully and accurately.\n");
},
ConversationMode::QuestionAnswering => {
context.push_str("\nAnswer the user's question accurately and concisely.\n");
},
ConversationMode::RolePlay => {
context
.push_str("\nStay in character and respond appropriately to the scenario.\n");
},
ConversationMode::Educational => {
context.push_str(
"\nProvide educational and informative responses to help the user learn.\n",
);
},
}
context.push_str("\nAssistant:");
Ok(context)
}
async fn generate_response_with_model(
&self,
context: &str,
config: &ConversationalConfig,
) -> Result<String> {
let tokenized = (*self.base.tokenizer).encode(context)?;
let input_ids = tokenized.input_ids;
let mut gen_config = config.generation_config.clone();
if let Some(top_k) = config.top_k {
gen_config.strategy = GenerationStrategy::TopK {
k: top_k,
temperature: config.temperature,
};
} else if config.top_p > 0.0 {
gen_config.strategy = GenerationStrategy::TopP {
p: config.top_p,
temperature: config.temperature,
};
} else {
gen_config.strategy = GenerationStrategy::Sampling {
temperature: config.temperature,
};
}
gen_config.max_length = Some(config.max_response_tokens);
gen_config.do_sample = true;
gen_config.early_stopping = true;
match config.conversation_mode {
ConversationMode::Educational => {
gen_config.repetition_penalty = 1.1;
gen_config.length_penalty = 1.0;
},
ConversationMode::QuestionAnswering => {
gen_config.strategy = GenerationStrategy::TopP {
p: 0.8,
temperature: 0.3,
}; },
ConversationMode::RolePlay => {
gen_config.strategy = GenerationStrategy::Sampling { temperature: 0.8 }; gen_config.repetition_penalty = 1.2;
},
_ => {},
}
let models_config = ModelsGenerationConfig {
max_new_tokens: gen_config.max_length.unwrap_or(512),
temperature: match gen_config.strategy {
GenerationStrategy::Sampling { temperature } => temperature,
GenerationStrategy::TopK { temperature, .. } => temperature,
GenerationStrategy::TopP { temperature, .. } => temperature,
_ => 1.0,
},
top_p: match gen_config.strategy {
GenerationStrategy::TopP { p, .. } => p,
_ => 0.9,
},
top_k: match gen_config.strategy {
GenerationStrategy::TopK { k, .. } => Some(k),
_ => None,
},
repetition_penalty: gen_config.repetition_penalty,
length_penalty: gen_config.length_penalty,
do_sample: gen_config.do_sample,
early_stopping: gen_config.early_stopping,
..ModelsGenerationConfig::default()
};
let response = (*self.base.model).generate(context, &models_config)?;
let cleaned_response = self.clean_generated_response(&response);
Ok(cleaned_response)
}
fn clean_generated_response(&self, response: &str) -> String {
let mut cleaned = response.trim().to_string();
cleaned = cleaned.replace("<|endoftext|>", "");
cleaned = cleaned.replace("<|end|>", "");
cleaned = cleaned.replace("\n\n", "\n");
if !cleaned.ends_with(['.', '!', '?']) && !cleaned.is_empty() {
cleaned.push('.');
}
if cleaned.len() > 2000 {
cleaned.truncate(1997);
cleaned.push_str("...");
}
cleaned
}
async fn summarize_conversation(
&self,
state: &mut ConversationState,
config: &ConversationalConfig,
) -> Result<()> {
if let Some(summarizer_arc) = &self.context_summarizer {
let turns_to_summarize = &state.turns[..state
.turns
.len()
.saturating_sub(config.summarization_config.preserve_recent_turns)];
if !turns_to_summarize.is_empty() {
let summary = summarizer_arc.write().await.summarize_context(turns_to_summarize)?;
state.context_summary = Some(summary);
state.turns = state.turns.split_off(turns_to_summarize.len());
state.total_tokens = state.turns.iter().map(|t| t.token_count).sum();
}
}
Ok(())
}
async fn attempt_conversation_repair(
&self,
state: &mut ConversationState,
input: &ConversationalInput,
config: &ConversationalConfig,
) -> Result<ConversationalOutput> {
state.health.repair_attempts += 1;
let repair_response = if state.health.repair_attempts
<= config.repair_config.max_repair_attempts
{
match config.repair_config.repair_strategies.first() {
Some(RepairStrategy::Clarification) => {
"I want to make sure I understand you correctly. Could you help me by rephrasing or providing more context?".to_string()
},
Some(RepairStrategy::Rephrase) => {
"Let me try a different approach. What specific aspect would you like me to focus on?".to_string()
},
Some(RepairStrategy::Redirect) => {
"Perhaps we could explore this from a different angle. What's most important to you about this topic?".to_string()
},
Some(RepairStrategy::Reset) => {
"Let's start fresh. What can I help you with today?".to_string()
},
None => "I'm having trouble following our conversation. Could you help me understand what you'd like to discuss?".to_string(),
}
} else {
state.turns.clear();
state.context_summary = None;
state.health.repair_attempts = 0;
state.health.overall_score = 1.0;
"I think we should start our conversation over. How can I help you today?".to_string()
};
let metadata = ConversationMetadata {
sentiment: Some("helpful".to_string()),
intent: Some("repair".to_string()),
confidence: 0.9,
topics: vec!["conversation_repair".to_string()],
safety_flags: vec![],
entities: vec![],
quality_score: 0.8,
engagement_level: EngagementLevel::Medium,
reasoning_type: Some(ReasoningType::Logical),
};
let generation_stats = GenerationStats {
generation_time_ms: 0.0,
tokens_generated: repair_response.len() / 4,
tokens_per_second: 0.0,
confidence: 0.9,
truncated: false,
};
Ok(ConversationalOutput {
response: repair_response,
conversation_id: state.conversation_id.clone(),
conversation_state: state.clone(),
metadata,
generation_stats,
})
}
fn update_conversation_health(
&self,
state: &mut ConversationState,
user_metadata: &ConversationMetadata,
response_metadata: &ConversationMetadata,
) {
let coherence = if user_metadata.topics.iter().any(|t| response_metadata.topics.contains(t))
{
0.9
} else {
0.6
};
let engagement = match (
&user_metadata.engagement_level,
&response_metadata.engagement_level,
) {
(EngagementLevel::VeryHigh, _) | (_, EngagementLevel::VeryHigh) => 0.9,
(EngagementLevel::High, _) | (_, EngagementLevel::High) => 0.8,
(EngagementLevel::Medium, _) | (_, EngagementLevel::Medium) => 0.6,
_ => 0.4,
};
let safety =
if user_metadata.safety_flags.is_empty() && response_metadata.safety_flags.is_empty() {
1.0
} else {
0.3
};
state.update_health(coherence, engagement, safety);
}
pub async fn generate_streaming_response(
&self,
input: ConversationalInput,
) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send + '_>>> {
let conversation_id =
input.conversation_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let conversations = self.conversations.read().await;
let state = conversations
.get(&conversation_id)
.cloned()
.unwrap_or_else(|| ConversationState::new(conversation_id.clone()));
let config = input.config_override.as_ref().unwrap_or(&self.config).clone();
let context = self.build_enhanced_context(&state, &config, &input.message)?;
let tokenizer = self.base.tokenizer.clone();
let model = self.base.model.clone();
let chunk_size = config.streaming_config.chunk_size;
let typing_delay = config.streaming_config.typing_delay_ms;
let stream = async_stream::stream! {
let tokenized = match (*tokenizer).encode(&context) {
Ok(t) => t,
Err(e) => {
yield Err(crate::error::TrustformersError::from(e));
return;
}
};
let mut gen_config = config.generation_config.clone();
gen_config.strategy = GenerationStrategy::Sampling { temperature: config.temperature };
gen_config.max_length = Some(config.max_response_tokens);
gen_config.do_sample = true;
let models_config = ModelsGenerationConfig {
max_new_tokens: gen_config.max_length.unwrap_or(512),
temperature: match gen_config.strategy {
GenerationStrategy::Sampling { temperature } => temperature,
GenerationStrategy::TopK { temperature, .. } => temperature,
GenerationStrategy::TopP { temperature, .. } => temperature,
_ => 1.0,
},
top_p: match gen_config.strategy {
GenerationStrategy::TopP { p, .. } => p,
_ => 0.9,
},
top_k: match gen_config.strategy {
GenerationStrategy::TopK { k, .. } => Some(k),
_ => None,
},
repetition_penalty: gen_config.repetition_penalty,
length_penalty: gen_config.length_penalty,
do_sample: gen_config.do_sample,
early_stopping: gen_config.early_stopping,
..ModelsGenerationConfig::default()
};
let full_response = match (*model).generate(&context, &models_config) {
Ok(response) => response,
Err(e) => {
yield Err(crate::error::TrustformersError::from(e));
return;
}
};
let words: Vec<&str> = full_response.split_whitespace().collect();
for chunk in words.chunks(chunk_size) {
let chunk_text = chunk.join(" ") + " ";
yield Ok(chunk_text);
tokio::time::sleep(tokio::time::Duration::from_millis(typing_delay)).await;
}
};
Ok(Box::pin(stream))
}
fn estimate_token_count(&self, text: &str) -> usize {
match (*self.base.tokenizer).encode(text) {
Ok(tokenized) => tokenized.input_ids.len(),
Err(_) => text.len() / 4, }
}
pub async fn get_conversation(&self, conversation_id: &str) -> Option<ConversationState> {
self.conversations.read().await.get(conversation_id).cloned()
}
pub async fn delete_conversation(&self, conversation_id: &str) -> bool {
self.conversations.write().await.remove(conversation_id).is_some()
}
pub async fn list_conversations(&self) -> Vec<String> {
self.conversations.read().await.keys().cloned().collect()
}
pub async fn clear_all_conversations(&self) {
self.conversations.write().await.clear();
}
pub async fn export_conversation(&self, conversation_id: &str) -> Result<String> {
if let Some(state) = self.get_conversation(conversation_id).await {
serde_json::to_string_pretty(&state)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))
} else {
Err(TrustformersError::invalid_input_simple(format!(
"Conversation not found: {}",
conversation_id
)))
}
}
pub async fn import_conversation(&self, json: &str) -> Result<String> {
let state: ConversationState = serde_json::from_str(json)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))?;
let conversation_id = state.conversation_id.clone();
self.conversations.write().await.insert(conversation_id.clone(), state);
Ok(conversation_id)
}
pub async fn get_conversation_stats(&self, conversation_id: &str) -> Option<ConversationStats> {
self.get_conversation(conversation_id).await.map(|state| state.stats)
}
pub fn update_config(&mut self, config: ConversationalConfig) {
self.config = config;
}
pub fn get_config(&self) -> &ConversationalConfig {
&self.config
}
pub async fn process_conversation_safe(
&self,
input: ConversationalInput,
) -> Result<ConversationalOutput> {
match self.process_conversation(input.clone()).await {
Ok(output) => Ok(output),
Err(e) => {
error!("Error processing conversation: {:?}", e);
let fallback_metadata = ConversationMetadata {
sentiment: Some("neutral".to_string()),
intent: Some("error_recovery".to_string()),
confidence: 0.5,
topics: vec!["system_error".to_string()],
safety_flags: vec![],
entities: vec![],
quality_score: 0.5,
engagement_level: EngagementLevel::Low,
reasoning_type: None,
};
let fallback_stats = GenerationStats {
generation_time_ms: 0.0,
tokens_generated: 20,
tokens_per_second: 0.0,
confidence: 0.5,
truncated: false,
};
let conversation_id =
input.conversation_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
Ok(ConversationalOutput {
response: "I apologize, but I encountered an error processing your request. Please try again.".to_string(),
conversation_id: conversation_id.clone(),
conversation_state: ConversationState::new(conversation_id),
metadata: fallback_metadata,
generation_stats: fallback_stats,
})
},
}
}
pub fn validate_input(&self, input: &ConversationalInput) -> Result<()> {
if input.message.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"Message cannot be empty".to_string(),
));
}
if input.message.len() > 10000 {
return Err(TrustformersError::invalid_input_simple(
"Message too long (max 10000 characters)".to_string(),
));
}
Ok(())
}
pub async fn get_health_status(&self) -> HashMap<String, f32> {
let mut health = HashMap::new();
let conversations = self.conversations.read().await;
if conversations.is_empty() {
health.insert("overall_health".to_string(), 1.0);
health.insert("average_conversation_health".to_string(), 1.0);
health.insert("active_conversations".to_string(), 0.0);
} else {
let total_health: f32 =
conversations.values().map(|state| state.health.overall_score).sum();
let avg_health = total_health / conversations.len() as f32;
health.insert("overall_health".to_string(), avg_health);
health.insert("average_conversation_health".to_string(), avg_health);
health.insert(
"active_conversations".to_string(),
conversations.len() as f32,
);
}
health
}
pub async fn reset_all_conversations(&self) {
info!("Resetting all conversations");
self.conversations.write().await.clear();
}
pub async fn get_conversation_count(&self) -> usize {
self.conversations.read().await.len()
}
pub async fn conversation_exists(&self, conversation_id: &str) -> bool {
self.conversations.read().await.contains_key(conversation_id)
}
pub async fn backup_all_conversations(&self) -> Result<String> {
let conversations = self.conversations.read().await;
serde_json::to_string_pretty(&*conversations)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))
}
pub async fn restore_conversations(&self, backup_json: &str) -> Result<usize> {
let conversations: HashMap<String, ConversationState> =
serde_json::from_str(backup_json)
.map_err(|e| TrustformersError::runtime_error(e.to_string()))?;
let count = conversations.len();
*self.conversations.write().await = conversations;
info!("Restored {} conversations from backup", count);
Ok(count)
}
pub async fn assess_conversation_health(
&self,
conversation_id: &str,
) -> Result<ConversationHealth> {
let conversations = self.conversations.read().await;
let state = conversations.get(conversation_id).ok_or_else(|| {
TrustformersError::invalid_input_simple(format!(
"Conversation not found: {}",
conversation_id
))
})?;
Ok(state.health.clone())
}
}
#[async_trait]
impl<M, T> Pipeline for ConversationalPipeline<M, T>
where
M: Model + Send + Sync + GenerativeModel,
T: Tokenizer + Send + Sync,
{
type Input = ConversationalInput;
type Output = ConversationalOutput;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(async { self.process_conversation(input).await })
})
}
}
pub async fn conversational_pipeline(
model: Option<&str>,
options: Option<crate::pipeline::PipelineOptions>,
) -> Result<ConversationalPipeline<crate::AutoModel, crate::AutoTokenizer>> {
let opts = options.unwrap_or_default();
let model_name = model.or(opts.model.as_deref()).unwrap_or("microsoft/DialoGPT-medium");
let model = crate::AutoModel::from_pretrained(model_name)?;
let tokenizer = crate::AutoTokenizer::from_pretrained(model_name)?;
let mut pipeline = ConversationalPipeline::new(model, tokenizer)?;
let mut config = ConversationalConfig::default();
if let Some(max_length) = opts.max_length {
config.max_response_tokens = max_length;
config.max_context_tokens = max_length * 4; }
if let Some(batch_size) = opts.batch_size {
config.max_history_turns = batch_size * 10; }
if opts.streaming {
config.streaming_config.enabled = true;
}
pipeline = pipeline.with_config(config);
Ok(pipeline)
}
pub async fn advanced_conversational_pipeline(
model_name: &str,
config: ConversationalConfig,
) -> Result<ConversationalPipeline<crate::AutoModel, crate::AutoTokenizer>> {
let model = crate::AutoModel::from_pretrained(model_name)?;
let tokenizer = crate::AutoTokenizer::from_pretrained(model_name)?;
let pipeline = ConversationalPipeline::new(model, tokenizer)?.with_config(config);
Ok(pipeline)
}
pub async fn streaming_conversational_pipeline(
model_name: &str,
) -> Result<ConversationalPipeline<crate::AutoModel, crate::AutoTokenizer>> {
let mut config = ConversationalConfig::default();
config.streaming_config.enabled = true;
config.streaming_config.chunk_size = 5;
config.streaming_config.typing_delay_ms = 30;
advanced_conversational_pipeline(model_name, config).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conversation_state_creation() {
let state = ConversationState::new("test-123".to_string());
assert_eq!(state.conversation_id, "test-123");
assert_eq!(state.turns.len(), 0);
assert_eq!(state.total_tokens, 0);
assert!(state.memories.is_empty());
assert_eq!(state.health.overall_score, 1.0);
}
#[test]
fn test_conversation_turn_addition() {
let mut state = ConversationState::new("test-123".to_string());
let turn = ConversationTurn {
role: ConversationRole::User,
content: "Hello".to_string(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 10,
};
state.add_turn(turn);
assert_eq!(state.turns.len(), 1);
assert_eq!(state.total_tokens, 10);
assert_eq!(state.stats.user_turns, 1);
}
#[test]
fn test_conversation_history_trimming() {
let mut state = ConversationState::new("test-123".to_string());
for i in 0..10 {
let turn = ConversationTurn {
role: ConversationRole::User,
content: format!("Message {}", i),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 100,
};
state.add_turn(turn);
}
assert_eq!(state.turns.len(), 10);
assert_eq!(state.total_tokens, 1000);
state.trim_history(5, 10000);
assert_eq!(state.turns.len(), 5);
}
#[test]
fn test_conversational_config_default() {
let config = ConversationalConfig::default();
assert_eq!(config.max_history_turns, 20);
assert_eq!(config.temperature, 0.7);
assert!(config.enable_safety_filter);
assert!(matches!(config.conversation_mode, ConversationMode::Chat));
}
#[test]
fn test_conversation_mode_serialization() {
let mode = ConversationMode::Educational;
let serialized = serde_json::to_string(&mode).expect("JSON serialization failed");
let deserialized: ConversationMode =
serde_json::from_str(&serialized).expect("JSON deserialization failed");
assert!(matches!(deserialized, ConversationMode::Educational));
}
#[test]
fn test_context_variables() {
let mut state = ConversationState::new("test-123".to_string());
state.set_variable("user_name".to_string(), "Alice".to_string());
state.set_variable("preference".to_string(), "technical".to_string());
assert_eq!(state.get_variable("user_name"), Some(&"Alice".to_string()));
assert_eq!(
state.get_variable("preference"),
Some(&"technical".to_string())
);
assert_eq!(state.get_variable("unknown"), None);
}
#[test]
fn test_recent_context_retrieval() {
let mut state = ConversationState::new("test-123".to_string());
for i in 0..5 {
let turn = ConversationTurn {
role: ConversationRole::User,
content: format!("Message {}", i),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 50,
};
state.add_turn(turn);
}
let recent = state.get_recent_context(120); assert!(recent.len() <= 3);
assert!(recent.len() >= 2);
}
#[test]
fn test_conversation_health() {
let mut state = ConversationState::new("test-health".to_string());
assert_eq!(state.health.overall_score, 1.0);
assert!(!state.needs_repair());
state.update_health(0.3, 0.2, 1.0); assert!(state.needs_repair());
assert!(state.health.overall_score < 0.6);
}
#[test]
fn test_memory_management() {
let config = MemoryConfig::default();
let manager = super::super::memory::MemoryManager::new(config);
let turn = ConversationTurn {
role: ConversationRole::User,
content: "I like pizza and my name is John".to_string(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 10,
};
let memory = manager.create_memory(&turn);
assert!(memory.is_some());
let memory = memory.expect("operation failed in test");
assert!(memory.importance > 0.5); assert_eq!(
memory.memory_type,
super::super::types::MemoryType::Preference
);
}
#[tokio::test]
async fn test_conversation_analyzer() {
let analyzer = super::super::analysis::ConversationAnalyzer::new(
super::super::analysis::types::EnhancedAnalysisConfig::default(),
);
let turn = ConversationTurn {
role: ConversationRole::User,
content: "Can you help me with programming? I'm working on a Python project."
.to_string(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 15,
};
let metadata = analyzer.analyze_turn(&turn).await.expect("async operation failed");
assert!(metadata.quality_score >= 0.0 && metadata.quality_score <= 1.0);
assert!(metadata.confidence >= 0.0 && metadata.confidence <= 1.0);
assert!(
metadata.intent.is_none()
|| !metadata.intent.as_ref().expect("operation failed in test").is_empty()
);
}
#[tokio::test]
async fn test_entity_extraction() {
let analyzer = super::super::analysis::ConversationAnalyzer::new(
super::super::analysis::types::EnhancedAnalysisConfig::default(),
);
let turn = ConversationTurn {
role: ConversationRole::User,
content: "I met John Smith on 12/25/2023 and paid $100 for the service.".to_string(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 20,
};
let metadata = analyzer.analyze_turn(&turn).await.expect("async operation failed");
assert!(metadata.quality_score >= 0.0 && metadata.quality_score <= 1.0);
assert!(matches!(
metadata.engagement_level,
super::super::types::EngagementLevel::Low
| super::super::types::EngagementLevel::Medium
| super::super::types::EngagementLevel::High
| super::super::types::EngagementLevel::VeryHigh
));
}
#[test]
fn test_memory_relevance() {
let mut state = ConversationState::new("test-relevance".to_string());
let memory1 = super::super::types::ConversationMemory {
id: "1".to_string(),
content: "User likes programming in Python".to_string(),
importance: 0.8,
last_accessed: chrono::Utc::now(),
access_count: 1,
memory_type: super::super::types::MemoryType::Preference,
tags: vec!["programming".to_string(), "python".to_string()],
};
let memory2 = super::super::types::ConversationMemory {
id: "2".to_string(),
content: "User went to a restaurant yesterday".to_string(),
importance: 0.4,
last_accessed: chrono::Utc::now(),
access_count: 1,
memory_type: super::super::types::MemoryType::Experience,
tags: vec!["food".to_string()],
};
state.add_memory(memory1);
state.add_memory(memory2);
let relevant = state.get_relevant_memories("help with Python coding", 2);
assert_eq!(relevant.len(), 2);
assert!(relevant[0].content.contains("programming")); }
#[test]
#[ignore] fn test_safety_filter() {
let filter = super::super::safety::SafetyFilter::new();
assert!(filter.is_safe("Hello, how are you?"));
assert!(!filter.is_safe("I hate you"));
assert!(filter.get_toxicity_score("Hello") < 0.5);
assert!(filter.get_toxicity_score("I hate you") > 0.5);
}
#[test]
#[ignore] fn test_input_validation() {
let config = ConversationalConfig::default();
let model = crate::AutoModel::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let tokenizer = crate::AutoTokenizer::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let pipeline =
ConversationalPipeline::new(model, tokenizer).expect("operation failed in test");
let empty_input = ConversationalInput {
message: "".to_string(),
conversation_id: None,
context: None,
config_override: None,
};
assert!(pipeline.validate_input(&empty_input).is_err());
let valid_input = ConversationalInput {
message: "Hello".to_string(),
conversation_id: None,
context: None,
config_override: None,
};
assert!(pipeline.validate_input(&valid_input).is_ok());
let long_input = ConversationalInput {
message: "x".repeat(10001),
conversation_id: None,
context: None,
config_override: None,
};
assert!(pipeline.validate_input(&long_input).is_err());
}
#[tokio::test]
#[ignore] async fn test_conversation_backup_restore() {
let config = ConversationalConfig::default();
let model = crate::AutoModel::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let tokenizer = crate::AutoTokenizer::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let pipeline =
ConversationalPipeline::new(model, tokenizer).expect("operation failed in test");
let mut state = ConversationState::new("test-123".to_string());
let turn = ConversationTurn {
role: ConversationRole::User,
content: "Hello".to_string(),
timestamp: chrono::Utc::now(),
metadata: None,
token_count: 10,
};
state.add_turn(turn);
pipeline.conversations.write().await.insert("test-123".to_string(), state);
let backup = pipeline.backup_all_conversations().await.expect("async operation failed");
assert!(!backup.is_empty());
pipeline.clear_all_conversations().await;
assert_eq!(pipeline.get_conversation_count().await, 0);
let count = pipeline.restore_conversations(&backup).await.expect("async operation failed");
assert_eq!(count, 1);
assert_eq!(pipeline.get_conversation_count().await, 1);
}
#[tokio::test]
#[ignore] async fn test_health_status() {
let config = ConversationalConfig::default();
let model = crate::AutoModel::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let tokenizer = crate::AutoTokenizer::from_pretrained("microsoft/DialoGPT-medium")
.expect("operation failed in test");
let pipeline =
ConversationalPipeline::new(model, tokenizer).expect("operation failed in test");
let health = pipeline.get_health_status().await;
assert_eq!(health.get("overall_health"), Some(&1.0));
assert_eq!(health.get("active_conversations"), Some(&0.0));
}
}