use crate::{AgentContext, AgentResult};
use car_inference::{GenerateParams, GenerateRequest};
#[derive(Debug, Clone)]
pub struct PlanConfig {
pub max_tokens: usize,
pub temperature: f64,
pub model: Option<String>,
pub include_tools: bool,
}
impl Default for PlanConfig {
fn default() -> Self {
Self {
max_tokens: 2048,
temperature: 0.2,
model: None,
include_tools: true,
}
}
}
pub struct PlannerAgent {
ctx: AgentContext,
config: PlanConfig,
}
impl PlannerAgent {
pub fn new(ctx: AgentContext) -> Self {
Self {
ctx,
config: PlanConfig::default(),
}
}
pub fn with_config(ctx: AgentContext, config: PlanConfig) -> Self {
Self { ctx, config }
}
pub async fn plan(&self, goal: &str, context: Option<&str>) -> AgentResult {
let prompt = format!(
"You are a planning agent. Break down the following goal into concrete, ordered steps.\n\n\
Goal: {goal}\n\n\
For each step, specify:\n\
1. What to do (action)\n\
2. What it depends on (which previous steps must complete)\n\
3. What it produces (output/state change)\n\
4. How to verify it worked\n\n\
Format as a numbered list. Be specific and actionable — no vague steps like 'analyze the situation.'"
);
let start = std::time::Instant::now();
let req = GenerateRequest {
prompt,
model: self.config.model.clone(),
params: GenerateParams {
temperature: self.config.temperature,
max_tokens: self.config.max_tokens,
..Default::default()
},
context: context.map(String::from),
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
match self.ctx.inference.generate_tracked(req).await {
Ok(result) => {
let step_count = result
.text
.lines()
.filter(|l| l.trim_start().starts_with(|c: char| c.is_ascii_digit()))
.count();
let confidence = if step_count >= 3 { 0.8 } else { 0.5 };
AgentResult {
agent: "planner".into(),
output: result.text,
confidence,
model_used: result.model_used,
latency_ms: start.elapsed().as_millis() as u64,
}
}
Err(e) => AgentResult {
agent: "planner".into(),
output: format!("Planning failed: {}", e),
confidence: 0.0,
model_used: String::new(),
latency_ms: start.elapsed().as_millis() as u64,
},
}
}
}