use crate::planner::*;
use crate::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Optimization {
ModelDowngrade,
ReducedContext,
ReducedOutput,
SimplifiedPrompt,
CachedResponse,
BatchedRequest,
}
pub struct PlanOptimizer;
impl PlanOptimizer {
pub fn new() -> Self {
Self
}
pub async fn optimize(
&self,
mut plan: ExecutionPlan,
budget_check: &budget::BudgetCheckResult,
cost_calculator: &cost::CostCalculator,
) -> Result<ExecutionPlan> {
if budget_check.approved {
return Ok(plan);
}
let mut optimizations = Vec::new();
match budget_check.action {
budget::BudgetAction::Warn => {
plan.warnings.push(format!(
"Budget warning: {} (${:.4} available, ${:.4} required)",
budget_check.reason,
budget_check.available_budget,
budget_check.required_budget
));
}
budget::BudgetAction::SwitchToSmaller => {
if let Some(cheaper_model) = cost_calculator.find_cheaper_model(
&plan.cost_estimate.model_used,
plan.token_estimate.input_tokens,
) {
let new_cost = cost_calculator.calculate_cost(
&cheaper_model,
plan.cost_estimate.input_tokens,
plan.cost_estimate.output_tokens,
)?;
if new_cost.estimated_cost_usd <= budget_check.available_budget {
plan.cost_estimate = new_cost;
plan.response_strategy.model_selection = cheaper_model;
optimizations.push(Optimization::ModelDowngrade);
}
}
if plan.cost_estimate.estimated_cost_usd > budget_check.available_budget {
let reduction_factor =
budget_check.available_budget / plan.cost_estimate.estimated_cost_usd;
let new_output_tokens =
(plan.cost_estimate.output_tokens as f64 * reduction_factor * 0.9) as usize;
if new_output_tokens > 50 {
let new_cost = cost_calculator.calculate_cost(
&plan.cost_estimate.model_used,
plan.cost_estimate.input_tokens,
new_output_tokens,
)?;
plan.cost_estimate = new_cost;
plan.response_strategy.max_tokens = new_output_tokens;
optimizations.push(Optimization::ReducedOutput);
}
}
}
budget::BudgetAction::Block => {
return Err(crate::ZoeyError::Other(format!(
"Budget exceeded and action is BLOCK: {}",
budget_check.reason
)));
}
budget::BudgetAction::RequireApproval => {
plan.warnings
.push(format!("User approval required: {}", budget_check.reason));
plan.requires_approval = true;
}
}
plan.optimizations_applied.extend(optimizations);
Ok(plan)
}
pub fn optimize_tokens(&self, plan: &mut ExecutionPlan) -> Vec<Optimization> {
let optimizations = Vec::new();
if plan.token_estimate.total_tokens > 100000 {
plan.warnings.push(
"High token usage detected. Consider reducing context or output length."
.to_string(),
);
}
let expected_output = match plan.complexity.level {
complexity::ComplexityLevel::Trivial => 100,
complexity::ComplexityLevel::Simple => 300,
complexity::ComplexityLevel::Moderate => 600,
complexity::ComplexityLevel::Complex => 1000,
complexity::ComplexityLevel::VeryComplex => 2000,
};
if plan.token_estimate.output_tokens > expected_output * 2 {
plan.warnings.push(format!(
"Output tokens ({}) seem high for {} complexity. Expected ~{}.",
plan.token_estimate.output_tokens, plan.complexity.level, expected_output
));
}
optimizations
}
pub fn suggest_optimizations(&self, plan: &ExecutionPlan) -> Vec<String> {
let mut suggestions = Vec::new();
if plan.cost_estimate.estimated_cost_usd > 0.10 {
suggestions.push(
"Consider using a smaller model for cost savings (e.g., GPT-3.5 instead of GPT-4)"
.to_string(),
);
}
if plan.token_estimate.input_tokens > 10000 {
suggestions.push(
"High input tokens detected. Consider summarizing context or using RAG."
.to_string(),
);
}
if matches!(
plan.complexity.level,
complexity::ComplexityLevel::Trivial | complexity::ComplexityLevel::Simple
) && plan.cost_estimate.model_used.contains("gpt-4")
{
suggestions.push(
"Simple task detected. A smaller model like GPT-3.5 may be sufficient.".to_string(),
);
}
if !plan.knowledge.unknown_gaps.is_empty() {
let critical_gaps = plan
.knowledge
.unknown_gaps
.iter()
.filter(|g| g.priority == knowledge::Priority::Critical)
.count();
if critical_gaps > 0 {
suggestions.push(format!(
"{} critical knowledge gaps detected. Consider gathering more context first.",
critical_gaps
));
}
}
suggestions
}
}
impl Default for PlanOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_model_downgrade() {
let optimizer = PlanOptimizer::new();
assert!(true); }
#[test]
fn test_suggestions() {
let optimizer = PlanOptimizer::new();
assert!(true); }
}