1use std::num::NonZeroUsize;
2
3#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
4#[serde(deny_unknown_fields)]
5pub struct ModelSpec {
6 pub id: String,
7 #[serde(default, skip_serializing_if = "Option::is_none")]
8 pub variant: Option<String>,
9 pub limits: ModelLimits,
10}
11
12impl ModelSpec {
13 pub fn new(id: impl Into<String>, context_window_tokens: NonZeroUsize) -> Self {
14 Self {
15 id: id.into(),
16 variant: None,
17 limits: ModelLimits {
18 context_window_tokens,
19 output_token_capacity: None,
20 },
21 }
22 }
23
24 pub fn with_limits(
25 id: impl Into<String>,
26 variant: Option<String>,
27 limits: ModelLimits,
28 ) -> Self {
29 Self {
30 id: id.into(),
31 variant,
32 limits,
33 }
34 }
35
36 pub fn with_variant(mut self, variant: Option<String>) -> Self {
37 self.variant = variant;
38 self
39 }
40
41 pub fn from_token_limits(
46 id: impl Into<String>,
47 variant: Option<String>,
48 context_window_tokens: usize,
49 output_token_capacity: Option<usize>,
50 ) -> Result<Self, String> {
51 Ok(Self::with_limits(
52 id,
53 variant,
54 ModelLimits::from_token_limits(context_window_tokens, output_token_capacity)?,
55 ))
56 }
57
58 pub fn context_window_tokens(&self) -> usize {
59 self.limits.context_window_tokens.get()
60 }
61}
62
63impl Default for ModelSpec {
64 fn default() -> Self {
65 Self::new(
66 String::new(),
67 NonZeroUsize::new(1).expect("one is non-zero"),
68 )
69 }
70}
71
72#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
73#[serde(deny_unknown_fields)]
74pub struct ModelLimits {
75 pub context_window_tokens: NonZeroUsize,
80 #[serde(default, skip_serializing_if = "Option::is_none")]
81 pub output_token_capacity: Option<NonZeroUsize>,
82}
83
84impl ModelLimits {
85 pub fn from_token_limits(
86 context_window_tokens: usize,
87 output_token_capacity: Option<usize>,
88 ) -> Result<Self, String> {
89 Ok(Self {
90 context_window_tokens: nonzero_token_limit(
91 "context_window_tokens",
92 context_window_tokens,
93 )?,
94 output_token_capacity: optional_nonzero_token_limit(
95 "output_token_capacity",
96 output_token_capacity,
97 )?,
98 })
99 }
100}
101
102impl Default for ModelLimits {
103 fn default() -> Self {
104 Self {
105 context_window_tokens: NonZeroUsize::new(1).expect("one is non-zero"),
106 output_token_capacity: None,
107 }
108 }
109}
110
111fn nonzero_token_limit(name: &str, value: usize) -> Result<NonZeroUsize, String> {
112 NonZeroUsize::new(value).ok_or_else(|| format!("{name} must be greater than zero"))
113}
114
115fn optional_nonzero_token_limit(
116 name: &str,
117 value: Option<usize>,
118) -> Result<Option<NonZeroUsize>, String> {
119 value
120 .map(|value| nonzero_token_limit(name, value))
121 .transpose()
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn model_spec_constructors_preserve_identity_variant_and_limits() {
130 let limits = ModelLimits::from_token_limits(8_192, Some(1_024)).expect("valid limits");
131
132 let spec =
133 ModelSpec::with_limits("provider/model", Some("fast".to_string()), limits.clone());
134
135 assert_eq!(spec.id, "provider/model");
136 assert_eq!(spec.variant.as_deref(), Some("fast"));
137 assert_eq!(spec.limits, limits);
138
139 let changed = spec.clone().with_variant(Some("accurate".to_string()));
140 assert_eq!(changed.id, "provider/model");
141 assert_eq!(changed.variant.as_deref(), Some("accurate"));
142 assert_eq!(changed.limits, spec.limits);
143
144 let cleared = changed.with_variant(None);
145 assert_eq!(cleared.id, "provider/model");
146 assert_eq!(cleared.variant, None);
147 assert_eq!(cleared.context_window_tokens(), 8_192);
148 }
149
150 #[test]
151 fn model_token_limit_constructors_reject_zero_and_preserve_output_cap() {
152 let spec = ModelSpec::from_token_limits(
153 "provider/model",
154 Some("variant-a".to_string()),
155 200_000,
156 Some(4_096),
157 )
158 .expect("valid token limits");
159
160 assert_eq!(spec.id, "provider/model");
161 assert_eq!(spec.variant.as_deref(), Some("variant-a"));
162 assert_eq!(spec.context_window_tokens(), 200_000);
163 assert_eq!(
164 spec.limits.output_token_capacity.map(NonZeroUsize::get),
165 Some(4_096)
166 );
167
168 let context_error = ModelSpec::from_token_limits("bad-context", None, 0, Some(1))
169 .expect_err("zero context");
170 assert!(
171 context_error.contains("context_window_tokens"),
172 "context error should name the invalid field: {context_error}"
173 );
174
175 let output_error = ModelLimits::from_token_limits(1, Some(0)).expect_err("zero output cap");
176 assert!(
177 output_error.contains("output_token_capacity"),
178 "output error should name the invalid field: {output_error}"
179 );
180 }
181}