use crate::models::Usage;
use crate::snapshot::SnapshotRepo;
use std::path::Path;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct TurnContext {
pub id: String,
#[allow(dead_code)]
pub started_at: Instant,
pub step: u32,
pub max_steps: u32,
tool_call_count: usize,
#[allow(dead_code)]
pub cancelled: bool,
pub usage: Usage,
}
impl TurnContext {
pub fn new(max_steps: u32) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
started_at: Instant::now(),
step: 0,
max_steps,
tool_call_count: 0,
cancelled: false,
usage: Usage {
input_tokens: 0,
output_tokens: 0,
..Usage::default()
},
}
}
pub fn next_step(&mut self) -> bool {
self.step += 1;
self.step <= self.max_steps
}
pub fn at_max_steps(&self) -> bool {
self.step >= self.max_steps
}
pub fn record_tool_call(&mut self) {
self.tool_call_count += 1;
}
pub fn has_tool_calls(&self) -> bool {
self.tool_call_count > 0
}
#[allow(dead_code)]
pub fn cancel(&mut self) {
self.cancelled = true;
}
#[allow(dead_code)]
pub fn elapsed(&self) -> Duration {
self.started_at.elapsed()
}
pub fn add_usage(&mut self, usage: &Usage) {
self.usage.input_tokens += usage.input_tokens;
self.usage.output_tokens += usage.output_tokens;
self.usage.prompt_cache_hit_tokens = add_optional_usage(
self.usage.prompt_cache_hit_tokens,
usage.prompt_cache_hit_tokens,
);
self.usage.prompt_cache_miss_tokens = add_optional_usage(
self.usage.prompt_cache_miss_tokens,
usage.prompt_cache_miss_tokens,
);
self.usage.reasoning_tokens =
add_optional_usage(self.usage.reasoning_tokens, usage.reasoning_tokens);
}
}
fn add_optional_usage(total: Option<u32>, delta: Option<u32>) -> Option<u32> {
match (total, delta) {
(Some(total), Some(delta)) => Some(total.saturating_add(delta)),
(None, Some(delta)) => Some(delta),
(Some(total), None) => Some(total),
(None, None) => None,
}
}
const USER_PROMPT_LABEL_MAX: usize = 100;
fn format_snapshot_label(prefix: &str, turn_seq: u64, user_prompt: Option<&str>) -> String {
let base = format!("{prefix}:{turn_seq}");
match user_prompt {
None | Some("") => base,
Some(prompt) => {
let first_line = prompt.lines().next().unwrap_or("");
let truncated: String = first_line.chars().take(USER_PROMPT_LABEL_MAX).collect();
if truncated.chars().count() < first_line.chars().count() {
format!("{base}: {truncated}…")
} else {
format!("{base}: {truncated}")
}
}
}
}
pub fn pre_turn_snapshot(
workspace: &Path,
turn_seq: u64,
cap_bytes: u64,
user_prompt: Option<&str>,
) -> Option<String> {
snapshot_with_label(
workspace,
&format_snapshot_label("pre-turn", turn_seq, user_prompt),
cap_bytes,
)
}
pub fn pre_tool_snapshot(workspace: &Path, call_id: &str, cap_bytes: u64) -> Option<String> {
snapshot_with_label(workspace, &format!("tool:{call_id}"), cap_bytes)
}
pub fn post_turn_snapshot(
workspace: &Path,
turn_seq: u64,
cap_bytes: u64,
user_prompt: Option<&str>,
) -> Option<String> {
snapshot_with_label(
workspace,
&format_snapshot_label("post-turn", turn_seq, user_prompt),
cap_bytes,
)
}
fn snapshot_with_label(workspace: &Path, label: &str, cap_bytes: u64) -> Option<String> {
match SnapshotRepo::open_or_init_with_cap(workspace, cap_bytes) {
Ok(repo) => {
let id = match repo.snapshot(label) {
Ok(id) => Some(id.0),
Err(e) => {
tracing::warn!(target: "snapshot", "snapshot '{label}' failed: {e}");
return None;
}
};
if let Err(e) = repo.prune_keep_last_n(crate::snapshot::DEFAULT_MAX_SNAPSHOTS) {
tracing::warn!(target: "snapshot", "snapshot prune failed: {e}");
}
id
}
Err(e) => {
tracing::warn!(target: "snapshot", "snapshot repo init failed: {e}");
None
}
}
}