use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::{from_reader, to_writer_pretty};
use crate::Error;
use crate::chat::config::ChatConfig;
use crate::error::Result;
use crate::types::{
CacheControlEphemeral, ContentBlock, MessageCreateTemplate, MessageParam, MessageParamContent,
MessageRole, Model, SystemPrompt, TextBlock, Usage,
};
use crate::{Agent, Anthropic, Budget, Renderer, ThinkingConfig, TurnOutcome};
const MAX_CACHE_BREAKPOINTS: usize = 4;
const BUDGET_BUFFER_MICRO_CENTS: u64 = 1;
pub trait ChatAgent: Agent {
fn config(&self) -> &ChatConfig;
fn config_mut(&mut self) -> &mut ChatConfig;
}
pub struct ConfigAgent {
config: ChatConfig,
}
impl ConfigAgent {
pub fn new(config: ChatConfig) -> Self {
Self { config }
}
}
#[async_trait::async_trait]
impl Agent for ConfigAgent {
async fn max_tokens(&self) -> u32 {
self.config.max_tokens()
}
async fn model(&self) -> Model {
self.config.model()
}
async fn stop_sequences(&self) -> Option<Vec<String>> {
let sequences = self.config.stop_sequences();
if sequences.is_empty() {
None
} else {
Some(sequences.to_vec())
}
}
async fn system(&self) -> Option<SystemPrompt> {
let prompt = self.config.template.system.as_ref()?;
if self.config.caching_enabled {
let mut blocks = match prompt {
SystemPrompt::String(text) => vec![TextBlock::new(text.clone())],
SystemPrompt::Blocks(existing) => {
existing.iter().map(|b| b.block.clone()).collect()
}
};
if let Some(last) = blocks.last_mut() {
last.cache_control = Some(CacheControlEphemeral::new());
}
Some(SystemPrompt::from_blocks(blocks))
} else {
Some(prompt.clone())
}
}
async fn temperature(&self) -> Option<f32> {
self.config.template.temperature
}
async fn thinking(&self) -> Option<ThinkingConfig> {
self.config.template.thinking
}
async fn top_k(&self) -> Option<u32> {
self.config.template.top_k
}
async fn top_p(&self) -> Option<f32> {
self.config.template.top_p
}
}
impl ChatAgent for ConfigAgent {
fn config(&self) -> &ChatConfig {
&self.config
}
fn config_mut(&mut self) -> &mut ChatConfig {
&mut self.config
}
}
pub struct ChatSession<A: ChatAgent> {
client: Anthropic,
agent: A,
messages: Vec<MessageParam>,
usage_totals: Usage,
last_turn_usage: Option<Usage>,
request_count: u64,
budget: Arc<Budget>,
}
#[derive(Debug, Clone)]
pub struct SessionStats {
pub model: Model,
pub message_count: usize,
pub max_tokens: u32,
pub system_prompt: Option<String>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stop_sequences: Vec<String>,
pub thinking_budget: Option<u32>,
pub session_budget_tokens: Option<u64>,
pub budget_spent_tokens: u64,
pub transcript_path: Option<PathBuf>,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_requests: u64,
pub last_turn_input_tokens: Option<u64>,
pub last_turn_output_tokens: Option<u64>,
pub caching_enabled: bool,
pub total_cache_creation_tokens: u64,
pub total_cache_read_tokens: u64,
}
impl ChatSession<ConfigAgent> {
pub fn new(client: Anthropic, config: ChatConfig) -> Self {
Self::with_agent(client, ConfigAgent::new(config))
}
}
impl<A: ChatAgent> ChatSession<A> {
pub fn with_agent(client: Anthropic, agent: A) -> Self {
let budget = Arc::new(Budget::new_flat_rate(u64::MAX, 1));
Self {
client,
agent,
messages: Vec::new(),
usage_totals: Usage::new(0, 0),
last_turn_usage: None,
request_count: 0,
budget,
}
}
pub async fn send_message(
&mut self,
message: MessageParam,
renderer: &mut dyn Renderer,
) -> Result<()> {
let context = ();
if let Some(budget) = self.agent.config().session_budget.as_ref()
&& !budget_allows_next_turn(budget, self.last_turn_usage.as_ref())
{
renderer.print_error(
&context,
"Session budget exhausted. Use /budget to increase or clear the limit.",
);
return Err(Error::bad_request(
"session budget exhausted",
Some("budget".to_string()),
));
}
let previous_len = self.messages.len();
self.messages.push(message);
if self.agent.config().caching_enabled {
apply_cache_control_to_messages(&mut self.messages);
}
let outcome = self
.agent
.take_turn_streaming_root(&self.client, &mut self.messages, &self.budget, renderer)
.await;
match outcome {
Ok(outcome) => {
self.record_usage(outcome);
self.auto_save_transcript()?;
Ok(())
}
Err(err) => {
self.messages.truncate(previous_len);
Err(err)
}
}
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn config(&self) -> &ChatConfig {
self.agent.config()
}
pub fn config_mut(&mut self) -> &mut ChatConfig {
self.agent.config_mut()
}
pub fn template(&self) -> &MessageCreateTemplate {
&self.agent.config().template
}
pub fn template_mut(&mut self) -> &mut MessageCreateTemplate {
&mut self.agent.config_mut().template
}
pub fn save_transcript_to<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let transcript = TranscriptFile::new(&self.messages);
let file = File::create(path.as_ref())
.map_err(|err| Error::io("failed to create transcript file", err))?;
let writer = BufWriter::new(file);
to_writer_pretty(writer, &transcript).map_err(|err| {
Error::serialization("failed to serialize transcript", Some(Box::new(err)))
})
}
pub fn load_transcript_from<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
let file = File::open(path.as_ref())
.map_err(|err| Error::io("failed to open transcript file", err))?;
let reader = BufReader::new(file);
let transcript: TranscriptFile = from_reader(reader).map_err(|err| {
Error::serialization("failed to parse transcript", Some(Box::new(err)))
})?;
self.messages = transcript.messages;
Ok(())
}
pub fn stats(&self) -> SessionStats {
let config = self.agent.config();
let (session_budget_tokens, budget_spent_tokens) = match config.session_budget.as_ref() {
Some(budget) => {
let total = budget.total_micro_cents();
let remaining = budget.remaining_micro_cents();
(Some(total), total.saturating_sub(remaining))
}
None => (None, 0),
};
SessionStats {
model: config.model(),
message_count: self.message_count(),
max_tokens: config.max_tokens(),
system_prompt: config.system_prompt_text().map(str::to_string),
temperature: config.template.temperature,
top_p: config.template.top_p,
top_k: config.template.top_k,
stop_sequences: config.template.stop_sequences.clone().unwrap_or_default(),
thinking_budget: config.thinking_budget(),
session_budget_tokens,
budget_spent_tokens,
transcript_path: config.transcript_path.clone(),
total_input_tokens: tokens_to_u64(self.usage_totals.input_tokens),
total_output_tokens: tokens_to_u64(self.usage_totals.output_tokens),
total_requests: self.request_count,
last_turn_input_tokens: self
.last_turn_usage
.map(|usage| tokens_to_u64(usage.input_tokens)),
last_turn_output_tokens: self
.last_turn_usage
.map(|usage| tokens_to_u64(usage.output_tokens)),
caching_enabled: config.caching_enabled,
total_cache_creation_tokens: self
.usage_totals
.cache_creation_input_tokens
.map(|t| t.max(0) as u64)
.unwrap_or(0),
total_cache_read_tokens: self
.usage_totals
.cache_read_input_tokens
.map(|t| t.max(0) as u64)
.unwrap_or(0),
}
}
fn record_usage(&mut self, outcome: TurnOutcome) {
self.last_turn_usage = Some(outcome.usage);
self.usage_totals = self.usage_totals + outcome.usage;
self.request_count = self.request_count.saturating_add(outcome.request_count);
if let Some(budget) = self.agent.config().session_budget.as_ref() {
budget.consume_usage_saturating(&outcome.usage);
}
}
fn auto_save_transcript(&self) -> Result<()> {
if let Some(path) = &self.agent.config().transcript_path {
self.save_transcript_to(path)
} else {
Ok(())
}
}
}
#[derive(Serialize, Deserialize)]
struct TranscriptFile {
version: u8,
messages: Vec<MessageParam>,
}
impl TranscriptFile {
fn new(messages: &[MessageParam]) -> Self {
Self {
version: 1,
messages: messages.to_vec(),
}
}
}
fn tokens_to_u64(value: i32) -> u64 {
value.max(0) as u64
}
fn budget_allows_next_turn(budget: &Budget, last_turn_usage: Option<&Usage>) -> bool {
let remaining = budget.remaining_micro_cents();
if remaining == 0 {
return false;
}
let Some(usage) = last_turn_usage else {
return true;
};
let cost = budget.calculate_cost(usage);
cost.saturating_add(BUDGET_BUFFER_MICRO_CENTS) < remaining
}
fn apply_cache_control_to_messages(messages: &mut [MessageParam]) {
for msg in messages.iter_mut() {
clear_cache_control_from_message(msg);
}
let user_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, msg)| msg.role == MessageRole::User)
.map(|(idx, _)| idx)
.rev()
.take(MAX_CACHE_BREAKPOINTS - 1) .collect();
for idx in user_indices {
apply_cache_control_to_message(&mut messages[idx]);
}
}
fn clear_cache_control_from_message(message: &mut MessageParam) {
if let MessageParamContent::Array(blocks) = &mut message.content {
for block in blocks.iter_mut() {
clear_cache_control_on_block(block);
}
}
}
fn clear_cache_control_on_block(block: &mut ContentBlock) {
match block {
ContentBlock::Text(text_block) => {
text_block.cache_control = None;
}
ContentBlock::ToolResult(tool_result) => {
tool_result.cache_control = None;
}
ContentBlock::ToolUse(tool_use) => {
tool_use.cache_control = None;
}
ContentBlock::Image(image_block) => {
image_block.cache_control = None;
}
ContentBlock::Document(document_block) => {
document_block.cache_control = None;
}
ContentBlock::ServerToolUse(server_tool_use) => {
server_tool_use.cache_control = None;
}
ContentBlock::WebSearchToolResult(web_search_result) => {
web_search_result.cache_control = None;
}
ContentBlock::Thinking(_) | ContentBlock::RedactedThinking(_) => {}
}
}
fn apply_cache_control_to_message(message: &mut MessageParam) {
match &mut message.content {
MessageParamContent::String(text) => {
let block = ContentBlock::Text(
TextBlock::new(text.clone()).with_cache_control(CacheControlEphemeral::new()),
);
message.content = MessageParamContent::Array(vec![block]);
}
MessageParamContent::Array(blocks) => {
if let Some(last_block) = blocks.last_mut() {
set_cache_control_on_block(last_block);
}
}
}
}
fn set_cache_control_on_block(block: &mut ContentBlock) {
match block {
ContentBlock::Text(text_block) => {
text_block.cache_control = Some(CacheControlEphemeral::new());
}
ContentBlock::ToolResult(tool_result) => {
tool_result.cache_control = Some(CacheControlEphemeral::new());
}
ContentBlock::ToolUse(tool_use) => {
tool_use.cache_control = Some(CacheControlEphemeral::new());
}
ContentBlock::Image(_)
| ContentBlock::Document(_)
| ContentBlock::ServerToolUse(_)
| ContentBlock::WebSearchToolResult(_)
| ContentBlock::Thinking(_)
| ContentBlock::RedactedThinking(_) => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{KnownModel, SystemPrompt};
#[test]
fn new_session_empty() {
let client = Anthropic::new(None).unwrap();
let config = ChatConfig::default();
let session = ChatSession::new(client, config);
assert_eq!(session.message_count(), 0);
}
#[test]
fn clear_session() {
let client = Anthropic::new(None).unwrap();
let config = ChatConfig::default();
let mut session = ChatSession::new(client, config);
session.messages.push(MessageParam {
role: MessageRole::User,
content: MessageParamContent::String("test".to_string()),
});
assert_eq!(session.message_count(), 1);
session.clear();
assert_eq!(session.message_count(), 0);
}
#[test]
fn template_updates_model() {
let client = Anthropic::new(None).unwrap();
let config = ChatConfig::default();
let mut session = ChatSession::new(client, config);
assert_eq!(
session.template().model,
Some(Model::Known(KnownModel::ClaudeHaiku45))
);
session.template_mut().model = Some(Model::Known(KnownModel::ClaudeSonnet40));
assert_eq!(
session.template().model,
Some(Model::Known(KnownModel::ClaudeSonnet40))
);
}
#[test]
fn template_updates_system_prompt() {
let client = Anthropic::new(None).unwrap();
let config = ChatConfig::default();
let mut session = ChatSession::new(client, config);
assert!(session.template().system.is_none());
session.template_mut().system = Some(SystemPrompt::from("Be helpful"));
assert!(matches!(
session.template().system,
Some(SystemPrompt::String(ref text)) if text == "Be helpful"
));
session.template_mut().system = None;
assert!(session.template().system.is_none());
}
#[test]
fn budget_allows_next_turn_without_usage() {
let budget = Budget::new_with_rates(1000, 1, 1, 0, 0);
assert!(budget_allows_next_turn(&budget, None));
}
#[test]
fn budget_allows_next_turn_with_usage() {
let budget = Budget::new_with_rates(1000, 1, 1, 0, 0);
let usage = Usage::new(400, 0);
assert!(budget_allows_next_turn(&budget, Some(&usage)));
}
#[test]
fn budget_blocks_next_turn_when_over_grace() {
let budget = Budget::new_with_rates(100, 1, 1, 0, 0);
let usage = Usage::new(100, 0);
assert!(!budget_allows_next_turn(&budget, Some(&usage)));
}
#[test]
fn apply_cache_control_to_string_content() {
let mut message = MessageParam {
role: MessageRole::User,
content: MessageParamContent::String("hello".to_string()),
};
apply_cache_control_to_message(&mut message);
match &message.content {
MessageParamContent::Array(blocks) => {
assert_eq!(blocks.len(), 1);
if let ContentBlock::Text(text_block) = &blocks[0] {
assert_eq!(text_block.text, "hello");
assert!(text_block.cache_control.is_some());
} else {
panic!("Expected Text block");
}
}
_ => panic!("Expected Array content"),
}
}
#[test]
fn apply_cache_control_to_array_content() {
let mut message = MessageParam {
role: MessageRole::User,
content: MessageParamContent::Array(vec![
ContentBlock::Text(TextBlock::new("first")),
ContentBlock::Text(TextBlock::new("second")),
]),
};
apply_cache_control_to_message(&mut message);
match &message.content {
MessageParamContent::Array(blocks) => {
assert_eq!(blocks.len(), 2);
if let ContentBlock::Text(first) = &blocks[0] {
assert!(first.cache_control.is_none());
}
if let ContentBlock::Text(second) = &blocks[1] {
assert!(second.cache_control.is_some());
}
}
_ => panic!("Expected Array content"),
}
}
#[test]
fn apply_cache_control_to_messages_selects_user_messages() {
let mut messages = vec![
MessageParam {
role: MessageRole::User,
content: MessageParamContent::String("user1".to_string()),
},
MessageParam {
role: MessageRole::Assistant,
content: MessageParamContent::String("assistant1".to_string()),
},
MessageParam {
role: MessageRole::User,
content: MessageParamContent::String("user2".to_string()),
},
MessageParam {
role: MessageRole::Assistant,
content: MessageParamContent::String("assistant2".to_string()),
},
MessageParam {
role: MessageRole::User,
content: MessageParamContent::String("user3".to_string()),
},
];
apply_cache_control_to_messages(&mut messages);
for (idx, msg) in messages.iter().enumerate() {
let has_cache = match &msg.content {
MessageParamContent::Array(blocks) => blocks.last().is_some_and(|b| {
if let ContentBlock::Text(t) = b {
t.cache_control.is_some()
} else {
false
}
}),
MessageParamContent::String(_) => false,
};
let is_user = msg.role == MessageRole::User;
if is_user {
assert!(
has_cache,
"User message at index {idx} should have cache_control"
);
} else {
assert!(
!has_cache,
"Assistant message at index {idx} should not have cache_control"
);
}
}
}
#[test]
fn apply_cache_control_respects_max_breakpoints() {
let mut messages: Vec<MessageParam> = (0..5)
.map(|i| MessageParam {
role: MessageRole::User,
content: MessageParamContent::String(format!("user{i}")),
})
.collect();
apply_cache_control_to_messages(&mut messages);
let cached_count = messages
.iter()
.filter(|msg| {
matches!(
&msg.content,
MessageParamContent::Array(blocks)
if blocks.last().is_some_and(|b| {
matches!(b, ContentBlock::Text(t) if t.cache_control.is_some())
})
)
})
.count();
assert_eq!(cached_count, 3);
for (idx, msg) in messages.iter().enumerate() {
let has_cache = matches!(
&msg.content,
MessageParamContent::Array(blocks)
if blocks.last().is_some_and(|b| {
matches!(b, ContentBlock::Text(t) if t.cache_control.is_some())
})
);
if idx < 2 {
assert!(!has_cache, "Message {idx} should NOT have cache_control");
} else {
assert!(has_cache, "Message {idx} should have cache_control");
}
}
}
#[test]
fn apply_cache_control_clears_old_markers() {
let mut messages: Vec<MessageParam> = (0..3)
.map(|i| MessageParam {
role: MessageRole::User,
content: MessageParamContent::Array(vec![ContentBlock::Text(
TextBlock::new(format!("user{i}"))
.with_cache_control(CacheControlEphemeral::new()),
)]),
})
.collect();
for i in 3..5 {
messages.push(MessageParam {
role: MessageRole::User,
content: MessageParamContent::String(format!("user{i}")),
});
}
apply_cache_control_to_messages(&mut messages);
let cached_count = messages
.iter()
.filter(|msg| {
matches!(
&msg.content,
MessageParamContent::Array(blocks)
if blocks.last().is_some_and(|b| {
matches!(b, ContentBlock::Text(t) if t.cache_control.is_some())
})
)
})
.count();
println!("cached_count: {cached_count}");
assert_eq!(cached_count, 3, "Only 3 messages should have cache_control");
for (idx, msg) in messages.iter().enumerate() {
let has_cache = matches!(
&msg.content,
MessageParamContent::Array(blocks)
if blocks.last().is_some_and(|b| {
matches!(b, ContentBlock::Text(t) if t.cache_control.is_some())
})
);
if idx < 2 {
assert!(
!has_cache,
"Message {idx} should have cache_control CLEARED"
);
} else {
assert!(has_cache, "Message {idx} should have cache_control");
}
}
}
}