1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::ContentBlock;
8use super::citations::{Citation, SearchResultLocationCitation};
9
10#[derive(Clone, Debug, Default, Serialize, Deserialize)]
11pub struct TokenUsage {
12 pub input_tokens: u64,
13 pub output_tokens: u64,
14 #[serde(default)]
15 pub cache_read_input_tokens: u64,
16 #[serde(default)]
17 pub cache_creation_input_tokens: u64,
18}
19
20impl TokenUsage {
21 pub fn total(&self) -> u64 {
22 self.input_tokens + self.output_tokens
23 }
24
25 pub fn add(&mut self, other: &TokenUsage) {
26 self.input_tokens += other.input_tokens;
27 self.output_tokens += other.output_tokens;
28 self.cache_read_input_tokens += other.cache_read_input_tokens;
29 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
30 }
31
32 pub fn add_usage(&mut self, usage: &Usage) {
33 self.input_tokens += usage.input_tokens as u64;
34 self.output_tokens += usage.output_tokens as u64;
35 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0) as u64;
36 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0) as u64;
37 }
38
39 pub fn cache_hit_rate(&self) -> f64 {
40 if self.input_tokens == 0 {
41 return 0.0;
42 }
43 self.cache_read_input_tokens as f64 / self.input_tokens as f64
44 }
45}
46
47impl From<&Usage> for TokenUsage {
48 fn from(usage: &Usage) -> Self {
49 Self {
50 input_tokens: usage.input_tokens as u64,
51 output_tokens: usage.output_tokens as u64,
52 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u64,
53 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u64,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct ApiResponse {
60 pub id: String,
61 #[serde(rename = "type")]
62 pub response_type: String,
63 pub role: String,
64 pub content: Vec<ContentBlock>,
65 pub model: String,
66 pub stop_reason: Option<StopReason>,
67 pub stop_sequence: Option<String>,
68 pub usage: Usage,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub context_management: Option<ContextManagementResponse>,
71}
72
73#[derive(Debug, Clone, Default, Serialize, Deserialize)]
74pub struct ContextManagementResponse {
75 #[serde(default)]
76 pub applied_edits: Vec<AppliedEdit>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct AppliedEdit {
81 #[serde(rename = "type")]
82 pub edit_type: String,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
84 pub cleared_tool_uses: Option<u32>,
85 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub cleared_thinking_turns: Option<u32>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub cleared_input_tokens: Option<u64>,
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
92#[serde(rename_all = "snake_case")]
93pub enum StopReason {
94 EndTurn,
95 MaxTokens,
96 StopSequence,
97 ToolUse,
98}
99
100#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
105pub struct ServerToolUseUsage {
106 #[serde(default)]
108 pub web_search_requests: u32,
109 #[serde(default)]
111 pub web_fetch_requests: u32,
112}
113
114#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
115pub struct Usage {
116 pub input_tokens: u32,
117 pub output_tokens: u32,
118 #[serde(default)]
119 pub cache_read_input_tokens: Option<u32>,
120 #[serde(default)]
121 pub cache_creation_input_tokens: Option<u32>,
122 #[serde(default)]
124 pub server_tool_use: Option<ServerToolUseUsage>,
125}
126
127impl Usage {
128 pub fn total(&self) -> u32 {
129 self.input_tokens + self.output_tokens
130 }
131
132 pub fn estimated_cost(&self, model: &str) -> f64 {
133 crate::budget::pricing::global_pricing_table().calculate(model, self)
134 }
135
136 pub fn server_web_search_requests(&self) -> u32 {
138 self.server_tool_use
139 .as_ref()
140 .map(|s| s.web_search_requests)
141 .unwrap_or(0)
142 }
143
144 pub fn server_web_fetch_requests(&self) -> u32 {
146 self.server_tool_use
147 .as_ref()
148 .map(|s| s.web_fetch_requests)
149 .unwrap_or(0)
150 }
151
152 pub fn has_server_tool_use(&self) -> bool {
154 self.server_tool_use
155 .as_ref()
156 .map(|s| s.web_search_requests > 0 || s.web_fetch_requests > 0)
157 .unwrap_or(false)
158 }
159}
160
161impl ApiResponse {
162 pub fn text(&self) -> String {
163 self.content
164 .iter()
165 .filter_map(|block| block.as_text())
166 .collect::<Vec<_>>()
167 .join("")
168 }
169
170 pub fn wants_tool_use(&self) -> bool {
171 self.stop_reason == Some(StopReason::ToolUse)
172 }
173
174 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
175 self.content
176 .iter()
177 .filter_map(|block| match block {
178 ContentBlock::ToolUse(tool_use) => Some(tool_use),
179 _ => None,
180 })
181 .collect()
182 }
183
184 pub fn thinking_blocks(&self) -> Vec<&super::ThinkingBlock> {
185 self.content
186 .iter()
187 .filter_map(|block| block.as_thinking())
188 .collect()
189 }
190
191 pub fn has_thinking(&self) -> bool {
192 self.content.iter().any(|block| block.is_thinking())
193 }
194
195 pub fn all_citations(&self) -> Vec<&Citation> {
196 self.content
197 .iter()
198 .filter_map(|block| block.citations())
199 .flatten()
200 .collect()
201 }
202
203 pub fn has_citations(&self) -> bool {
204 self.content.iter().any(|block| block.has_citations())
205 }
206
207 pub fn citations_by_document(&self) -> HashMap<usize, Vec<&Citation>> {
208 let mut map: HashMap<usize, Vec<&Citation>> = HashMap::new();
209 for citation in self.all_citations() {
210 if let Some(doc_idx) = citation.document_index() {
211 map.entry(doc_idx).or_default().push(citation);
212 }
213 }
214 map
215 }
216
217 pub fn search_citations(&self) -> Vec<&SearchResultLocationCitation> {
218 self.all_citations()
219 .into_iter()
220 .filter_map(|c| match c {
221 Citation::SearchResultLocation(src) => Some(src),
222 _ => None,
223 })
224 .collect()
225 }
226
227 pub fn applied_edits(&self) -> &[AppliedEdit] {
228 self.context_management
229 .as_ref()
230 .map(|cm| cm.applied_edits.as_slice())
231 .unwrap_or_default()
232 }
233
234 pub fn cleared_tokens(&self) -> u64 {
235 self.applied_edits()
236 .iter()
237 .filter_map(|e| e.cleared_input_tokens)
238 .sum()
239 }
240
241 pub fn server_tool_uses(&self) -> Vec<&super::ServerToolUseBlock> {
243 self.content
244 .iter()
245 .filter_map(|block| block.as_server_tool_use())
246 .collect()
247 }
248
249 pub fn has_server_tool_use(&self) -> bool {
251 self.content.iter().any(|block| block.is_server_tool_use())
252 }
253
254 pub fn server_web_search_requests(&self) -> u32 {
256 self.usage.server_web_search_requests()
257 }
258
259 pub fn server_web_fetch_requests(&self) -> u32 {
261 self.usage.server_web_fetch_requests()
262 }
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266#[serde(tag = "type", rename_all = "snake_case")]
267pub enum StreamEvent {
268 MessageStart {
269 message: MessageStartData,
270 },
271 ContentBlockStart {
272 index: usize,
273 content_block: ContentBlock,
274 },
275 ContentBlockDelta {
276 index: usize,
277 delta: ContentDelta,
278 },
279 ContentBlockStop {
280 index: usize,
281 },
282 MessageDelta {
283 delta: MessageDeltaData,
284 usage: Usage,
285 },
286 MessageStop,
287 Ping,
288 Error {
289 error: StreamError,
290 },
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct MessageStartData {
295 pub id: String,
296 #[serde(rename = "type")]
297 pub message_type: String,
298 pub role: String,
299 pub model: String,
300 pub usage: Usage,
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
304#[serde(tag = "type", rename_all = "snake_case")]
305pub enum ContentDelta {
306 TextDelta { text: String },
307 InputJsonDelta { partial_json: String },
308 ThinkingDelta { thinking: String },
309 SignatureDelta { signature: String },
310 CitationsDelta { citation: Citation },
311}
312
313impl ContentDelta {
314 pub fn is_citation(&self) -> bool {
315 matches!(self, Self::CitationsDelta { .. })
316 }
317
318 pub fn as_citation(&self) -> Option<&Citation> {
319 match self {
320 Self::CitationsDelta { citation } => Some(citation),
321 _ => None,
322 }
323 }
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct MessageDeltaData {
328 pub stop_reason: Option<StopReason>,
329 pub stop_sequence: Option<String>,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct StreamError {
334 #[serde(rename = "type")]
335 pub error_type: String,
336 pub message: String,
337}
338
339#[derive(Debug, Clone)]
340pub enum CompactResult {
341 NotNeeded,
342 Compacted {
343 original_count: usize,
344 new_count: usize,
345 saved_tokens: usize,
346 summary: String,
347 },
348 Skipped {
349 reason: String,
350 },
351}
352
353#[derive(Debug, Clone, Default, Serialize, Deserialize)]
359pub struct ModelUsage {
360 pub input_tokens: u32,
362 pub output_tokens: u32,
364 #[serde(default)]
366 pub cache_read_input_tokens: u32,
367 #[serde(default)]
369 pub cache_creation_input_tokens: u32,
370 #[serde(default)]
372 pub web_search_requests: u32,
373 #[serde(default)]
375 pub web_fetch_requests: u32,
376 #[serde(default)]
378 pub cost_usd: f64,
379 #[serde(default)]
381 pub context_window: u32,
382}
383
384impl ModelUsage {
385 pub fn from_usage(usage: &Usage, model: &str) -> Self {
386 let cost = usage.estimated_cost(model);
387 let context_window = crate::types::models::context_window::for_model(model);
388 let (web_search, web_fetch) = usage
389 .server_tool_use
390 .as_ref()
391 .map(|s| (s.web_search_requests, s.web_fetch_requests))
392 .unwrap_or((0, 0));
393 Self {
394 input_tokens: usage.input_tokens,
395 output_tokens: usage.output_tokens,
396 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
397 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
398 web_search_requests: web_search,
399 web_fetch_requests: web_fetch,
400 cost_usd: cost,
401 context_window: context_window as u32,
402 }
403 }
404
405 pub fn add(&mut self, other: &ModelUsage) {
406 self.input_tokens += other.input_tokens;
407 self.output_tokens += other.output_tokens;
408 self.cache_read_input_tokens += other.cache_read_input_tokens;
409 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
410 self.web_search_requests += other.web_search_requests;
411 self.web_fetch_requests += other.web_fetch_requests;
412 self.cost_usd += other.cost_usd;
413 }
414
415 pub fn add_usage(&mut self, usage: &Usage, model: &str) {
416 self.input_tokens += usage.input_tokens;
417 self.output_tokens += usage.output_tokens;
418 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0);
419 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
420 self.cost_usd += usage.estimated_cost(model);
421 if let Some(ref server_usage) = usage.server_tool_use {
422 self.web_search_requests += server_usage.web_search_requests;
423 self.web_fetch_requests += server_usage.web_fetch_requests;
424 }
425 }
426
427 pub fn total_tokens(&self) -> u32 {
428 self.input_tokens + self.output_tokens
429 }
430}
431
432#[derive(Debug, Clone, Default, Serialize, Deserialize)]
447pub struct ServerToolUse {
448 pub web_search_requests: u32,
450 pub web_fetch_requests: u32,
452}
453
454impl ServerToolUse {
455 pub fn record_web_search(&mut self) {
457 self.web_search_requests += 1;
458 }
459
460 pub fn record_web_fetch(&mut self) {
462 self.web_fetch_requests += 1;
463 }
464
465 pub fn has_usage(&self) -> bool {
467 self.web_search_requests > 0 || self.web_fetch_requests > 0
468 }
469
470 pub fn add_from_usage(&mut self, usage: &ServerToolUseUsage) {
472 self.web_search_requests += usage.web_search_requests;
473 self.web_fetch_requests += usage.web_fetch_requests;
474 }
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct PermissionDenial {
483 pub tool_name: String,
485 pub tool_use_id: String,
487 pub tool_input: serde_json::Value,
489 #[serde(default, skip_serializing_if = "Option::is_none")]
491 pub reason: Option<String>,
492 #[serde(default, skip_serializing_if = "Option::is_none")]
494 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
495}
496
497impl PermissionDenial {
498 pub fn new(
500 tool_name: impl Into<String>,
501 tool_use_id: impl Into<String>,
502 tool_input: serde_json::Value,
503 ) -> Self {
504 Self {
505 tool_name: tool_name.into(),
506 tool_use_id: tool_use_id.into(),
507 tool_input,
508 reason: None,
509 timestamp: Some(chrono::Utc::now()),
510 }
511 }
512
513 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
515 self.reason = Some(reason.into());
516 self
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_usage_total() {
526 let usage = Usage {
527 input_tokens: 100,
528 output_tokens: 50,
529 ..Default::default()
530 };
531 assert_eq!(usage.total(), 150);
532 }
533
534 #[test]
535 fn test_usage_cost() {
536 let usage = Usage {
537 input_tokens: 1_000_000,
538 output_tokens: 1_000_000,
539 ..Default::default()
540 };
541 let cost = usage.estimated_cost("claude-sonnet-4-5");
543 assert!((cost - 18.0).abs() < 0.01);
544 }
545
546 #[test]
547 fn test_model_usage_from_usage() {
548 let usage = Usage {
549 input_tokens: 1000,
550 output_tokens: 500,
551 cache_read_input_tokens: Some(100),
552 cache_creation_input_tokens: Some(50),
553 ..Default::default()
554 };
555 let model_usage = ModelUsage::from_usage(&usage, "claude-sonnet-4-5");
556 assert_eq!(model_usage.input_tokens, 1000);
557 assert_eq!(model_usage.output_tokens, 500);
558 assert_eq!(model_usage.cache_read_input_tokens, 100);
559 assert!(model_usage.cost_usd > 0.0);
560 }
561
562 #[test]
563 fn test_model_usage_add() {
564 let mut usage1 = ModelUsage {
565 input_tokens: 100,
566 output_tokens: 50,
567 cost_usd: 0.01,
568 ..Default::default()
569 };
570 let usage2 = ModelUsage {
571 input_tokens: 200,
572 output_tokens: 100,
573 cost_usd: 0.02,
574 ..Default::default()
575 };
576 usage1.add(&usage2);
577 assert_eq!(usage1.input_tokens, 300);
578 assert_eq!(usage1.output_tokens, 150);
579 assert!((usage1.cost_usd - 0.03).abs() < 0.0001);
580 }
581
582 #[test]
583 fn test_server_tool_use() {
584 let mut stu = ServerToolUse::default();
585 assert!(!stu.has_usage());
586
587 stu.record_web_search();
588 assert!(stu.has_usage());
589 assert_eq!(stu.web_search_requests, 1);
590
591 stu.record_web_fetch();
592 assert_eq!(stu.web_fetch_requests, 1);
593 }
594
595 #[test]
596 fn test_permission_denial() {
597 let denial = PermissionDenial::new(
598 "WebSearch",
599 "tool_123",
600 serde_json::json!({"query": "test"}),
601 )
602 .with_reason("User denied");
603
604 assert_eq!(denial.tool_name, "WebSearch");
605 assert_eq!(denial.reason, Some("User denied".to_string()));
606 assert!(denial.timestamp.is_some());
607 }
608
609 #[test]
610 fn test_server_tool_use_usage_parsing() {
611 let json = r#"{
612 "input_tokens": 1000,
613 "output_tokens": 500,
614 "server_tool_use": {
615 "web_search_requests": 3,
616 "web_fetch_requests": 2
617 }
618 }"#;
619 let usage: Usage = serde_json::from_str(json).unwrap();
620 assert_eq!(usage.input_tokens, 1000);
621 assert_eq!(usage.output_tokens, 500);
622 assert!(usage.has_server_tool_use());
623 assert_eq!(usage.server_web_search_requests(), 3);
624 assert_eq!(usage.server_web_fetch_requests(), 2);
625 }
626
627 #[test]
628 fn test_server_tool_use_usage_empty() {
629 let json = r#"{
630 "input_tokens": 100,
631 "output_tokens": 50
632 }"#;
633 let usage: Usage = serde_json::from_str(json).unwrap();
634 assert!(!usage.has_server_tool_use());
635 assert_eq!(usage.server_web_search_requests(), 0);
636 assert_eq!(usage.server_web_fetch_requests(), 0);
637 }
638
639 #[test]
640 fn test_server_tool_use_add_from_usage() {
641 let mut stu = ServerToolUse::default();
642 let usage = ServerToolUseUsage {
643 web_search_requests: 2,
644 web_fetch_requests: 1,
645 };
646 stu.add_from_usage(&usage);
647 assert_eq!(stu.web_search_requests, 2);
648 assert_eq!(stu.web_fetch_requests, 1);
649
650 stu.add_from_usage(&usage);
652 assert_eq!(stu.web_search_requests, 4);
653 assert_eq!(stu.web_fetch_requests, 2);
654 }
655}