1use serde::{Deserialize, Serialize};
4
5#[non_exhaustive]
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum StopReason {
8 #[serde(rename = "stop")]
9 Stop,
10 #[serde(rename = "length")]
11 Length,
12 #[serde(rename = "tool_use")]
13 ToolUse,
14 #[serde(rename = "error")]
15 Error,
16 #[serde(rename = "aborted")]
17 Aborted,
18}
19
20impl StopReason {
21 pub fn as_str(self) -> &'static str {
22 match self {
23 Self::Stop => "stop",
24 Self::Length => "length",
25 Self::ToolUse => "tool_use",
26 Self::Error => "error",
27 Self::Aborted => "aborted",
28 }
29 }
30}
31
32#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
33pub struct Usage {
34 pub input_tokens: u32,
35 pub output_tokens: u32,
36 #[serde(default)]
37 pub cache_read_tokens: u32,
38 #[serde(default)]
39 pub cache_write_tokens: u32,
40}
41
42impl Usage {
43 pub fn total_tokens(&self) -> u64 {
44 self.input_tokens as u64
45 + self.output_tokens as u64
46 + self.cache_read_tokens as u64
47 + self.cache_write_tokens as u64
48 }
49}
50
51#[derive(Debug, Clone, Default, PartialEq)]
53pub struct CumulativeUsage {
54 input_tokens: u64,
55 output_tokens: u64,
56 cache_read_tokens: u64,
57 cache_write_tokens: u64,
58 turns: u32,
59}
60
61impl CumulativeUsage {
62 pub fn from_totals(
64 input_tokens: u64,
65 output_tokens: u64,
66 cache_read_tokens: u64,
67 cache_write_tokens: u64,
68 turns: u32,
69 ) -> Self {
70 Self {
71 input_tokens,
72 output_tokens,
73 cache_read_tokens,
74 cache_write_tokens,
75 turns,
76 }
77 }
78
79 pub fn total_input_tokens(&self) -> u64 {
80 self.input_tokens
81 }
82
83 pub fn total_output_tokens(&self) -> u64 {
84 self.output_tokens
85 }
86
87 pub fn total_cache_read_tokens(&self) -> u64 {
88 self.cache_read_tokens
89 }
90
91 pub fn total_cache_write_tokens(&self) -> u64 {
92 self.cache_write_tokens
93 }
94
95 pub fn turn_count(&self) -> u32 {
96 self.turns
97 }
98
99 pub fn accumulate(&mut self, turn: &Usage) {
100 self.input_tokens += turn.input_tokens as u64;
101 self.output_tokens += turn.output_tokens as u64;
102 self.cache_read_tokens += turn.cache_read_tokens as u64;
103 self.cache_write_tokens += turn.cache_write_tokens as u64;
104 self.turns += 1;
105 }
106
107 pub fn as_usage(&self) -> Usage {
108 Usage {
109 input_tokens: self.input_tokens as u32,
110 output_tokens: self.output_tokens as u32,
111 cache_read_tokens: self.cache_read_tokens as u32,
112 cache_write_tokens: self.cache_write_tokens as u32,
113 }
114 }
115}
116
117#[derive(Debug, Clone, Copy, Default, PartialEq)]
119pub struct Pricing {
120 pub input_cost_per_mtok: f64,
121 pub output_cost_per_mtok: f64,
122 pub cache_read_cost_per_mtok: f64,
123 pub cache_write_cost_per_mtok: f64,
124}
125
126#[derive(Debug, Clone, Copy, Default, PartialEq)]
128pub struct CostBreakdown {
129 pub input_cost: f64,
130 pub output_cost: f64,
131 pub cache_read_cost: f64,
132 pub cache_write_cost: f64,
133}
134
135impl CostBreakdown {
136 pub fn total_cost(&self) -> f64 {
137 self.input_cost + self.output_cost + self.cache_read_cost + self.cache_write_cost
138 }
139}
140
141pub fn calculate_cost(usage: &Usage, pricing: &Pricing) -> CostBreakdown {
143 let per_tok = |cost_per_mtok: f64| cost_per_mtok / 1_000_000.0;
144 CostBreakdown {
145 input_cost: usage.input_tokens as f64 * per_tok(pricing.input_cost_per_mtok),
146 output_cost: usage.output_tokens as f64 * per_tok(pricing.output_cost_per_mtok),
147 cache_read_cost: usage.cache_read_tokens as f64 * per_tok(pricing.cache_read_cost_per_mtok),
148 cache_write_cost: usage.cache_write_tokens as f64
149 * per_tok(pricing.cache_write_cost_per_mtok),
150 }
151}
152
153#[non_exhaustive]
154#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155#[serde(tag = "type")]
156pub enum AssistantStreamEvent {
157 #[serde(rename = "start")]
158 Start {
159 partial: crate::message::AssistantMessage,
160 },
161 #[serde(rename = "text_start")]
162 TextStart {
163 content_index: usize,
164 partial: crate::message::AssistantMessage,
165 },
166 #[serde(rename = "text_delta")]
167 TextDelta {
168 content_index: usize,
169 delta: String,
170 partial: crate::message::AssistantMessage,
171 },
172 #[serde(rename = "text_end")]
173 TextEnd {
174 content_index: usize,
175 content: String,
176 partial: crate::message::AssistantMessage,
177 },
178 #[serde(rename = "thinking_start")]
179 ThinkingStart {
180 content_index: usize,
181 partial: crate::message::AssistantMessage,
182 },
183 #[serde(rename = "thinking_delta")]
184 ThinkingDelta {
185 content_index: usize,
186 delta: String,
187 partial: crate::message::AssistantMessage,
188 },
189 #[serde(rename = "thinking_end")]
190 ThinkingEnd {
191 content_index: usize,
192 content: String,
193 partial: crate::message::AssistantMessage,
194 },
195 #[serde(rename = "tool_call_start")]
196 ToolCallStart {
197 content_index: usize,
198 partial: crate::message::AssistantMessage,
199 },
200 #[serde(rename = "tool_call_delta")]
201 ToolCallDelta {
202 content_index: usize,
203 delta: String,
204 partial: crate::message::AssistantMessage,
205 },
206 #[serde(rename = "tool_call_end")]
207 ToolCallEnd {
208 content_index: usize,
209 tool_call: crate::message::ToolCall,
210 partial: crate::message::AssistantMessage,
211 },
212 #[serde(rename = "done")]
213 Done {
214 reason: StopReason,
215 message: crate::message::AssistantMessage,
216 },
217 #[serde(rename = "error")]
218 Error {
219 reason: StopReason,
220 message: crate::message::AssistantMessage,
221 },
222}
223
224impl AssistantStreamEvent {
225 pub fn is_terminal(&self) -> bool {
226 matches!(self, Self::Done { .. } | Self::Error { .. })
227 }
228}