1use std::collections::HashMap;
4
5use rust_decimal::Decimal;
6use serde::{Deserialize, Serialize};
7
8use super::ContentBlock;
9use super::citations::{Citation, SearchResultLocationCitation};
10
11#[derive(Clone, Debug, Default, Serialize, Deserialize)]
12pub struct TokenUsage {
13 pub input_tokens: u64,
14 pub output_tokens: u64,
15 #[serde(default)]
16 pub cache_read_input_tokens: u64,
17 #[serde(default)]
18 pub cache_creation_input_tokens: u64,
19}
20
21impl TokenUsage {
22 pub fn total(&self) -> u64 {
23 self.input_tokens + self.output_tokens
24 }
25
26 pub fn context_usage(&self) -> u64 {
27 self.input_tokens + self.cache_read_input_tokens + self.cache_creation_input_tokens
28 }
29
30 pub fn add(&mut self, other: &TokenUsage) {
31 self.input_tokens += other.input_tokens;
32 self.output_tokens += other.output_tokens;
33 self.cache_read_input_tokens += other.cache_read_input_tokens;
34 self.cache_creation_input_tokens += other.cache_creation_input_tokens;
35 }
36
37 pub fn add_usage(&mut self, usage: &Usage) {
38 self.input_tokens += usage.input_tokens as u64;
39 self.output_tokens += usage.output_tokens as u64;
40 self.cache_read_input_tokens += usage.cache_read_input_tokens.unwrap_or(0) as u64;
41 self.cache_creation_input_tokens += usage.cache_creation_input_tokens.unwrap_or(0) as u64;
42 }
43
44 pub fn cache_hit_rate(&self) -> f64 {
45 if self.input_tokens == 0 {
46 return 0.0;
47 }
48 (self.cache_read_input_tokens as f64 / self.input_tokens as f64).clamp(0.0, 1.0)
49 }
50}
51
52impl From<&Usage> for TokenUsage {
53 fn from(usage: &Usage) -> Self {
54 Self {
55 input_tokens: usage.input_tokens as u64,
56 output_tokens: usage.output_tokens as u64,
57 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0) as u64,
58 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0) as u64,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ApiResponse {
65 pub id: String,
66 #[serde(rename = "type")]
67 pub response_type: String,
68 pub role: String,
69 pub content: Vec<ContentBlock>,
70 pub model: String,
71 pub stop_reason: Option<StopReason>,
72 pub stop_sequence: Option<String>,
73 pub usage: Usage,
74 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub context_management: Option<ContextManagementResponse>,
76}
77
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct ContextManagementResponse {
80 #[serde(default)]
81 pub applied_edits: Vec<AppliedEdit>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct AppliedEdit {
86 #[serde(rename = "type")]
87 pub edit_type: String,
88 #[serde(default, skip_serializing_if = "Option::is_none")]
89 pub cleared_tool_uses: Option<u32>,
90 #[serde(default, skip_serializing_if = "Option::is_none")]
91 pub cleared_thinking_turns: Option<u32>,
92 #[serde(default, skip_serializing_if = "Option::is_none")]
93 pub cleared_input_tokens: Option<u64>,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum StopReason {
99 EndTurn,
100 MaxTokens,
101 StopSequence,
102 ToolUse,
103 Refusal,
106}
107
108#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
113pub struct ServerToolUseUsage {
114 #[serde(default)]
116 pub web_search_requests: u32,
117 #[serde(default)]
119 pub web_fetch_requests: u32,
120}
121
122#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
123pub struct Usage {
124 pub input_tokens: u32,
125 pub output_tokens: u32,
126 #[serde(default)]
127 pub cache_read_input_tokens: Option<u32>,
128 #[serde(default)]
129 pub cache_creation_input_tokens: Option<u32>,
130 #[serde(default)]
132 pub server_tool_use: Option<ServerToolUseUsage>,
133}
134
135impl Usage {
136 pub fn total(&self) -> u32 {
137 self.input_tokens + self.output_tokens
138 }
139
140 pub fn context_usage(&self) -> u32 {
141 self.input_tokens
142 + self.cache_read_input_tokens.unwrap_or(0)
143 + self.cache_creation_input_tokens.unwrap_or(0)
144 }
145
146 pub fn estimated_cost(&self, model: &str) -> Decimal {
147 crate::budget::pricing::global_pricing_table().calculate(model, self)
148 }
149
150 pub fn server_web_search_requests(&self) -> u32 {
152 self.server_tool_use
153 .as_ref()
154 .map(|s| s.web_search_requests)
155 .unwrap_or(0)
156 }
157
158 pub fn server_web_fetch_requests(&self) -> u32 {
160 self.server_tool_use
161 .as_ref()
162 .map(|s| s.web_fetch_requests)
163 .unwrap_or(0)
164 }
165
166 pub fn has_server_tool_use(&self) -> bool {
168 self.server_tool_use
169 .as_ref()
170 .map(|s| s.web_search_requests > 0 || s.web_fetch_requests > 0)
171 .unwrap_or(false)
172 }
173}
174
175impl ApiResponse {
176 pub fn text(&self) -> String {
177 self.content
178 .iter()
179 .filter_map(|block| block.as_text())
180 .collect::<Vec<_>>()
181 .join("")
182 }
183
184 pub fn wants_tool_use(&self) -> bool {
185 self.stop_reason == Some(StopReason::ToolUse)
186 }
187
188 pub fn tool_uses(&self) -> Vec<&super::ToolUseBlock> {
189 self.content
190 .iter()
191 .filter_map(|block| match block {
192 ContentBlock::ToolUse(tool_use) => Some(tool_use),
193 _ => None,
194 })
195 .collect()
196 }
197
198 pub fn thinking_blocks(&self) -> Vec<&super::ThinkingBlock> {
199 self.content
200 .iter()
201 .filter_map(|block| block.as_thinking())
202 .collect()
203 }
204
205 pub fn has_thinking(&self) -> bool {
206 self.content.iter().any(|block| block.is_thinking())
207 }
208
209 pub fn all_citations(&self) -> Vec<&Citation> {
210 self.content
211 .iter()
212 .filter_map(|block| block.citations())
213 .flatten()
214 .collect()
215 }
216
217 pub fn has_citations(&self) -> bool {
218 self.content.iter().any(|block| block.has_citations())
219 }
220
221 pub fn citations_by_document(&self) -> HashMap<usize, Vec<&Citation>> {
222 let mut map: HashMap<usize, Vec<&Citation>> = HashMap::new();
223 for citation in self.all_citations() {
224 if let Some(doc_idx) = citation.document_index() {
225 map.entry(doc_idx).or_default().push(citation);
226 }
227 }
228 map
229 }
230
231 pub fn search_citations(&self) -> Vec<&SearchResultLocationCitation> {
232 self.all_citations()
233 .into_iter()
234 .filter_map(|c| match c {
235 Citation::SearchResultLocation(src) => Some(src),
236 _ => None,
237 })
238 .collect()
239 }
240
241 pub fn applied_edits(&self) -> &[AppliedEdit] {
242 self.context_management
243 .as_ref()
244 .map(|cm| cm.applied_edits.as_slice())
245 .unwrap_or_default()
246 }
247
248 pub fn cleared_tokens(&self) -> u64 {
249 self.applied_edits()
250 .iter()
251 .filter_map(|e| e.cleared_input_tokens)
252 .sum()
253 }
254
255 pub fn server_tool_uses(&self) -> Vec<&super::ServerToolUseBlock> {
257 self.content
258 .iter()
259 .filter_map(|block| block.as_server_tool_use())
260 .collect()
261 }
262
263 pub fn has_server_tool_use(&self) -> bool {
265 self.content.iter().any(|block| block.is_server_tool_use())
266 }
267
268 pub fn server_web_search_requests(&self) -> u32 {
270 self.usage.server_web_search_requests()
271 }
272
273 pub fn server_web_fetch_requests(&self) -> u32 {
275 self.usage.server_web_fetch_requests()
276 }
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280#[serde(tag = "type", rename_all = "snake_case")]
281pub enum StreamEvent {
282 MessageStart {
283 message: MessageStartData,
284 },
285 ContentBlockStart {
286 index: usize,
287 content_block: ContentBlock,
288 },
289 ContentBlockDelta {
290 index: usize,
291 delta: ContentDelta,
292 },
293 ContentBlockStop {
294 index: usize,
295 },
296 MessageDelta {
297 delta: MessageDeltaData,
298 usage: Usage,
299 },
300 MessageStop,
301 Ping,
302 Error {
303 error: StreamError,
304 },
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct MessageStartData {
309 pub id: String,
310 #[serde(rename = "type")]
311 pub message_type: String,
312 pub role: String,
313 pub model: String,
314 pub usage: Usage,
315}
316
317#[derive(Debug, Clone, Serialize, Deserialize)]
318#[serde(tag = "type", rename_all = "snake_case")]
319pub enum ContentDelta {
320 TextDelta { text: String },
321 InputJsonDelta { partial_json: String },
322 ThinkingDelta { thinking: String },
323 SignatureDelta { signature: String },
324 CitationsDelta { citation: Citation },
325}
326
327impl ContentDelta {
328 pub fn is_citation(&self) -> bool {
329 matches!(self, Self::CitationsDelta { .. })
330 }
331
332 pub fn as_citation(&self) -> Option<&Citation> {
333 match self {
334 Self::CitationsDelta { citation } => Some(citation),
335 _ => None,
336 }
337 }
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct MessageDeltaData {
342 pub stop_reason: Option<StopReason>,
343 pub stop_sequence: Option<String>,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub struct StreamError {
348 #[serde(rename = "type")]
349 pub error_type: String,
350 pub message: String,
351}
352
353#[derive(Debug, Clone)]
354pub enum CompactResult {
355 NotNeeded,
356 Compacted {
357 original_count: usize,
358 new_count: usize,
359 saved_tokens: usize,
360 summary: String,
361 },
362 Skipped {
363 reason: String,
364 },
365}
366
367#[derive(Debug, Clone, Default, Serialize, Deserialize)]
373pub struct ModelUsage {
374 pub input_tokens: u32,
376 pub output_tokens: u32,
378 #[serde(default)]
380 pub cache_read_input_tokens: u32,
381 #[serde(default)]
383 pub cache_creation_input_tokens: u32,
384 #[serde(default)]
386 pub web_search_requests: u32,
387 #[serde(default)]
389 pub web_fetch_requests: u32,
390 #[serde(default)]
392 pub cost_usd: Decimal,
393 #[serde(default)]
395 pub context_window: u64,
396}
397
398impl ModelUsage {
399 pub fn from_usage(usage: &Usage, model: &str) -> Self {
400 let cost = usage.estimated_cost(model);
401 let context_window = crate::models::context_window::for_model(model);
402 let (web_search, web_fetch) = usage
403 .server_tool_use
404 .as_ref()
405 .map(|s| (s.web_search_requests, s.web_fetch_requests))
406 .unwrap_or((0, 0));
407 Self {
408 input_tokens: usage.input_tokens,
409 output_tokens: usage.output_tokens,
410 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
411 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
412 web_search_requests: web_search,
413 web_fetch_requests: web_fetch,
414 cost_usd: cost,
415 context_window,
416 }
417 }
418
419 pub fn add(&mut self, other: &ModelUsage) {
420 self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
421 self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
422 self.cache_read_input_tokens = self
423 .cache_read_input_tokens
424 .saturating_add(other.cache_read_input_tokens);
425 self.cache_creation_input_tokens = self
426 .cache_creation_input_tokens
427 .saturating_add(other.cache_creation_input_tokens);
428 self.web_search_requests = self
429 .web_search_requests
430 .saturating_add(other.web_search_requests);
431 self.web_fetch_requests = self
432 .web_fetch_requests
433 .saturating_add(other.web_fetch_requests);
434 self.cost_usd += other.cost_usd;
435 }
436
437 pub fn add_usage(&mut self, usage: &Usage, model: &str) {
438 self.input_tokens = self.input_tokens.saturating_add(usage.input_tokens);
439 self.output_tokens = self.output_tokens.saturating_add(usage.output_tokens);
440 self.cache_read_input_tokens = self
441 .cache_read_input_tokens
442 .saturating_add(usage.cache_read_input_tokens.unwrap_or(0));
443 self.cache_creation_input_tokens = self
444 .cache_creation_input_tokens
445 .saturating_add(usage.cache_creation_input_tokens.unwrap_or(0));
446 self.cost_usd += usage.estimated_cost(model);
447 if let Some(ref server_usage) = usage.server_tool_use {
448 self.web_search_requests = self
449 .web_search_requests
450 .saturating_add(server_usage.web_search_requests);
451 self.web_fetch_requests = self
452 .web_fetch_requests
453 .saturating_add(server_usage.web_fetch_requests);
454 }
455 }
456
457 pub fn total_tokens(&self) -> u32 {
458 self.input_tokens + self.output_tokens
459 }
460}
461
462#[derive(Debug, Clone, Default, Serialize, Deserialize)]
477pub struct ServerToolUse {
478 pub web_search_requests: u32,
480 pub web_fetch_requests: u32,
482}
483
484impl ServerToolUse {
485 pub fn record_web_search(&mut self) {
487 self.web_search_requests += 1;
488 }
489
490 pub fn record_web_fetch(&mut self) {
492 self.web_fetch_requests += 1;
493 }
494
495 pub fn has_usage(&self) -> bool {
497 self.web_search_requests > 0 || self.web_fetch_requests > 0
498 }
499
500 pub fn add_from_usage(&mut self, usage: &ServerToolUseUsage) {
502 self.web_search_requests += usage.web_search_requests;
503 self.web_fetch_requests += usage.web_fetch_requests;
504 }
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct PermissionDenial {
513 pub tool_name: String,
515 pub tool_use_id: String,
517 pub tool_input: serde_json::Value,
519 #[serde(default, skip_serializing_if = "Option::is_none")]
521 pub reason: Option<String>,
522 #[serde(default, skip_serializing_if = "Option::is_none")]
524 pub timestamp: Option<chrono::DateTime<chrono::Utc>>,
525}
526
527impl PermissionDenial {
528 pub fn new(
530 tool_name: impl Into<String>,
531 tool_use_id: impl Into<String>,
532 tool_input: serde_json::Value,
533 ) -> Self {
534 Self {
535 tool_name: tool_name.into(),
536 tool_use_id: tool_use_id.into(),
537 tool_input,
538 reason: None,
539 timestamp: Some(chrono::Utc::now()),
540 }
541 }
542
543 pub fn reason(mut self, reason: impl Into<String>) -> Self {
545 self.reason = Some(reason.into());
546 self
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use rust_decimal_macros::dec;
553
554 use super::*;
555
556 #[test]
557 fn test_usage_total() {
558 let usage = Usage {
559 input_tokens: 100,
560 output_tokens: 50,
561 ..Default::default()
562 };
563 assert_eq!(usage.total(), 150);
564 }
565
566 #[test]
567 fn test_usage_cost() {
568 let usage = Usage {
569 input_tokens: 1_000_000,
570 output_tokens: 1_000_000,
571 ..Default::default()
572 };
573 let cost = usage.estimated_cost("claude-sonnet-4-5");
575 assert_eq!(cost, dec!(21));
576 }
577
578 #[test]
579 fn test_model_usage_from_usage() {
580 let usage = Usage {
581 input_tokens: 1000,
582 output_tokens: 500,
583 cache_read_input_tokens: Some(100),
584 cache_creation_input_tokens: Some(50),
585 ..Default::default()
586 };
587 let model_usage = ModelUsage::from_usage(&usage, "claude-sonnet-4-5");
588 assert_eq!(model_usage.input_tokens, 1000);
589 assert_eq!(model_usage.output_tokens, 500);
590 assert_eq!(model_usage.cache_read_input_tokens, 100);
591 assert!(model_usage.cost_usd > Decimal::ZERO);
592 }
593
594 #[test]
595 fn test_model_usage_add() {
596 let mut usage1 = ModelUsage {
597 input_tokens: 100,
598 output_tokens: 50,
599 cost_usd: dec!(0.01),
600 ..Default::default()
601 };
602 let usage2 = ModelUsage {
603 input_tokens: 200,
604 output_tokens: 100,
605 cost_usd: dec!(0.02),
606 ..Default::default()
607 };
608 usage1.add(&usage2);
609 assert_eq!(usage1.input_tokens, 300);
610 assert_eq!(usage1.output_tokens, 150);
611 assert_eq!(usage1.cost_usd, dec!(0.03));
612 }
613
614 #[test]
615 fn test_server_tool_use() {
616 let mut stu = ServerToolUse::default();
617 assert!(!stu.has_usage());
618
619 stu.record_web_search();
620 assert!(stu.has_usage());
621 assert_eq!(stu.web_search_requests, 1);
622
623 stu.record_web_fetch();
624 assert_eq!(stu.web_fetch_requests, 1);
625 }
626
627 #[test]
628 fn test_permission_denial() {
629 let denial = PermissionDenial::new(
630 "WebSearch",
631 "tool_123",
632 serde_json::json!({"query": "test"}),
633 )
634 .reason("User denied");
635
636 assert_eq!(denial.tool_name, "WebSearch");
637 assert_eq!(denial.reason, Some("User denied".to_string()));
638 assert!(denial.timestamp.is_some());
639 }
640
641 #[test]
642 fn test_server_tool_use_usage_parsing() {
643 let json = r#"{
644 "input_tokens": 1000,
645 "output_tokens": 500,
646 "server_tool_use": {
647 "web_search_requests": 3,
648 "web_fetch_requests": 2
649 }
650 }"#;
651 let usage: Usage = serde_json::from_str(json).unwrap();
652 assert_eq!(usage.input_tokens, 1000);
653 assert_eq!(usage.output_tokens, 500);
654 assert!(usage.has_server_tool_use());
655 assert_eq!(usage.server_web_search_requests(), 3);
656 assert_eq!(usage.server_web_fetch_requests(), 2);
657 }
658
659 #[test]
660 fn test_server_tool_use_usage_empty() {
661 let json = r#"{
662 "input_tokens": 100,
663 "output_tokens": 50
664 }"#;
665 let usage: Usage = serde_json::from_str(json).unwrap();
666 assert!(!usage.has_server_tool_use());
667 assert_eq!(usage.server_web_search_requests(), 0);
668 assert_eq!(usage.server_web_fetch_requests(), 0);
669 }
670
671 #[test]
672 fn test_server_tool_use_add_from_usage() {
673 let mut stu = ServerToolUse::default();
674 let usage = ServerToolUseUsage {
675 web_search_requests: 2,
676 web_fetch_requests: 1,
677 };
678 stu.add_from_usage(&usage);
679 assert_eq!(stu.web_search_requests, 2);
680 assert_eq!(stu.web_fetch_requests, 1);
681
682 stu.add_from_usage(&usage);
684 assert_eq!(stu.web_search_requests, 4);
685 assert_eq!(stu.web_fetch_requests, 2);
686 }
687}