use std::borrow::Cow;
use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelId(Cow<'static, str>);
impl ModelId {
pub const OPUS_4_7: ModelId = ModelId(Cow::Borrowed("claude-opus-4-7"));
pub const SONNET_4_6: ModelId = ModelId(Cow::Borrowed("claude-sonnet-4-6"));
pub const HAIKU_4_5: ModelId = ModelId(Cow::Borrowed("claude-haiku-4-5-20251001"));
pub fn custom(s: impl Into<String>) -> Self {
Self(Cow::Owned(s.into()))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for ModelId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl AsRef<str> for ModelId {
fn as_ref(&self) -> &str {
&self.0
}
}
impl From<&'static str> for ModelId {
fn from(s: &'static str) -> Self {
Self(Cow::Borrowed(s))
}
}
impl From<String> for ModelId {
fn from(s: String) -> Self {
Self(Cow::Owned(s))
}
}
impl Serialize for ModelId {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(&self.0)
}
}
impl<'de> Deserialize<'de> for ModelId {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
String::deserialize(d).map(Self::from)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
StopSequence,
ToolUse,
PauseTurn,
Refusal,
#[serde(other)]
Other,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServiceTier {
Standard,
Priority,
Batch,
#[serde(other)]
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_creation_input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_read_input_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_creation: Option<CacheCreationBreakdown>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub server_tool_use: Option<ServerToolUseUsage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub service_tier: Option<ServiceTier>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub inference_geo: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct CacheCreationBreakdown {
#[serde(default)]
pub ephemeral_5m_input_tokens: u32,
#[serde(default)]
pub ephemeral_1h_input_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ServerToolUseUsage {
#[serde(default)]
pub web_search_requests: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use serde::de::DeserializeOwned;
fn round_trip<T>(value: &T, expected_json: &str)
where
T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
{
let json = serde_json::to_string(value).expect("serialize");
assert_eq!(json, expected_json, "serialized form mismatch");
let parsed: T = serde_json::from_str(&json).expect("deserialize");
assert_eq!(&parsed, value, "round-trip mismatch");
}
#[test]
fn model_id_serializes_as_string() {
round_trip(&ModelId::OPUS_4_7, "\"claude-opus-4-7\"");
round_trip(&ModelId::SONNET_4_6, "\"claude-sonnet-4-6\"");
round_trip(&ModelId::HAIKU_4_5, "\"claude-haiku-4-5-20251001\"");
round_trip(
&ModelId::custom("claude-future-foo"),
"\"claude-future-foo\"",
);
}
#[test]
fn model_id_const_equals_custom() {
assert_eq!(ModelId::OPUS_4_7, ModelId::custom("claude-opus-4-7"));
}
#[test]
fn model_id_display_and_as_ref() {
assert_eq!(ModelId::SONNET_4_6.to_string(), "claude-sonnet-4-6");
assert_eq!(
<ModelId as AsRef<str>>::as_ref(&ModelId::SONNET_4_6),
"claude-sonnet-4-6"
);
}
#[test]
fn role_serializes_lowercase() {
round_trip(&Role::User, "\"user\"");
round_trip(&Role::Assistant, "\"assistant\"");
}
#[test]
fn stop_reason_round_trips_known_variants() {
round_trip(&StopReason::EndTurn, "\"end_turn\"");
round_trip(&StopReason::MaxTokens, "\"max_tokens\"");
round_trip(&StopReason::StopSequence, "\"stop_sequence\"");
round_trip(&StopReason::ToolUse, "\"tool_use\"");
round_trip(&StopReason::PauseTurn, "\"pause_turn\"");
round_trip(&StopReason::Refusal, "\"refusal\"");
}
#[test]
fn stop_reason_unknown_falls_back_to_other() {
let parsed: StopReason = serde_json::from_str("\"some_new_reason\"").expect("deserialize");
assert_eq!(parsed, StopReason::Other);
}
#[test]
fn service_tier_unknown_falls_back_to_other() {
let parsed: ServiceTier = serde_json::from_str("\"enterprise\"").expect("deserialize");
assert_eq!(parsed, ServiceTier::Other);
round_trip(&ServiceTier::Standard, "\"standard\"");
round_trip(&ServiceTier::Priority, "\"priority\"");
round_trip(&ServiceTier::Batch, "\"batch\"");
}
#[test]
fn usage_minimal_payload_round_trips() {
let u = Usage {
input_tokens: 12,
output_tokens: 34,
..Usage::default()
};
round_trip(&u, r#"{"input_tokens":12,"output_tokens":34}"#);
}
#[test]
fn usage_full_payload_round_trips() {
let u = Usage {
input_tokens: 100,
output_tokens: 50,
cache_creation_input_tokens: Some(20),
cache_read_input_tokens: Some(80),
cache_creation: Some(CacheCreationBreakdown {
ephemeral_5m_input_tokens: 10,
ephemeral_1h_input_tokens: 10,
}),
server_tool_use: Some(ServerToolUseUsage {
web_search_requests: 3,
}),
service_tier: Some(ServiceTier::Standard),
inference_geo: Some("us-east-1".into()),
};
let json = serde_json::to_string(&u).expect("serialize");
let parsed: Usage = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, u);
}
#[test]
fn usage_tolerates_unknown_fields() {
let json = r#"{
"input_tokens": 5,
"output_tokens": 7,
"future_field": "ignored"
}"#;
let parsed: Usage = serde_json::from_str(json).expect("deserialize");
assert_eq!(parsed.input_tokens, 5);
assert_eq!(parsed.output_tokens, 7);
}
}