1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::ContentBlock;
8use super::citations::{Citation, SearchResultLocationCitation};
9
10#[derive(Clone, Debug, Default, Serialize, Deserialize)]
11pub struct TokenUsage {
12 pub input_tokens: u64,
13 pub output_tokens: u64,
14 #[serde(default)]
15 pub cache_read_input_tokens: u64,
16 #[serde(default)]
17 pub cache_creation_input_tokens: u64,
18}
19
20impl TokenUsage {
21 pub fn total(&self) -> u64 {
22 self.input_tokens + self.output_tokens
23 }
24
25 pub fn context_usage(&self) -> u64 {
26 self.input_tokens + self.cache_read_input_tokens + self.cache_creation_input_tokens
27 }
28
29 pub fn add(&mut self, other: &TokenUsage) {
30 self.input_tokens += other.input_tokens;
31 self.output_tokens += other.output_tokens;
32 self.cache_read_input_tokens += other.cache_read_input_tokens;
33 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
34 }
35
36 pub fn add_usage(&mut self, usage: &Usage) {
37 self.input_tokens += usage.input_tokens as u64;
38 self.output_tokens += usage.output_tokens as u64;
39 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0) as u64;
40 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0) as u64;
41 }
42
43 pub fn cache_hit_rate(&self) -> f64 {
44 if self.input_tokens == 0 {
45 return 0.0;
46 }
47 self.cache_read_input_tokens as f64 / self.input_tokens as f64
48 }
49}
50
51impl From<&Usage> for TokenUsage {
52 fn from(usage: &Usage) -> Self {
53 Self {
54 input_tokens: usage.input_tokens as u64,
55 output_tokens: usage.output_tokens as u64,
56 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u64,
57 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u64,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ApiResponse {
64 pub id: String,
65 #[serde(rename = "type")]
66 pub response_type: String,
67 pub role: String,
68 pub content: Vec<ContentBlock>,
69 pub model: String,
70 pub stop_reason: Option<StopReason>,
71 pub stop_sequence: Option<String>,
72 pub usage: Usage,
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub context_management: Option<ContextManagementResponse>,
75}
76
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct ContextManagementResponse {
79 #[serde(default)]
80 pub applied_edits: Vec<AppliedEdit>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AppliedEdit {
85 #[serde(rename = "type")]
86 pub edit_type: String,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub cleared_tool_uses: Option<u32>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub cleared_thinking_turns: Option<u32>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub cleared_input_tokens: Option<u64>,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
96#[serde(rename_all = "snake_case")]
97pub enum StopReason {
98 EndTurn,
99 MaxTokens,
100 StopSequence,
101 ToolUse,
102}
103
104#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
109pub struct ServerToolUseUsage {
110 #[serde(default)]
112 pub web_search_requests: u32,
113 #[serde(default)]
115 pub web_fetch_requests: u32,
116}
117
118#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
119pub struct Usage {
120 pub input_tokens: u32,
121 pub output_tokens: u32,
122 #[serde(default)]
123 pub cache_read_input_tokens: Option<u32>,
124 #[serde(default)]
125 pub cache_creation_input_tokens: Option<u32>,
126 #[serde(default)]
128 pub server_tool_use: Option<ServerToolUseUsage>,
129}
130
131impl Usage {
132 pub fn total(&self) -> u32 {
133 self.input_tokens + self.output_tokens
134 }
135
136 pub fn context_usage(&self) -> u32 {
137 self.input_tokens
138 + self.cache_read_input_tokens.unwrap_or(0)
139 + self.cache_creation_input_tokens.unwrap_or(0)
140 }
141
142 pub fn estimated_cost(&self, model: &str) -> f64 {
143 crate::budget::pricing::global_pricing_table().calculate(model, self)
144 }
145
146 pub fn server_web_search_requests(&self) -> u32 {
148 self.server_tool_use
149 .as_ref()
150 .map(|s| s.web_search_requests)
151 .unwrap_or(0)
152 }
153
154 pub fn server_web_fetch_requests(&self) -> u32 {
156 self.server_tool_use
157 .as_ref()
158 .map(|s| s.web_fetch_requests)
159 .unwrap_or(0)
160 }
161
162 pub fn has_server_tool_use(&self) -> bool {
164 self.server_tool_use
165 .as_ref()
166 .map(|s| s.web_search_requests > 0 || s.web_fetch_requests > 0)
167 .unwrap_or(false)
168 }
169}
170
171impl ApiResponse {
172 pub fn text(&self) -> String {
173 self.content
174 .iter()
175 .filter_map(|block| block.as_text())
176 .collect::<Vec<_>>()
177 .join("")
178 }
179
180 pub fn wants_tool_use(&self) -> bool {
181 self.stop_reason == Some(StopReason::ToolUse)
182 }
183
184 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
185 self.content
186 .iter()
187 .filter_map(|block| match block {
188 ContentBlock::ToolUse(tool_use) => Some(tool_use),
189 _ => None,
190 })
191 .collect()
192 }
193
194 pub fn thinking_blocks(&self) -> Vec<&super::ThinkingBlock> {
195 self.content
196 .iter()
197 .filter_map(|block| block.as_thinking())
198 .collect()
199 }
200
201 pub fn has_thinking(&self) -> bool {
202 self.content.iter().any(|block| block.is_thinking())
203 }
204
205 pub fn all_citations(&self) -> Vec<&Citation> {
206 self.content
207 .iter()
208 .filter_map(|block| block.citations())
209 .flatten()
210 .collect()
211 }
212
213 pub fn has_citations(&self) -> bool {
214 self.content.iter().any(|block| block.has_citations())
215 }
216
217 pub fn citations_by_document(&self) -> HashMap<usize, Vec<&Citation>> {
218 let mut map: HashMap<usize, Vec<&Citation>> = HashMap::new();
219 for citation in self.all_citations() {
220 if let Some(doc_idx) = citation.document_index() {
221 map.entry(doc_idx).or_default().push(citation);
222 }
223 }
224 map
225 }
226
227 pub fn search_citations(&self) -> Vec<&SearchResultLocationCitation> {
228 self.all_citations()
229 .into_iter()
230 .filter_map(|c| match c {
231 Citation::SearchResultLocation(src) => Some(src),
232 _ => None,
233 })
234 .collect()
235 }
236
237 pub fn applied_edits(&self) -> &[AppliedEdit] {
238 self.context_management
239 .as_ref()
240 .map(|cm| cm.applied_edits.as_slice())
241 .unwrap_or_default()
242 }
243
244 pub fn cleared_tokens(&self) -> u64 {
245 self.applied_edits()
246 .iter()
247 .filter_map(|e| e.cleared_input_tokens)
248 .sum()
249 }
250
251 pub fn server_tool_uses(&self) -> Vec<&super::ServerToolUseBlock> {
253 self.content
254 .iter()
255 .filter_map(|block| block.as_server_tool_use())
256 .collect()
257 }
258
259 pub fn has_server_tool_use(&self) -> bool {
261 self.content.iter().any(|block| block.is_server_tool_use())
262 }
263
264 pub fn server_web_search_requests(&self) -> u32 {
266 self.usage.server_web_search_requests()
267 }
268
269 pub fn server_web_fetch_requests(&self) -> u32 {
271 self.usage.server_web_fetch_requests()
272 }
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276#[serde(tag = "type", rename_all = "snake_case")]
277pub enum StreamEvent {
278 MessageStart {
279 message: MessageStartData,
280 },
281 ContentBlockStart {
282 index: usize,
283 content_block: ContentBlock,
284 },
285 ContentBlockDelta {
286 index: usize,
287 delta: ContentDelta,
288 },
289 ContentBlockStop {
290 index: usize,
291 },
292 MessageDelta {
293 delta: MessageDeltaData,
294 usage: Usage,
295 },
296 MessageStop,
297 Ping,
298 Error {
299 error: StreamError,
300 },
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct MessageStartData {
305 pub id: String,
306 #[serde(rename = "type")]
307 pub message_type: String,
308 pub role: String,
309 pub model: String,
310 pub usage: Usage,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize)]
314#[serde(tag = "type", rename_all = "snake_case")]
315pub enum ContentDelta {
316 TextDelta { text: String },
317 InputJsonDelta { partial_json: String },
318 ThinkingDelta { thinking: String },
319 SignatureDelta { signature: String },
320 CitationsDelta { citation: Citation },
321}
322
323impl ContentDelta {
324 pub fn is_citation(&self) -> bool {
325 matches!(self, Self::CitationsDelta { .. })
326 }
327
328 pub fn as_citation(&self) -> Option<&Citation> {
329 match self {
330 Self::CitationsDelta { citation } => Some(citation),
331 _ => None,
332 }
333 }
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct MessageDeltaData {
338 pub stop_reason: Option<StopReason>,
339 pub stop_sequence: Option<String>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct StreamError {
344 #[serde(rename = "type")]
345 pub error_type: String,
346 pub message: String,
347}
348
349#[derive(Debug, Clone)]
350pub enum CompactResult {
351 NotNeeded,
352 Compacted {
353 original_count: usize,
354 new_count: usize,
355 saved_tokens: usize,
356 summary: String,
357 },
358 Skipped {
359 reason: String,
360 },
361}
362
363#[derive(Debug, Clone, Default, Serialize, Deserialize)]
369pub struct ModelUsage {
370 pub input_tokens: u32,
372 pub output_tokens: u32,
374 #[serde(default)]
376 pub cache_read_input_tokens: u32,
377 #[serde(default)]
379 pub cache_creation_input_tokens: u32,
380 #[serde(default)]
382 pub web_search_requests: u32,
383 #[serde(default)]
385 pub web_fetch_requests: u32,
386 #[serde(default)]
388 pub cost_usd: f64,
389 #[serde(default)]
391 pub context_window: u32,
392}
393
394impl ModelUsage {
395 pub fn from_usage(usage: &Usage, model: &str) -> Self {
396 let cost = usage.estimated_cost(model);
397 let context_window = crate::models::context_window::for_model(model);
398 let (web_search, web_fetch) = usage
399 .server_tool_use
400 .as_ref()
401 .map(|s| (s.web_search_requests, s.web_fetch_requests))
402 .unwrap_or((0, 0));
403 Self {
404 input_tokens: usage.input_tokens,
405 output_tokens: usage.output_tokens,
406 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
407 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
408 web_search_requests: web_search,
409 web_fetch_requests: web_fetch,
410 cost_usd: cost,
411 context_window: context_window as u32,
412 }
413 }
414
415 pub fn add(&mut self, other: &ModelUsage) {
416 self.input_tokens += other.input_tokens;
417 self.output_tokens += other.output_tokens;
418 self.cache_read_input_tokens += other.cache_read_input_tokens;
419 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
420 self.web_search_requests += other.web_search_requests;
421 self.web_fetch_requests += other.web_fetch_requests;
422 self.cost_usd += other.cost_usd;
423 }
424
425 pub fn add_usage(&mut self, usage: &Usage, model: &str) {
426 self.input_tokens += usage.input_tokens;
427 self.output_tokens += usage.output_tokens;
428 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0);
429 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
430 self.cost_usd += usage.estimated_cost(model);
431 if let Some(ref server_usage) = usage.server_tool_use {
432 self.web_search_requests += server_usage.web_search_requests;
433 self.web_fetch_requests += server_usage.web_fetch_requests;
434 }
435 }
436
437 pub fn total_tokens(&self) -> u32 {
438 self.input_tokens + self.output_tokens
439 }
440}
441
442#[derive(Debug, Clone, Default, Serialize, Deserialize)]
457pub struct ServerToolUse {
458 pub web_search_requests: u32,
460 pub web_fetch_requests: u32,
462}
463
464impl ServerToolUse {
465 pub fn record_web_search(&mut self) {
467 self.web_search_requests += 1;
468 }
469
470 pub fn record_web_fetch(&mut self) {
472 self.web_fetch_requests += 1;
473 }
474
475 pub fn has_usage(&self) -> bool {
477 self.web_search_requests > 0 || self.web_fetch_requests > 0
478 }
479
480 pub fn add_from_usage(&mut self, usage: &ServerToolUseUsage) {
482 self.web_search_requests += usage.web_search_requests;
483 self.web_fetch_requests += usage.web_fetch_requests;
484 }
485}
486
487#[derive(Debug, Clone, Serialize, Deserialize)]
492pub struct PermissionDenial {
493 pub tool_name: String,
495 pub tool_use_id: String,
497 pub tool_input: serde_json::Value,
499 #[serde(default, skip_serializing_if = "Option::is_none")]
501 pub reason: Option<String>,
502 #[serde(default, skip_serializing_if = "Option::is_none")]
504 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
505}
506
507impl PermissionDenial {
508 pub fn new(
510 tool_name: impl Into<String>,
511 tool_use_id: impl Into<String>,
512 tool_input: serde_json::Value,
513 ) -> Self {
514 Self {
515 tool_name: tool_name.into(),
516 tool_use_id: tool_use_id.into(),
517 tool_input,
518 reason: None,
519 timestamp: Some(chrono::Utc::now()),
520 }
521 }
522
523 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
525 self.reason = Some(reason.into());
526 self
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn test_usage_total() {
536 let usage = Usage {
537 input_tokens: 100,
538 output_tokens: 50,
539 ..Default::default()
540 };
541 assert_eq!(usage.total(), 150);
542 }
543
544 #[test]
545 fn test_usage_cost() {
546 let usage = Usage {
547 input_tokens: 1_000_000,
548 output_tokens: 1_000_000,
549 ..Default::default()
550 };
551 let cost = usage.estimated_cost("claude-sonnet-4-5");
553 assert!((cost - 18.0).abs() < 0.01);
554 }
555
556 #[test]
557 fn test_model_usage_from_usage() {
558 let usage = Usage {
559 input_tokens: 1000,
560 output_tokens: 500,
561 cache_read_input_tokens: Some(100),
562 cache_creation_input_tokens: Some(50),
563 ..Default::default()
564 };
565 let model_usage = ModelUsage::from_usage(&usage, "claude-sonnet-4-5");
566 assert_eq!(model_usage.input_tokens, 1000);
567 assert_eq!(model_usage.output_tokens, 500);
568 assert_eq!(model_usage.cache_read_input_tokens, 100);
569 assert!(model_usage.cost_usd > 0.0);
570 }
571
572 #[test]
573 fn test_model_usage_add() {
574 let mut usage1 = ModelUsage {
575 input_tokens: 100,
576 output_tokens: 50,
577 cost_usd: 0.01,
578 ..Default::default()
579 };
580 let usage2 = ModelUsage {
581 input_tokens: 200,
582 output_tokens: 100,
583 cost_usd: 0.02,
584 ..Default::default()
585 };
586 usage1.add(&usage2);
587 assert_eq!(usage1.input_tokens, 300);
588 assert_eq!(usage1.output_tokens, 150);
589 assert!((usage1.cost_usd - 0.03).abs() < 0.0001);
590 }
591
592 #[test]
593 fn test_server_tool_use() {
594 let mut stu = ServerToolUse::default();
595 assert!(!stu.has_usage());
596
597 stu.record_web_search();
598 assert!(stu.has_usage());
599 assert_eq!(stu.web_search_requests, 1);
600
601 stu.record_web_fetch();
602 assert_eq!(stu.web_fetch_requests, 1);
603 }
604
605 #[test]
606 fn test_permission_denial() {
607 let denial = PermissionDenial::new(
608 "WebSearch",
609 "tool_123",
610 serde_json::json!({"query": "test"}),
611 )
612 .with_reason("User denied");
613
614 assert_eq!(denial.tool_name, "WebSearch");
615 assert_eq!(denial.reason, Some("User denied".to_string()));
616 assert!(denial.timestamp.is_some());
617 }
618
619 #[test]
620 fn test_server_tool_use_usage_parsing() {
621 let json = r#"{
622 "input_tokens": 1000,
623 "output_tokens": 500,
624 "server_tool_use": {
625 "web_search_requests": 3,
626 "web_fetch_requests": 2
627 }
628 }"#;
629 let usage: Usage = serde_json::from_str(json).unwrap();
630 assert_eq!(usage.input_tokens, 1000);
631 assert_eq!(usage.output_tokens, 500);
632 assert!(usage.has_server_tool_use());
633 assert_eq!(usage.server_web_search_requests(), 3);
634 assert_eq!(usage.server_web_fetch_requests(), 2);
635 }
636
637 #[test]
638 fn test_server_tool_use_usage_empty() {
639 let json = r#"{
640 "input_tokens": 100,
641 "output_tokens": 50
642 }"#;
643 let usage: Usage = serde_json::from_str(json).unwrap();
644 assert!(!usage.has_server_tool_use());
645 assert_eq!(usage.server_web_search_requests(), 0);
646 assert_eq!(usage.server_web_fetch_requests(), 0);
647 }
648
649 #[test]
650 fn test_server_tool_use_add_from_usage() {
651 let mut stu = ServerToolUse::default();
652 let usage = ServerToolUseUsage {
653 web_search_requests: 2,
654 web_fetch_requests: 1,
655 };
656 stu.add_from_usage(&usage);
657 assert_eq!(stu.web_search_requests, 2);
658 assert_eq!(stu.web_fetch_requests, 1);
659
660 stu.add_from_usage(&usage);
662 assert_eq!(stu.web_search_requests, 4);
663 assert_eq!(stu.web_fetch_requests, 2);
664 }
665}