1use std::borrow::Cow;
4use std::fmt;
5
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub struct ModelId(Cow<'static, str>);
24
25impl ModelId {
26 pub const OPUS_4_7: ModelId = ModelId(Cow::Borrowed("claude-opus-4-7"));
28 pub const SONNET_4_6: ModelId = ModelId(Cow::Borrowed("claude-sonnet-4-6"));
30 pub const HAIKU_4_5: ModelId = ModelId(Cow::Borrowed("claude-haiku-4-5-20251001"));
32
33 pub fn custom(s: impl Into<String>) -> Self {
35 Self(Cow::Owned(s.into()))
36 }
37
38 pub fn as_str(&self) -> &str {
40 &self.0
41 }
42}
43
44impl fmt::Display for ModelId {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.write_str(&self.0)
47 }
48}
49
50impl AsRef<str> for ModelId {
51 fn as_ref(&self) -> &str {
52 &self.0
53 }
54}
55
56impl From<&'static str> for ModelId {
57 fn from(s: &'static str) -> Self {
58 Self(Cow::Borrowed(s))
59 }
60}
61
62impl From<String> for ModelId {
63 fn from(s: String) -> Self {
64 Self(Cow::Owned(s))
65 }
66}
67
68impl Serialize for ModelId {
69 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
70 s.serialize_str(&self.0)
71 }
72}
73
74impl<'de> Deserialize<'de> for ModelId {
75 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
76 String::deserialize(d).map(Self::from)
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum Role {
84 User,
86 Assistant,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95#[serde(rename_all = "snake_case")]
96pub enum StopReason {
97 EndTurn,
99 MaxTokens,
101 StopSequence,
103 ToolUse,
105 PauseTurn,
107 Refusal,
109 #[serde(other)]
111 Other,
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum ServiceTier {
121 Standard,
123 Priority,
125 Batch,
127 #[serde(other)]
129 Other,
130}
131
132#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
137#[non_exhaustive]
138pub struct Usage {
139 pub input_tokens: u32,
141 pub output_tokens: u32,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub cache_creation_input_tokens: Option<u32>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub cache_read_input_tokens: Option<u32>,
149 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub cache_creation: Option<CacheCreationBreakdown>,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
154 pub server_tool_use: Option<ServerToolUseUsage>,
155 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub service_tier: Option<ServiceTier>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
162 pub inference_geo: Option<String>,
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
167#[non_exhaustive]
168pub struct CacheCreationBreakdown {
169 #[serde(default)]
171 pub ephemeral_5m_input_tokens: u32,
172 #[serde(default)]
174 pub ephemeral_1h_input_tokens: u32,
175}
176
177#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
179#[non_exhaustive]
180pub struct ServerToolUseUsage {
181 #[serde(default)]
183 pub web_search_requests: u32,
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use pretty_assertions::assert_eq;
190 use serde::de::DeserializeOwned;
191
192 fn round_trip<T>(value: &T, expected_json: &str)
193 where
194 T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
195 {
196 let json = serde_json::to_string(value).expect("serialize");
197 assert_eq!(json, expected_json, "serialized form mismatch");
198 let parsed: T = serde_json::from_str(&json).expect("deserialize");
199 assert_eq!(&parsed, value, "round-trip mismatch");
200 }
201
202 #[test]
203 fn model_id_serializes_as_string() {
204 round_trip(&ModelId::OPUS_4_7, "\"claude-opus-4-7\"");
205 round_trip(&ModelId::SONNET_4_6, "\"claude-sonnet-4-6\"");
206 round_trip(&ModelId::HAIKU_4_5, "\"claude-haiku-4-5-20251001\"");
207 round_trip(
208 &ModelId::custom("claude-future-foo"),
209 "\"claude-future-foo\"",
210 );
211 }
212
213 #[test]
214 fn model_id_const_equals_custom() {
215 assert_eq!(ModelId::OPUS_4_7, ModelId::custom("claude-opus-4-7"));
216 }
217
218 #[test]
219 fn model_id_display_and_as_ref() {
220 assert_eq!(ModelId::SONNET_4_6.to_string(), "claude-sonnet-4-6");
221 assert_eq!(
222 <ModelId as AsRef<str>>::as_ref(&ModelId::SONNET_4_6),
223 "claude-sonnet-4-6"
224 );
225 }
226
227 #[test]
228 fn role_serializes_lowercase() {
229 round_trip(&Role::User, "\"user\"");
230 round_trip(&Role::Assistant, "\"assistant\"");
231 }
232
233 #[test]
234 fn stop_reason_round_trips_known_variants() {
235 round_trip(&StopReason::EndTurn, "\"end_turn\"");
236 round_trip(&StopReason::MaxTokens, "\"max_tokens\"");
237 round_trip(&StopReason::StopSequence, "\"stop_sequence\"");
238 round_trip(&StopReason::ToolUse, "\"tool_use\"");
239 round_trip(&StopReason::PauseTurn, "\"pause_turn\"");
240 round_trip(&StopReason::Refusal, "\"refusal\"");
241 }
242
243 #[test]
244 fn stop_reason_unknown_falls_back_to_other() {
245 let parsed: StopReason = serde_json::from_str("\"some_new_reason\"").expect("deserialize");
246 assert_eq!(parsed, StopReason::Other);
247 }
248
249 #[test]
250 fn service_tier_unknown_falls_back_to_other() {
251 let parsed: ServiceTier = serde_json::from_str("\"enterprise\"").expect("deserialize");
252 assert_eq!(parsed, ServiceTier::Other);
253 round_trip(&ServiceTier::Standard, "\"standard\"");
254 round_trip(&ServiceTier::Priority, "\"priority\"");
255 round_trip(&ServiceTier::Batch, "\"batch\"");
256 }
257
258 #[test]
259 fn usage_minimal_payload_round_trips() {
260 let u = Usage {
261 input_tokens: 12,
262 output_tokens: 34,
263 ..Usage::default()
264 };
265 round_trip(&u, r#"{"input_tokens":12,"output_tokens":34}"#);
266 }
267
268 #[test]
269 fn usage_full_payload_round_trips() {
270 let u = Usage {
271 input_tokens: 100,
272 output_tokens: 50,
273 cache_creation_input_tokens: Some(20),
274 cache_read_input_tokens: Some(80),
275 cache_creation: Some(CacheCreationBreakdown {
276 ephemeral_5m_input_tokens: 10,
277 ephemeral_1h_input_tokens: 10,
278 }),
279 server_tool_use: Some(ServerToolUseUsage {
280 web_search_requests: 3,
281 }),
282 service_tier: Some(ServiceTier::Standard),
283 inference_geo: Some("us-east-1".into()),
284 };
285 let json = serde_json::to_string(&u).expect("serialize");
286 let parsed: Usage = serde_json::from_str(&json).expect("deserialize");
287 assert_eq!(parsed, u);
288 }
289
290 #[test]
291 fn usage_tolerates_unknown_fields() {
292 let json = r#"{
293 "input_tokens": 5,
294 "output_tokens": 7,
295 "future_field": "ignored"
296 }"#;
297 let parsed: Usage = serde_json::from_str(json).expect("deserialize");
298 assert_eq!(parsed.input_tokens, 5);
299 assert_eq!(parsed.output_tokens, 7);
300 }
301}