1use std::borrow::Cow;
12use std::fmt;
13
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub struct ModelId(Cow<'static, str>);
32
33impl ModelId {
34 pub const OPUS_4_7: ModelId = ModelId(Cow::Borrowed("claude-opus-4-7"));
36 pub const SONNET_4_6: ModelId = ModelId(Cow::Borrowed("claude-sonnet-4-6"));
38 pub const HAIKU_4_5: ModelId = ModelId(Cow::Borrowed("claude-haiku-4-5-20251001"));
40
41 pub fn custom(s: impl Into<String>) -> Self {
43 Self(Cow::Owned(s.into()))
44 }
45
46 pub fn as_str(&self) -> &str {
48 &self.0
49 }
50}
51
52impl fmt::Display for ModelId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.write_str(&self.0)
55 }
56}
57
58impl AsRef<str> for ModelId {
59 fn as_ref(&self) -> &str {
60 &self.0
61 }
62}
63
64impl From<&'static str> for ModelId {
65 fn from(s: &'static str) -> Self {
66 Self(Cow::Borrowed(s))
67 }
68}
69
70impl From<String> for ModelId {
71 fn from(s: String) -> Self {
72 Self(Cow::Owned(s))
73 }
74}
75
76impl Serialize for ModelId {
77 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
78 s.serialize_str(&self.0)
79 }
80}
81
82impl<'de> Deserialize<'de> for ModelId {
83 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
84 String::deserialize(d).map(Self::from)
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[serde(rename_all = "lowercase")]
91pub enum Role {
92 User,
94 Assistant,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
103#[serde(rename_all = "snake_case")]
104pub enum StopReason {
105 EndTurn,
107 MaxTokens,
109 StopSequence,
111 ToolUse,
113 PauseTurn,
115 Refusal,
117 #[serde(other)]
119 Other,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
127#[serde(rename_all = "snake_case")]
128pub enum ServiceTier {
129 Standard,
131 Priority,
133 Batch,
135 #[serde(other)]
137 Other,
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
145#[non_exhaustive]
146pub struct Usage {
147 pub input_tokens: u32,
149 pub output_tokens: u32,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub cache_creation_input_tokens: Option<u32>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
156 pub cache_read_input_tokens: Option<u32>,
157 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub cache_creation: Option<CacheCreationBreakdown>,
160 #[serde(default, skip_serializing_if = "Option::is_none")]
162 pub server_tool_use: Option<ServerToolUseUsage>,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub service_tier: Option<ServiceTier>,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
170 pub inference_geo: Option<String>,
171}
172
173#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
175#[non_exhaustive]
176pub struct CacheCreationBreakdown {
177 #[serde(default)]
179 pub ephemeral_5m_input_tokens: u32,
180 #[serde(default)]
182 pub ephemeral_1h_input_tokens: u32,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
187#[non_exhaustive]
188pub struct ServerToolUseUsage {
189 #[serde(default)]
191 pub web_search_requests: u32,
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use pretty_assertions::assert_eq;
198 use serde::de::DeserializeOwned;
199
200 fn round_trip<T>(value: &T, expected_json: &str)
201 where
202 T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
203 {
204 let json = serde_json::to_string(value).expect("serialize");
205 assert_eq!(json, expected_json, "serialized form mismatch");
206 let parsed: T = serde_json::from_str(&json).expect("deserialize");
207 assert_eq!(&parsed, value, "round-trip mismatch");
208 }
209
210 #[test]
211 fn model_id_serializes_as_string() {
212 round_trip(&ModelId::OPUS_4_7, "\"claude-opus-4-7\"");
213 round_trip(&ModelId::SONNET_4_6, "\"claude-sonnet-4-6\"");
214 round_trip(&ModelId::HAIKU_4_5, "\"claude-haiku-4-5-20251001\"");
215 round_trip(
216 &ModelId::custom("claude-future-foo"),
217 "\"claude-future-foo\"",
218 );
219 }
220
221 #[test]
222 fn model_id_const_equals_custom() {
223 assert_eq!(ModelId::OPUS_4_7, ModelId::custom("claude-opus-4-7"));
224 }
225
226 #[test]
227 fn model_id_display_and_as_ref() {
228 assert_eq!(ModelId::SONNET_4_6.to_string(), "claude-sonnet-4-6");
229 assert_eq!(
230 <ModelId as AsRef<str>>::as_ref(&ModelId::SONNET_4_6),
231 "claude-sonnet-4-6"
232 );
233 }
234
235 #[test]
236 fn role_serializes_lowercase() {
237 round_trip(&Role::User, "\"user\"");
238 round_trip(&Role::Assistant, "\"assistant\"");
239 }
240
241 #[test]
242 fn stop_reason_round_trips_known_variants() {
243 round_trip(&StopReason::EndTurn, "\"end_turn\"");
244 round_trip(&StopReason::MaxTokens, "\"max_tokens\"");
245 round_trip(&StopReason::StopSequence, "\"stop_sequence\"");
246 round_trip(&StopReason::ToolUse, "\"tool_use\"");
247 round_trip(&StopReason::PauseTurn, "\"pause_turn\"");
248 round_trip(&StopReason::Refusal, "\"refusal\"");
249 }
250
251 #[test]
252 fn stop_reason_unknown_falls_back_to_other() {
253 let parsed: StopReason = serde_json::from_str("\"some_new_reason\"").expect("deserialize");
254 assert_eq!(parsed, StopReason::Other);
255 }
256
257 #[test]
258 fn service_tier_unknown_falls_back_to_other() {
259 let parsed: ServiceTier = serde_json::from_str("\"enterprise\"").expect("deserialize");
260 assert_eq!(parsed, ServiceTier::Other);
261 round_trip(&ServiceTier::Standard, "\"standard\"");
262 round_trip(&ServiceTier::Priority, "\"priority\"");
263 round_trip(&ServiceTier::Batch, "\"batch\"");
264 }
265
266 #[test]
267 fn usage_minimal_payload_round_trips() {
268 let u = Usage {
269 input_tokens: 12,
270 output_tokens: 34,
271 ..Usage::default()
272 };
273 round_trip(&u, r#"{"input_tokens":12,"output_tokens":34}"#);
274 }
275
276 #[test]
277 fn usage_full_payload_round_trips() {
278 let u = Usage {
279 input_tokens: 100,
280 output_tokens: 50,
281 cache_creation_input_tokens: Some(20),
282 cache_read_input_tokens: Some(80),
283 cache_creation: Some(CacheCreationBreakdown {
284 ephemeral_5m_input_tokens: 10,
285 ephemeral_1h_input_tokens: 10,
286 }),
287 server_tool_use: Some(ServerToolUseUsage {
288 web_search_requests: 3,
289 }),
290 service_tier: Some(ServiceTier::Standard),
291 inference_geo: Some("us-east-1".into()),
292 };
293 let json = serde_json::to_string(&u).expect("serialize");
294 let parsed: Usage = serde_json::from_str(&json).expect("deserialize");
295 assert_eq!(parsed, u);
296 }
297
298 #[test]
299 fn usage_tolerates_unknown_fields() {
300 let json = r#"{
301 "input_tokens": 5,
302 "output_tokens": 7,
303 "future_field": "ignored"
304 }"#;
305 let parsed: Usage = serde_json::from_str(json).expect("deserialize");
306 assert_eq!(parsed.input_tokens, 5);
307 assert_eq!(parsed.output_tokens, 7);
308 }
309}