1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7use super::ContentBlock;
8use super::citations::{Citation, SearchResultLocationCitation};
9
10#[derive(Clone, Debug, Default, Serialize, Deserialize)]
11pub struct TokenUsage {
12 pub input_tokens: u64,
13 pub output_tokens: u64,
14 #[serde(default)]
15 pub cache_read_input_tokens: u64,
16 #[serde(default)]
17 pub cache_creation_input_tokens: u64,
18}
19
20impl TokenUsage {
21 pub fn total(&self) -> u64 {
22 self.input_tokens + self.output_tokens
23 }
24
25 pub fn context_usage(&self) -> u64 {
26 self.input_tokens + self.cache_read_input_tokens + self.cache_creation_input_tokens
27 }
28
29 pub fn add(&mut self, other: &TokenUsage) {
30 self.input_tokens += other.input_tokens;
31 self.output_tokens += other.output_tokens;
32 self.cache_read_input_tokens += other.cache_read_input_tokens;
33 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
34 }
35
36 pub fn add_usage(&mut self, usage: &Usage) {
37 self.input_tokens += usage.input_tokens as u64;
38 self.output_tokens += usage.output_tokens as u64;
39 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0) as u64;
40 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0) as u64;
41 }
42
43 pub fn cache_hit_rate(&self) -> f64 {
44 if self.input_tokens == 0 {
45 return 0.0;
46 }
47 self.cache_read_input_tokens as f64 / self.input_tokens as f64
48 }
49}
50
51impl From<&Usage> for TokenUsage {
52 fn from(usage: &Usage) -> Self {
53 Self {
54 input_tokens: usage.input_tokens as u64,
55 output_tokens: usage.output_tokens as u64,
56 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u64,
57 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u64,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ApiResponse {
64 pub id: String,
65 #[serde(rename = "type")]
66 pub response_type: String,
67 pub role: String,
68 pub content: Vec<ContentBlock>,
69 pub model: String,
70 pub stop_reason: Option<StopReason>,
71 pub stop_sequence: Option<String>,
72 pub usage: Usage,
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub context_management: Option<ContextManagementResponse>,
75}
76
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct ContextManagementResponse {
79 #[serde(default)]
80 pub applied_edits: Vec<AppliedEdit>,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct AppliedEdit {
85 #[serde(rename = "type")]
86 pub edit_type: String,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
88 pub cleared_tool_uses: Option<u32>,
89 #[serde(default, skip_serializing_if = "Option::is_none")]
90 pub cleared_thinking_turns: Option<u32>,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub cleared_input_tokens: Option<u64>,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
96#[serde(rename_all = "snake_case")]
97pub enum StopReason {
98 EndTurn,
99 MaxTokens,
100 StopSequence,
101 ToolUse,
102 Refusal,
105}
106
107#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
112pub struct ServerToolUseUsage {
113 #[serde(default)]
115 pub web_search_requests: u32,
116 #[serde(default)]
118 pub web_fetch_requests: u32,
119}
120
121#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
122pub struct Usage {
123 pub input_tokens: u32,
124 pub output_tokens: u32,
125 #[serde(default)]
126 pub cache_read_input_tokens: Option<u32>,
127 #[serde(default)]
128 pub cache_creation_input_tokens: Option<u32>,
129 #[serde(default)]
131 pub server_tool_use: Option<ServerToolUseUsage>,
132}
133
134impl Usage {
135 pub fn total(&self) -> u32 {
136 self.input_tokens + self.output_tokens
137 }
138
139 pub fn context_usage(&self) -> u32 {
140 self.input_tokens
141 + self.cache_read_input_tokens.unwrap_or(0)
142 + self.cache_creation_input_tokens.unwrap_or(0)
143 }
144
145 pub fn estimated_cost(&self, model: &str) -> f64 {
146 crate::budget::pricing::global_pricing_table().calculate(model, self)
147 }
148
149 pub fn server_web_search_requests(&self) -> u32 {
151 self.server_tool_use
152 .as_ref()
153 .map(|s| s.web_search_requests)
154 .unwrap_or(0)
155 }
156
157 pub fn server_web_fetch_requests(&self) -> u32 {
159 self.server_tool_use
160 .as_ref()
161 .map(|s| s.web_fetch_requests)
162 .unwrap_or(0)
163 }
164
165 pub fn has_server_tool_use(&self) -> bool {
167 self.server_tool_use
168 .as_ref()
169 .map(|s| s.web_search_requests > 0 || s.web_fetch_requests > 0)
170 .unwrap_or(false)
171 }
172}
173
174impl ApiResponse {
175 pub fn text(&self) -> String {
176 self.content
177 .iter()
178 .filter_map(|block| block.as_text())
179 .collect::<Vec<_>>()
180 .join("")
181 }
182
183 pub fn wants_tool_use(&self) -> bool {
184 self.stop_reason == Some(StopReason::ToolUse)
185 }
186
187 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
188 self.content
189 .iter()
190 .filter_map(|block| match block {
191 ContentBlock::ToolUse(tool_use) => Some(tool_use),
192 _ => None,
193 })
194 .collect()
195 }
196
197 pub fn thinking_blocks(&self) -> Vec<&super::ThinkingBlock> {
198 self.content
199 .iter()
200 .filter_map(|block| block.as_thinking())
201 .collect()
202 }
203
204 pub fn has_thinking(&self) -> bool {
205 self.content.iter().any(|block| block.is_thinking())
206 }
207
208 pub fn all_citations(&self) -> Vec<&Citation> {
209 self.content
210 .iter()
211 .filter_map(|block| block.citations())
212 .flatten()
213 .collect()
214 }
215
216 pub fn has_citations(&self) -> bool {
217 self.content.iter().any(|block| block.has_citations())
218 }
219
220 pub fn citations_by_document(&self) -> HashMap<usize, Vec<&Citation>> {
221 let mut map: HashMap<usize, Vec<&Citation>> = HashMap::new();
222 for citation in self.all_citations() {
223 if let Some(doc_idx) = citation.document_index() {
224 map.entry(doc_idx).or_default().push(citation);
225 }
226 }
227 map
228 }
229
230 pub fn search_citations(&self) -> Vec<&SearchResultLocationCitation> {
231 self.all_citations()
232 .into_iter()
233 .filter_map(|c| match c {
234 Citation::SearchResultLocation(src) => Some(src),
235 _ => None,
236 })
237 .collect()
238 }
239
240 pub fn applied_edits(&self) -> &[AppliedEdit] {
241 self.context_management
242 .as_ref()
243 .map(|cm| cm.applied_edits.as_slice())
244 .unwrap_or_default()
245 }
246
247 pub fn cleared_tokens(&self) -> u64 {
248 self.applied_edits()
249 .iter()
250 .filter_map(|e| e.cleared_input_tokens)
251 .sum()
252 }
253
254 pub fn server_tool_uses(&self) -> Vec<&super::ServerToolUseBlock> {
256 self.content
257 .iter()
258 .filter_map(|block| block.as_server_tool_use())
259 .collect()
260 }
261
262 pub fn has_server_tool_use(&self) -> bool {
264 self.content.iter().any(|block| block.is_server_tool_use())
265 }
266
267 pub fn server_web_search_requests(&self) -> u32 {
269 self.usage.server_web_search_requests()
270 }
271
272 pub fn server_web_fetch_requests(&self) -> u32 {
274 self.usage.server_web_fetch_requests()
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "type", rename_all = "snake_case")]
280pub enum StreamEvent {
281 MessageStart {
282 message: MessageStartData,
283 },
284 ContentBlockStart {
285 index: usize,
286 content_block: ContentBlock,
287 },
288 ContentBlockDelta {
289 index: usize,
290 delta: ContentDelta,
291 },
292 ContentBlockStop {
293 index: usize,
294 },
295 MessageDelta {
296 delta: MessageDeltaData,
297 usage: Usage,
298 },
299 MessageStop,
300 Ping,
301 Error {
302 error: StreamError,
303 },
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct MessageStartData {
308 pub id: String,
309 #[serde(rename = "type")]
310 pub message_type: String,
311 pub role: String,
312 pub model: String,
313 pub usage: Usage,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317#[serde(tag = "type", rename_all = "snake_case")]
318pub enum ContentDelta {
319 TextDelta { text: String },
320 InputJsonDelta { partial_json: String },
321 ThinkingDelta { thinking: String },
322 SignatureDelta { signature: String },
323 CitationsDelta { citation: Citation },
324}
325
326impl ContentDelta {
327 pub fn is_citation(&self) -> bool {
328 matches!(self, Self::CitationsDelta { .. })
329 }
330
331 pub fn as_citation(&self) -> Option<&Citation> {
332 match self {
333 Self::CitationsDelta { citation } => Some(citation),
334 _ => None,
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct MessageDeltaData {
341 pub stop_reason: Option<StopReason>,
342 pub stop_sequence: Option<String>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct StreamError {
347 #[serde(rename = "type")]
348 pub error_type: String,
349 pub message: String,
350}
351
352#[derive(Debug, Clone)]
353pub enum CompactResult {
354 NotNeeded,
355 Compacted {
356 original_count: usize,
357 new_count: usize,
358 saved_tokens: usize,
359 summary: String,
360 },
361 Skipped {
362 reason: String,
363 },
364}
365
366#[derive(Debug, Clone, Default, Serialize, Deserialize)]
372pub struct ModelUsage {
373 pub input_tokens: u32,
375 pub output_tokens: u32,
377 #[serde(default)]
379 pub cache_read_input_tokens: u32,
380 #[serde(default)]
382 pub cache_creation_input_tokens: u32,
383 #[serde(default)]
385 pub web_search_requests: u32,
386 #[serde(default)]
388 pub web_fetch_requests: u32,
389 #[serde(default)]
391 pub cost_usd: f64,
392 #[serde(default)]
394 pub context_window: u32,
395}
396
397impl ModelUsage {
398 pub fn from_usage(usage: &Usage, model: &str) -> Self {
399 let cost = usage.estimated_cost(model);
400 let context_window = crate::models::context_window::for_model(model);
401 let (web_search, web_fetch) = usage
402 .server_tool_use
403 .as_ref()
404 .map(|s| (s.web_search_requests, s.web_fetch_requests))
405 .unwrap_or((0, 0));
406 Self {
407 input_tokens: usage.input_tokens,
408 output_tokens: usage.output_tokens,
409 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
410 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
411 web_search_requests: web_search,
412 web_fetch_requests: web_fetch,
413 cost_usd: cost,
414 context_window: context_window as u32,
415 }
416 }
417
418 pub fn add(&mut self, other: &ModelUsage) {
419 self.input_tokens += other.input_tokens;
420 self.output_tokens += other.output_tokens;
421 self.cache_read_input_tokens += other.cache_read_input_tokens;
422 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
423 self.web_search_requests += other.web_search_requests;
424 self.web_fetch_requests += other.web_fetch_requests;
425 self.cost_usd += other.cost_usd;
426 }
427
428 pub fn add_usage(&mut self, usage: &Usage, model: &str) {
429 self.input_tokens += usage.input_tokens;
430 self.output_tokens += usage.output_tokens;
431 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0);
432 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
433 self.cost_usd += usage.estimated_cost(model);
434 if let Some(ref server_usage) = usage.server_tool_use {
435 self.web_search_requests += server_usage.web_search_requests;
436 self.web_fetch_requests += server_usage.web_fetch_requests;
437 }
438 }
439
440 pub fn total_tokens(&self) -> u32 {
441 self.input_tokens + self.output_tokens
442 }
443}
444
445#[derive(Debug, Clone, Default, Serialize, Deserialize)]
460pub struct ServerToolUse {
461 pub web_search_requests: u32,
463 pub web_fetch_requests: u32,
465}
466
467impl ServerToolUse {
468 pub fn record_web_search(&mut self) {
470 self.web_search_requests += 1;
471 }
472
473 pub fn record_web_fetch(&mut self) {
475 self.web_fetch_requests += 1;
476 }
477
478 pub fn has_usage(&self) -> bool {
480 self.web_search_requests > 0 || self.web_fetch_requests > 0
481 }
482
483 pub fn add_from_usage(&mut self, usage: &ServerToolUseUsage) {
485 self.web_search_requests += usage.web_search_requests;
486 self.web_fetch_requests += usage.web_fetch_requests;
487 }
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct PermissionDenial {
496 pub tool_name: String,
498 pub tool_use_id: String,
500 pub tool_input: serde_json::Value,
502 #[serde(default, skip_serializing_if = "Option::is_none")]
504 pub reason: Option<String>,
505 #[serde(default, skip_serializing_if = "Option::is_none")]
507 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
508}
509
510impl PermissionDenial {
511 pub fn new(
513 tool_name: impl Into<String>,
514 tool_use_id: impl Into<String>,
515 tool_input: serde_json::Value,
516 ) -> Self {
517 Self {
518 tool_name: tool_name.into(),
519 tool_use_id: tool_use_id.into(),
520 tool_input,
521 reason: None,
522 timestamp: Some(chrono::Utc::now()),
523 }
524 }
525
526 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
528 self.reason = Some(reason.into());
529 self
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_usage_total() {
539 let usage = Usage {
540 input_tokens: 100,
541 output_tokens: 50,
542 ..Default::default()
543 };
544 assert_eq!(usage.total(), 150);
545 }
546
547 #[test]
548 fn test_usage_cost() {
549 let usage = Usage {
550 input_tokens: 1_000_000,
551 output_tokens: 1_000_000,
552 ..Default::default()
553 };
554 let cost = usage.estimated_cost("claude-sonnet-4-5");
556 assert!((cost - 18.0).abs() < 0.01);
557 }
558
559 #[test]
560 fn test_model_usage_from_usage() {
561 let usage = Usage {
562 input_tokens: 1000,
563 output_tokens: 500,
564 cache_read_input_tokens: Some(100),
565 cache_creation_input_tokens: Some(50),
566 ..Default::default()
567 };
568 let model_usage = ModelUsage::from_usage(&usage, "claude-sonnet-4-5");
569 assert_eq!(model_usage.input_tokens, 1000);
570 assert_eq!(model_usage.output_tokens, 500);
571 assert_eq!(model_usage.cache_read_input_tokens, 100);
572 assert!(model_usage.cost_usd > 0.0);
573 }
574
575 #[test]
576 fn test_model_usage_add() {
577 let mut usage1 = ModelUsage {
578 input_tokens: 100,
579 output_tokens: 50,
580 cost_usd: 0.01,
581 ..Default::default()
582 };
583 let usage2 = ModelUsage {
584 input_tokens: 200,
585 output_tokens: 100,
586 cost_usd: 0.02,
587 ..Default::default()
588 };
589 usage1.add(&usage2);
590 assert_eq!(usage1.input_tokens, 300);
591 assert_eq!(usage1.output_tokens, 150);
592 assert!((usage1.cost_usd - 0.03).abs() < 0.0001);
593 }
594
595 #[test]
596 fn test_server_tool_use() {
597 let mut stu = ServerToolUse::default();
598 assert!(!stu.has_usage());
599
600 stu.record_web_search();
601 assert!(stu.has_usage());
602 assert_eq!(stu.web_search_requests, 1);
603
604 stu.record_web_fetch();
605 assert_eq!(stu.web_fetch_requests, 1);
606 }
607
608 #[test]
609 fn test_permission_denial() {
610 let denial = PermissionDenial::new(
611 "WebSearch",
612 "tool_123",
613 serde_json::json!({"query": "test"}),
614 )
615 .with_reason("User denied");
616
617 assert_eq!(denial.tool_name, "WebSearch");
618 assert_eq!(denial.reason, Some("User denied".to_string()));
619 assert!(denial.timestamp.is_some());
620 }
621
622 #[test]
623 fn test_server_tool_use_usage_parsing() {
624 let json = r#"{
625 "input_tokens": 1000,
626 "output_tokens": 500,
627 "server_tool_use": {
628 "web_search_requests": 3,
629 "web_fetch_requests": 2
630 }
631 }"#;
632 let usage: Usage = serde_json::from_str(json).unwrap();
633 assert_eq!(usage.input_tokens, 1000);
634 assert_eq!(usage.output_tokens, 500);
635 assert!(usage.has_server_tool_use());
636 assert_eq!(usage.server_web_search_requests(), 3);
637 assert_eq!(usage.server_web_fetch_requests(), 2);
638 }
639
640 #[test]
641 fn test_server_tool_use_usage_empty() {
642 let json = r#"{
643 "input_tokens": 100,
644 "output_tokens": 50
645 }"#;
646 let usage: Usage = serde_json::from_str(json).unwrap();
647 assert!(!usage.has_server_tool_use());
648 assert_eq!(usage.server_web_search_requests(), 0);
649 assert_eq!(usage.server_web_fetch_requests(), 0);
650 }
651
652 #[test]
653 fn test_server_tool_use_add_from_usage() {
654 let mut stu = ServerToolUse::default();
655 let usage = ServerToolUseUsage {
656 web_search_requests: 2,
657 web_fetch_requests: 1,
658 };
659 stu.add_from_usage(&usage);
660 assert_eq!(stu.web_search_requests, 2);
661 assert_eq!(stu.web_fetch_requests, 1);
662
663 stu.add_from_usage(&usage);
665 assert_eq!(stu.web_search_requests, 4);
666 assert_eq!(stu.web_fetch_requests, 2);
667 }
668}