use std::fmt;
use serde::{Deserialize, Serialize};
use super::cost::CostEstimator;
#[derive(Debug, Clone)]
pub struct TokenBudget {
max_tokens: usize,
used_tokens: usize,
}
impl TokenBudget {
pub fn new(max_tokens: usize) -> Self {
Self {
max_tokens,
used_tokens: 0,
}
}
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
pub fn used_tokens(&self) -> usize {
self.used_tokens
}
pub fn remaining(&self) -> usize {
self.max_tokens.saturating_sub(self.used_tokens)
}
pub fn can_afford(&self, estimated_tokens: usize) -> bool {
self.used_tokens.saturating_add(estimated_tokens) <= self.max_tokens
}
pub fn record_usage(&mut self, tokens: usize) {
self.used_tokens = self.used_tokens.saturating_add(tokens);
}
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub content: String,
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub model: String,
pub finish_reason: String,
}
impl LlmResponse {
pub fn to_usage(&self, estimator: &CostEstimator) -> LlmUsage {
LlmUsage {
prompt_tokens: self.prompt_tokens,
completion_tokens: self.completion_tokens,
total_tokens: self.total_tokens,
estimated_usd: estimator.estimate(
self.prompt_tokens,
self.completion_tokens,
&self.model,
),
}
}
}
#[derive(Debug, Clone)]
pub struct LlmUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub estimated_usd: f64,
}
#[derive(Debug, Clone)]
pub struct RepoContext {
pub primary_language: Option<String>,
pub file_count: usize,
pub top_level_files: Vec<String>,
pub has_tests: bool,
pub has_ci: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GoalKind {
Greenfield,
Rewrite,
Repair,
Audit,
Migration,
Vague,
}
impl fmt::Display for GoalKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
GoalKind::Greenfield => write!(f, "greenfield"),
GoalKind::Rewrite => write!(f, "rewrite"),
GoalKind::Repair => write!(f, "repair"),
GoalKind::Audit => write!(f, "audit"),
GoalKind::Migration => write!(f, "migration"),
GoalKind::Vague => write!(f, "vague"),
}
}
}
#[derive(Debug, Clone)]
pub struct GoalClassification {
pub kind: GoalKind,
pub confidence: f32,
pub reasoning: String,
pub is_testable: bool,
pub suggested_refinement: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Difficulty {
Trivial,
Easy,
Medium,
Hard,
Complex,
}
impl fmt::Display for Difficulty {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Difficulty::Trivial => write!(f, "trivial"),
Difficulty::Easy => write!(f, "easy"),
Difficulty::Medium => write!(f, "medium"),
Difficulty::Hard => write!(f, "hard"),
Difficulty::Complex => write!(f, "complex"),
}
}
}
#[derive(Debug, Clone)]
pub struct Slice {
pub id: String,
pub description: String,
pub write_set: Vec<String>,
pub estimated_difficulty: Difficulty,
}
#[derive(Debug, Clone)]
pub struct Complexity {
pub score: u8,
pub reasoning: String,
pub estimated_hours: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct Plan {
pub goal_text: String,
pub kind: GoalKind,
pub complexity: Complexity,
pub slices: Vec<Slice>,
pub dependencies: Vec<(usize, usize)>,
pub acceptance_criteria: Vec<String>,
pub estimated_tokens: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_budget_remaining() {
let mut budget = TokenBudget::new(1000);
assert_eq!(budget.remaining(), 1000);
budget.record_usage(300);
assert_eq!(budget.remaining(), 700);
}
#[test]
fn test_token_budget_can_afford() {
let mut budget = TokenBudget::new(100);
budget.record_usage(80);
assert!(budget.can_afford(20));
assert!(!budget.can_afford(21));
}
#[test]
fn test_llm_response_to_usage() {
let resp = LlmResponse {
content: "hello".to_string(),
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
model: "gpt-4".to_string(),
finish_reason: "stop".to_string(),
};
let estimator = CostEstimator::new();
let usage = resp.to_usage(&estimator);
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert_eq!(usage.total_tokens, 150);
assert!(usage.estimated_usd > 0.0);
}
#[test]
fn test_goal_kind_display() {
assert_eq!(GoalKind::Greenfield.to_string(), "greenfield");
assert_eq!(GoalKind::Rewrite.to_string(), "rewrite");
assert_eq!(GoalKind::Repair.to_string(), "repair");
assert_eq!(GoalKind::Audit.to_string(), "audit");
assert_eq!(GoalKind::Migration.to_string(), "migration");
assert_eq!(GoalKind::Vague.to_string(), "vague");
}
#[test]
fn test_difficulty_display() {
assert_eq!(Difficulty::Trivial.to_string(), "trivial");
assert_eq!(Difficulty::Easy.to_string(), "easy");
assert_eq!(Difficulty::Medium.to_string(), "medium");
assert_eq!(Difficulty::Hard.to_string(), "hard");
assert_eq!(Difficulty::Complex.to_string(), "complex");
}
}