use crate::models::Usage;
use crate::snapshot::SnapshotRepo;
use std::path::Path;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct TurnContext {
pub id: String,
#[allow(dead_code)]
pub started_at: Instant,
pub step: u32,
pub max_steps: u32,
pub tool_calls: Vec<TurnToolCall>,
#[allow(dead_code)]
pub cancelled: bool,
pub usage: Usage,
}
#[derive(Debug, Clone)]
pub struct TurnToolCall {
pub id: String,
pub name: String,
pub input: serde_json::Value,
pub result: Option<String>,
pub error: Option<String>,
pub duration: Option<Duration>,
}
impl TurnContext {
pub fn new(max_steps: u32) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
started_at: Instant::now(),
step: 0,
max_steps,
tool_calls: Vec::new(),
cancelled: false,
usage: Usage {
input_tokens: 0,
output_tokens: 0,
..Usage::default()
},
}
}
pub fn next_step(&mut self) -> bool {
self.step += 1;
self.step <= self.max_steps
}
pub fn at_max_steps(&self) -> bool {
self.step >= self.max_steps
}
pub fn record_tool_call(&mut self, call: TurnToolCall) {
self.tool_calls.push(call);
}
#[allow(dead_code)]
pub fn cancel(&mut self) {
self.cancelled = true;
}
#[allow(dead_code)]
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn add_usage(&mut self, usage: &Usage) {
self.usage.input_tokens += usage.input_tokens;
self.usage.output_tokens += usage.output_tokens;
self.usage.prompt_cache_hit_tokens = add_optional_usage(
self.usage.prompt_cache_hit_tokens,
usage.prompt_cache_hit_tokens,
);
self.usage.prompt_cache_miss_tokens = add_optional_usage(
self.usage.prompt_cache_miss_tokens,
usage.prompt_cache_miss_tokens,
);
self.usage.reasoning_tokens =
add_optional_usage(self.usage.reasoning_tokens, usage.reasoning_tokens);
}
}
fn add_optional_usage(total: Option<u32>, delta: Option<u32>) -> Option<u32> {
match (total, delta) {
(Some(total), Some(delta)) => Some(total.saturating_add(delta)),
(None, Some(delta)) => Some(delta),
(Some(total), None) => Some(total),
(None, None) => None,
}
}
pub fn pre_turn_snapshot(workspace: &Path, turn_seq: u64) -> Option<String> {
snapshot_with_label(workspace, &format!("pre-turn:{turn_seq}"))
}
pub fn pre_tool_snapshot(workspace: &Path, call_id: &str) -> Option<String> {
snapshot_with_label(workspace, &format!("tool:{call_id}"))
}
pub fn post_turn_snapshot(workspace: &Path, turn_seq: u64) -> Option<String> {
snapshot_with_label(workspace, &format!("post-turn:{turn_seq}"))
}
fn snapshot_with_label(workspace: &Path, label: &str) -> Option<String> {
match SnapshotRepo::open_or_init(workspace) {
Ok(repo) => match repo.snapshot(label) {
Ok(id) => Some(id.0),
Err(e) => {
tracing::warn!(target: "snapshot", "snapshot '{label}' failed: {e}");
None
}
},
Err(e) => {
tracing::warn!(target: "snapshot", "snapshot repo init failed: {e}");
None
}
}
}
impl TurnToolCall {
pub fn new(id: String, name: String, input: serde_json::Value) -> Self {
Self {
id,
name,
input,
result: None,
error: None,
duration: None,
}
}
pub fn set_result(&mut self, result: String, duration: Duration) {
self.result = Some(result);
self.duration = Some(duration);
}
pub fn set_error(&mut self, error: String, duration: Duration) {
self.error = Some(error);
self.duration = Some(duration);
}
}