use std::sync::Arc;
use crate::config::TokenOptimizationConfig;
use crate::error::TokenOptError;
use crate::estimator::TokenEstimator;
use crate::estimator_tuning::EstimationCalibrator;
use crate::history::dedup::deduplicate_adjacent;
use crate::metrics::OptimizationMetrics;
use crate::optimizer::{OptimizationResult, TokenOptimizer};
use crate::ports::SummarizationPort;
use crate::prompt::rag_dedup::RagEntry;
use crate::prompt::structured::prose_to_yaml;
use crate::stream::repetition::RepetitionDetector;
use crate::tools::chain_collapser::collapse_tool_chains;
use crate::tools::progressive::{ToolUsageTracker, compress_progressively};
use crate::types::{Conversation, ToolDefinition};
#[derive(Debug)]
pub struct TextOptimizationResult {
pub optimized_prompt: String,
pub tokens_before: u32,
pub tokens_after: u32,
pub recommended_max_tokens: Option<u32>,
}
#[allow(clippy::struct_excessive_bools)]
pub struct Pipeline {
config: TokenOptimizationConfig,
summarizer: Option<Box<dyn SummarizationPort>>,
enable_dedup: bool,
enable_structured: bool,
enable_chain_collapse: bool,
enable_progressive_tools: bool,
enable_output_budget: bool,
tool_tracker: ToolUsageTracker,
metrics: Option<Arc<OptimizationMetrics>>,
calibrator: Option<EstimationCalibrator>,
tools: Option<Vec<ToolDefinition>>,
rag_entries: Option<Vec<RagEntry>>,
}
impl std::fmt::Debug for Pipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pipeline")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl Default for Pipeline {
fn default() -> Self {
Self {
config: TokenOptimizationConfig::default(),
summarizer: None,
enable_dedup: true,
enable_structured: true,
enable_chain_collapse: true,
enable_progressive_tools: true,
enable_output_budget: true,
tool_tracker: ToolUsageTracker::new(),
metrics: None,
calibrator: None,
tools: None,
rag_entries: None,
}
}
}
impl Pipeline {
#[must_use]
pub fn context_window(mut self, tokens: u32) -> Self {
self.config.context_window_tokens = tokens;
self
}
#[must_use]
pub fn response_headroom(mut self, ratio: f32) -> Self {
self.config.response_headroom_ratio = ratio;
self
}
#[must_use]
pub fn compaction_trigger(mut self, ratio: f32) -> Self {
self.config.compaction_trigger_ratio = ratio;
self
}
#[must_use]
pub fn max_tools(mut self, max: usize) -> Self {
self.config.max_tools_per_request = max;
self
}
#[must_use]
pub fn with_summarizer(mut self, summarizer: Box<dyn SummarizationPort>) -> Self {
self.summarizer = Some(summarizer);
self
}
#[must_use]
pub fn repetition_detection(mut self, enabled: bool) -> Self {
self.config.repetition_detection_enabled = enabled;
self
}
#[must_use]
pub fn enable_dedup(mut self, enabled: bool) -> Self {
self.enable_dedup = enabled;
self
}
#[must_use]
pub fn enable_structured_prompts(mut self, enabled: bool) -> Self {
self.enable_structured = enabled;
self
}
#[must_use]
pub fn enable_chain_collapse(mut self, enabled: bool) -> Self {
self.enable_chain_collapse = enabled;
self
}
#[must_use]
pub fn enable_progressive_tools(mut self, enabled: bool) -> Self {
self.enable_progressive_tools = enabled;
self
}
#[must_use]
pub fn enable_output_budget(mut self, enabled: bool) -> Self {
self.enable_output_budget = enabled;
self
}
#[must_use]
pub fn output_max_tokens(mut self, cap: Option<u32>) -> Self {
self.config.output_max_tokens = cap;
self
}
#[must_use]
pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = Some(tools);
self
}
#[must_use]
pub fn with_rag(mut self, entries: Vec<RagEntry>) -> Self {
self.rag_entries = Some(entries);
self
}
pub fn with_metrics(&mut self) -> Arc<OptimizationMetrics> {
let metrics = Arc::new(OptimizationMetrics::new());
self.metrics = Some(Arc::clone(&metrics));
metrics
}
#[must_use]
pub fn with_calibration(mut self) -> Self {
self.calibrator = Some(EstimationCalibrator::new());
self
}
pub fn report_actual_tokens(&mut self, model: &str, estimated: u32, actual: u32) {
if let Some(ref mut cal) = self.calibrator {
cal.record_observation(model, estimated, actual);
}
}
#[must_use]
pub fn create_stream_monitor(&self) -> Option<RepetitionDetector> {
if self.config.repetition_detection_enabled {
Some(RepetitionDetector::new(3, 0.4))
} else {
None
}
}
#[must_use]
pub fn metrics_snapshot(&self) -> Option<&OptimizationMetrics> {
self.metrics.as_deref()
}
fn build_optimizer(&self) -> TokenOptimizer {
let mut opt = TokenOptimizer::new(self.config.clone());
if let Some(ref metrics) = self.metrics {
opt = opt.with_metrics(Arc::clone(metrics));
}
if let Some(ref cal) = self.calibrator {
opt = opt.with_calibration();
let _ = cal;
}
opt
}
pub async fn optimize_text(
&self,
system_prompt: &str,
user_message: &str,
) -> Result<TextOptimizationResult, TokenOptError> {
let mut conv = Conversation::with_system_prompt(system_prompt);
conv.add_user_message(user_message);
let tokens_before = TokenEstimator::estimate_conversation(&conv).total;
let optimizer = self.build_optimizer();
let _result = optimizer
.optimize_conversation(&mut conv, self.summarizer.as_deref())
.await?;
if self.enable_structured {
if let Some(ref prompt) = conv.system_prompt {
conv.system_prompt = Some(prose_to_yaml(prompt));
}
}
let tokens_after = TokenEstimator::estimate_conversation(&conv).total;
let recommended_max_tokens = self.config.output_max_tokens;
Ok(TextOptimizationResult {
optimized_prompt: conv.system_prompt.unwrap_or_default(),
tokens_before,
tokens_after,
recommended_max_tokens,
})
}
pub async fn optimize_conversation(
&mut self,
conversation: &mut Conversation,
) -> Result<OptimizationResult, TokenOptError> {
if self.enable_dedup {
let dedup_result = deduplicate_adjacent(&conversation.messages, 0.7);
if dedup_result.merged_count > 0 {
conversation.messages = dedup_result.messages;
}
}
if self.enable_chain_collapse {
let collapse_result = collapse_tool_chains(&conversation.messages);
if collapse_result.collapsed_count > 0 {
conversation.messages = collapse_result.messages;
}
}
if self.enable_structured {
if let Some(ref prompt) = conversation.system_prompt {
let converted = prose_to_yaml(prompt);
if converted != *prompt {
conversation.system_prompt = Some(converted);
}
}
}
let optimizer = self.build_optimizer();
optimizer
.optimize_conversation(conversation, self.summarizer.as_deref())
.await
}
pub async fn optimize_conversation_with_tools(
&mut self,
conversation: &mut Conversation,
tools: &[ToolDefinition],
) -> Result<(OptimizationResult, Vec<ToolDefinition>), TokenOptError> {
if self.enable_dedup {
let dedup_result = deduplicate_adjacent(&conversation.messages, 0.7);
if dedup_result.merged_count > 0 {
conversation.messages = dedup_result.messages;
}
}
if self.enable_chain_collapse {
let collapse_result = collapse_tool_chains(&conversation.messages);
if collapse_result.collapsed_count > 0 {
conversation.messages = collapse_result.messages;
}
}
if self.enable_structured {
if let Some(ref prompt) = conversation.system_prompt {
let converted = prose_to_yaml(prompt);
if converted != *prompt {
conversation.system_prompt = Some(converted);
}
}
}
let optimizer = self.build_optimizer();
let result = optimizer
.optimize_conversation_with_tools(conversation, tools, self.summarizer.as_deref())
.await?;
let mut optimized_tools = optimizer.optimize_tools(
conversation
.messages
.last()
.map_or("", |m| m.content.as_str()),
tools,
);
if self.enable_progressive_tools {
optimized_tools = compress_progressively(&optimized_tools, &self.tool_tracker);
self.tool_tracker.mark_seen(&optimized_tools);
}
Ok((result, optimized_tools))
}
#[must_use]
pub fn recommended_max_tokens(&self, _user_query: &str) -> Option<u32> {
self.config.output_max_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn text_optimization_basic() {
let pipeline = Pipeline::default().context_window(8192);
let result = pipeline
.optimize_text("You are helpful.", "Hello!")
.await
.expect("should succeed");
assert!(result.tokens_before > 0);
assert!(!result.optimized_prompt.is_empty());
}
#[tokio::test]
async fn conversation_optimization_basic() {
let mut pipeline = Pipeline::default().context_window(8192);
let mut conv = Conversation::with_system_prompt("You are helpful.");
conv.add_user_message("Hello!");
conv.add_assistant_message("Hi!");
let result = pipeline
.optimize_conversation(&mut conv)
.await
.expect("should succeed");
assert!(result.estimate_before.total > 0);
}
#[test]
fn pipeline_builder_chaining() {
let pipeline = Pipeline::default()
.context_window(4096)
.response_headroom(0.3)
.compaction_trigger(0.6)
.max_tools(5)
.repetition_detection(false);
assert_eq!(pipeline.config.context_window_tokens, 4096);
assert!((pipeline.config.response_headroom_ratio - 0.3).abs() < f32::EPSILON);
assert!((pipeline.config.compaction_trigger_ratio - 0.6).abs() < f32::EPSILON);
assert_eq!(pipeline.config.max_tools_per_request, 5);
assert!(!pipeline.config.repetition_detection_enabled);
}
#[test]
fn v2_features_configurable() {
let pipeline = Pipeline::default()
.enable_dedup(false)
.enable_structured_prompts(false)
.enable_chain_collapse(false)
.enable_progressive_tools(false)
.enable_output_budget(false);
assert!(!pipeline.enable_dedup);
assert!(!pipeline.enable_structured);
assert!(!pipeline.enable_chain_collapse);
assert!(!pipeline.enable_progressive_tools);
assert!(!pipeline.enable_output_budget);
}
#[test]
fn recommended_max_tokens_none_by_default() {
let pipeline = Pipeline::default().context_window(8192);
let max = pipeline.recommended_max_tokens("Hello!");
assert!(max.is_none(), "no output cap by default");
}
#[test]
fn recommended_max_tokens_returns_explicit_cap() {
let mut pipeline = Pipeline::default().context_window(8192);
pipeline.config.output_max_tokens = Some(512);
let max = pipeline.recommended_max_tokens("Hello!");
assert_eq!(max, Some(512));
}
#[test]
fn recommended_max_tokens_disabled() {
let pipeline = Pipeline::default().enable_output_budget(false);
assert!(pipeline.recommended_max_tokens("Hello!").is_none());
}
#[tokio::test]
async fn dedup_runs_in_pipeline() {
let mut pipeline = Pipeline::default().context_window(8192).enable_dedup(true);
let mut conv = Conversation::new();
conv.add_user_message("What is the weather today?");
conv.add_user_message("What is the weather today?");
conv.add_assistant_message("It's sunny!");
let original_count = conv.messages.len();
let _result = pipeline
.optimize_conversation(&mut conv)
.await
.expect("should succeed");
assert!(conv.messages.len() < original_count);
}
#[test]
fn with_tools_stores_definitions() {
use crate::types::ToolParameters;
use std::collections::HashMap;
let tools = vec![ToolDefinition {
name: "get_weather".to_string(),
description: "Get weather forecast".to_string(),
parameters: ToolParameters {
schema_type: "object".to_string(),
properties: HashMap::new(),
required: Vec::new(),
},
icon: None,
}];
let pipeline = Pipeline::default().with_tools(tools);
assert_eq!(pipeline.tools.as_ref().expect("tools set").len(), 1);
}
#[test]
fn with_rag_stores_entries() {
let entries = vec![RagEntry {
content: "Rust is a systems language.".to_string(),
relevance: 0.95,
embedding: None,
}];
let pipeline = Pipeline::default().with_rag(entries);
assert_eq!(pipeline.rag_entries.as_ref().expect("rag set").len(), 1);
}
#[test]
fn with_metrics_creates_shared_handle() {
let mut pipeline = Pipeline::default();
let metrics = pipeline.with_metrics();
assert_eq!(metrics.total_optimizations(), 0);
assert!(pipeline.metrics_snapshot().is_some());
}
#[test]
fn with_calibration_enables_calibrator() {
let pipeline = Pipeline::default().with_calibration();
assert!(pipeline.calibrator.is_some());
}
#[test]
fn report_actual_tokens_noop_without_calibration() {
let mut pipeline = Pipeline::default();
pipeline.report_actual_tokens("llama3", 100, 110);
assert!(pipeline.calibrator.is_none());
}
#[test]
fn report_actual_tokens_with_calibration() {
let mut pipeline = Pipeline::default().with_calibration();
pipeline.report_actual_tokens("llama3", 100, 110);
let cal = pipeline.calibrator.as_ref().expect("calibrator set");
let factor = cal.correction_factor("llama3");
assert!(factor > 1.0);
}
#[test]
fn create_stream_monitor_when_enabled() {
let pipeline = Pipeline::default().repetition_detection(true);
assert!(pipeline.create_stream_monitor().is_some());
}
#[test]
fn create_stream_monitor_when_disabled() {
let pipeline = Pipeline::default().repetition_detection(false);
assert!(pipeline.create_stream_monitor().is_none());
}
#[test]
fn metrics_snapshot_none_without_metrics() {
let pipeline = Pipeline::default();
assert!(pipeline.metrics_snapshot().is_none());
}
#[tokio::test]
async fn metrics_tracked_through_pipeline() {
let mut pipeline = Pipeline::default().context_window(8192);
let metrics = pipeline.with_metrics();
let mut conv = Conversation::with_system_prompt("You are helpful.");
conv.add_user_message("Hello!");
let _result = pipeline
.optimize_conversation(&mut conv)
.await
.expect("should succeed");
assert!(metrics.total_optimizations() > 0);
}
}