Skip to main content

opi_ai/
stream.rs

1//! Streaming response events (S7.3).
2
3use serde::{Deserialize, Serialize};
4
5#[non_exhaustive]
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum StopReason {
8    #[serde(rename = "stop")]
9    Stop,
10    #[serde(rename = "length")]
11    Length,
12    #[serde(rename = "tool_use")]
13    ToolUse,
14    #[serde(rename = "error")]
15    Error,
16    #[serde(rename = "aborted")]
17    Aborted,
18}
19
20impl StopReason {
21    pub fn as_str(self) -> &'static str {
22        match self {
23            Self::Stop => "stop",
24            Self::Length => "length",
25            Self::ToolUse => "tool_use",
26            Self::Error => "error",
27            Self::Aborted => "aborted",
28        }
29    }
30}
31
32#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
33pub struct Usage {
34    pub input_tokens: u32,
35    pub output_tokens: u32,
36    #[serde(default)]
37    pub cache_read_tokens: u32,
38    #[serde(default)]
39    pub cache_write_tokens: u32,
40}
41
42impl Usage {
43    pub fn total_tokens(&self) -> u64 {
44        self.input_tokens as u64
45            + self.output_tokens as u64
46            + self.cache_read_tokens as u64
47            + self.cache_write_tokens as u64
48    }
49}
50
51/// Accumulated usage across multiple turns.
52#[derive(Debug, Clone, Default, PartialEq)]
53pub struct CumulativeUsage {
54    input_tokens: u64,
55    output_tokens: u64,
56    cache_read_tokens: u64,
57    cache_write_tokens: u64,
58    turns: u32,
59}
60
61impl CumulativeUsage {
62    /// Construct from pre-computed totals (e.g. when replaying a session).
63    pub fn from_totals(
64        input_tokens: u64,
65        output_tokens: u64,
66        cache_read_tokens: u64,
67        cache_write_tokens: u64,
68        turns: u32,
69    ) -> Self {
70        Self {
71            input_tokens,
72            output_tokens,
73            cache_read_tokens,
74            cache_write_tokens,
75            turns,
76        }
77    }
78
79    pub fn total_input_tokens(&self) -> u64 {
80        self.input_tokens
81    }
82
83    pub fn total_output_tokens(&self) -> u64 {
84        self.output_tokens
85    }
86
87    pub fn total_cache_read_tokens(&self) -> u64 {
88        self.cache_read_tokens
89    }
90
91    pub fn total_cache_write_tokens(&self) -> u64 {
92        self.cache_write_tokens
93    }
94
95    pub fn turn_count(&self) -> u32 {
96        self.turns
97    }
98
99    pub fn accumulate(&mut self, turn: &Usage) {
100        self.input_tokens += turn.input_tokens as u64;
101        self.output_tokens += turn.output_tokens as u64;
102        self.cache_read_tokens += turn.cache_read_tokens as u64;
103        self.cache_write_tokens += turn.cache_write_tokens as u64;
104        self.turns += 1;
105    }
106
107    pub fn as_usage(&self) -> Usage {
108        Usage {
109            input_tokens: self.input_tokens as u32,
110            output_tokens: self.output_tokens as u32,
111            cache_read_tokens: self.cache_read_tokens as u32,
112            cache_write_tokens: self.cache_write_tokens as u32,
113        }
114    }
115}
116
117/// Per-million-token pricing for a model (USD).
118#[derive(Debug, Clone, Copy, Default, PartialEq)]
119pub struct Pricing {
120    pub input_cost_per_mtok: f64,
121    pub output_cost_per_mtok: f64,
122    pub cache_read_cost_per_mtok: f64,
123    pub cache_write_cost_per_mtok: f64,
124}
125
126/// Cost breakdown from a usage + pricing calculation.
127#[derive(Debug, Clone, Copy, Default, PartialEq)]
128pub struct CostBreakdown {
129    pub input_cost: f64,
130    pub output_cost: f64,
131    pub cache_read_cost: f64,
132    pub cache_write_cost: f64,
133}
134
135impl CostBreakdown {
136    pub fn total_cost(&self) -> f64 {
137        self.input_cost + self.output_cost + self.cache_read_cost + self.cache_write_cost
138    }
139}
140
141/// Calculate cost from usage and pricing.
142pub fn calculate_cost(usage: &Usage, pricing: &Pricing) -> CostBreakdown {
143    let per_tok = |cost_per_mtok: f64| cost_per_mtok / 1_000_000.0;
144    CostBreakdown {
145        input_cost: usage.input_tokens as f64 * per_tok(pricing.input_cost_per_mtok),
146        output_cost: usage.output_tokens as f64 * per_tok(pricing.output_cost_per_mtok),
147        cache_read_cost: usage.cache_read_tokens as f64 * per_tok(pricing.cache_read_cost_per_mtok),
148        cache_write_cost: usage.cache_write_tokens as f64
149            * per_tok(pricing.cache_write_cost_per_mtok),
150    }
151}
152
153#[non_exhaustive]
154#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
155#[serde(tag = "type")]
156pub enum AssistantStreamEvent {
157    #[serde(rename = "start")]
158    Start {
159        partial: crate::message::AssistantMessage,
160    },
161    #[serde(rename = "text_start")]
162    TextStart {
163        content_index: usize,
164        partial: crate::message::AssistantMessage,
165    },
166    #[serde(rename = "text_delta")]
167    TextDelta {
168        content_index: usize,
169        delta: String,
170        partial: crate::message::AssistantMessage,
171    },
172    #[serde(rename = "text_end")]
173    TextEnd {
174        content_index: usize,
175        content: String,
176        partial: crate::message::AssistantMessage,
177    },
178    #[serde(rename = "thinking_start")]
179    ThinkingStart {
180        content_index: usize,
181        partial: crate::message::AssistantMessage,
182    },
183    #[serde(rename = "thinking_delta")]
184    ThinkingDelta {
185        content_index: usize,
186        delta: String,
187        partial: crate::message::AssistantMessage,
188    },
189    #[serde(rename = "thinking_end")]
190    ThinkingEnd {
191        content_index: usize,
192        content: String,
193        partial: crate::message::AssistantMessage,
194    },
195    #[serde(rename = "tool_call_start")]
196    ToolCallStart {
197        content_index: usize,
198        partial: crate::message::AssistantMessage,
199    },
200    #[serde(rename = "tool_call_delta")]
201    ToolCallDelta {
202        content_index: usize,
203        delta: String,
204        partial: crate::message::AssistantMessage,
205    },
206    #[serde(rename = "tool_call_end")]
207    ToolCallEnd {
208        content_index: usize,
209        tool_call: crate::message::ToolCall,
210        partial: crate::message::AssistantMessage,
211    },
212    #[serde(rename = "done")]
213    Done {
214        reason: StopReason,
215        message: crate::message::AssistantMessage,
216    },
217    #[serde(rename = "error")]
218    Error {
219        reason: StopReason,
220        message: crate::message::AssistantMessage,
221    },
222}
223
224impl AssistantStreamEvent {
225    pub fn is_terminal(&self) -> bool {
226        matches!(self, Self::Done { .. } | Self::Error { .. })
227    }
228}