use serde::{Deserialize, Serialize};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
#[serde(rename = "stop")]
Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_use")]
ToolUse,
#[serde(rename = "error")]
Error,
#[serde(rename = "aborted")]
Aborted,
}
impl StopReason {
pub fn as_str(self) -> &'static str {
match self {
Self::Stop => "stop",
Self::Length => "length",
Self::ToolUse => "tool_use",
Self::Error => "error",
Self::Aborted => "aborted",
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default)]
pub cache_read_tokens: u32,
#[serde(default)]
pub cache_write_tokens: u32,
}
impl Usage {
pub fn total_tokens(&self) -> u64 {
self.input_tokens as u64
+ self.output_tokens as u64
+ self.cache_read_tokens as u64
+ self.cache_write_tokens as u64
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct CumulativeUsage {
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_write_tokens: u64,
turns: u32,
}
impl CumulativeUsage {
pub fn from_totals(
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_write_tokens: u64,
turns: u32,
) -> Self {
Self {
input_tokens,
output_tokens,
cache_read_tokens,
cache_write_tokens,
turns,
}
}
pub fn total_input_tokens(&self) -> u64 {
self.input_tokens
}
pub fn total_output_tokens(&self) -> u64 {
self.output_tokens
}
pub fn total_cache_read_tokens(&self) -> u64 {
self.cache_read_tokens
}
pub fn total_cache_write_tokens(&self) -> u64 {
self.cache_write_tokens
}
pub fn turn_count(&self) -> u32 {
self.turns
}
pub fn accumulate(&mut self, turn: &Usage) {
self.input_tokens += turn.input_tokens as u64;
self.output_tokens += turn.output_tokens as u64;
self.cache_read_tokens += turn.cache_read_tokens as u64;
self.cache_write_tokens += turn.cache_write_tokens as u64;
self.turns += 1;
}
pub fn as_usage(&self) -> Usage {
Usage {
input_tokens: self.input_tokens as u32,
output_tokens: self.output_tokens as u32,
cache_read_tokens: self.cache_read_tokens as u32,
cache_write_tokens: self.cache_write_tokens as u32,
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Pricing {
pub input_cost_per_mtok: f64,
pub output_cost_per_mtok: f64,
pub cache_read_cost_per_mtok: f64,
pub cache_write_cost_per_mtok: f64,
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct CostBreakdown {
pub input_cost: f64,
pub output_cost: f64,
pub cache_read_cost: f64,
pub cache_write_cost: f64,
}
impl CostBreakdown {
pub fn total_cost(&self) -> f64 {
self.input_cost + self.output_cost + self.cache_read_cost + self.cache_write_cost
}
}
pub fn calculate_cost(usage: &Usage, pricing: &Pricing) -> CostBreakdown {
let per_tok = |cost_per_mtok: f64| cost_per_mtok / 1_000_000.0;
CostBreakdown {
input_cost: usage.input_tokens as f64 * per_tok(pricing.input_cost_per_mtok),
output_cost: usage.output_tokens as f64 * per_tok(pricing.output_cost_per_mtok),
cache_read_cost: usage.cache_read_tokens as f64 * per_tok(pricing.cache_read_cost_per_mtok),
cache_write_cost: usage.cache_write_tokens as f64
* per_tok(pricing.cache_write_cost_per_mtok),
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AssistantStreamEvent {
#[serde(rename = "start")]
Start {
partial: crate::message::AssistantMessage,
},
#[serde(rename = "text_start")]
TextStart {
content_index: usize,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "text_delta")]
TextDelta {
content_index: usize,
delta: String,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "text_end")]
TextEnd {
content_index: usize,
content: String,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "thinking_start")]
ThinkingStart {
content_index: usize,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "thinking_delta")]
ThinkingDelta {
content_index: usize,
delta: String,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "thinking_end")]
ThinkingEnd {
content_index: usize,
content: String,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "tool_call_start")]
ToolCallStart {
content_index: usize,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "tool_call_delta")]
ToolCallDelta {
content_index: usize,
delta: String,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "tool_call_end")]
ToolCallEnd {
content_index: usize,
tool_call: crate::message::ToolCall,
partial: crate::message::AssistantMessage,
},
#[serde(rename = "done")]
Done {
reason: StopReason,
message: crate::message::AssistantMessage,
},
#[serde(rename = "error")]
Error {
reason: StopReason,
message: crate::message::AssistantMessage,
},
}
impl AssistantStreamEvent {
pub fn is_terminal(&self) -> bool {
matches!(self, Self::Done { .. } | Self::Error { .. })
}
}