1use bitflags::bitflags;
7use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
8use serde::{Deserialize, Serialize};
9
10bitflags! {
11 #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
12 pub struct ModelType: u16 {
13 const CHAT = 1 << 0;
15 const COMPLETIONS = 1 << 1;
17 const RESPONSES = 1 << 2;
19 const EMBEDDINGS = 1 << 3;
21 const RERANK = 1 << 4;
23 const GENERATE = 1 << 5;
25 const VISION = 1 << 6;
27 const TOOLS = 1 << 7;
29 const REASONING = 1 << 8;
31 const IMAGE_GEN = 1 << 9;
33 const AUDIO = 1 << 10;
35 const MODERATION = 1 << 11;
37
38 const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
40 | Self::RESPONSES.bits() | Self::TOOLS.bits();
41
42 const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
44
45 const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
47
48 const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
50
51 const EMBED_MODEL = Self::EMBEDDINGS.bits();
53
54 const RERANK_MODEL = Self::RERANK.bits();
56
57 const IMAGE_MODEL = Self::IMAGE_GEN.bits();
59
60 const AUDIO_MODEL = Self::AUDIO.bits();
62
63 const MODERATION_MODEL = Self::MODERATION.bits();
65 }
66}
67
68const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
70 (ModelType::CHAT, "chat"),
71 (ModelType::COMPLETIONS, "completions"),
72 (ModelType::RESPONSES, "responses"),
73 (ModelType::EMBEDDINGS, "embeddings"),
74 (ModelType::RERANK, "rerank"),
75 (ModelType::GENERATE, "generate"),
76 (ModelType::VISION, "vision"),
77 (ModelType::TOOLS, "tools"),
78 (ModelType::REASONING, "reasoning"),
79 (ModelType::IMAGE_GEN, "image_gen"),
80 (ModelType::AUDIO, "audio"),
81 (ModelType::MODERATION, "moderation"),
82];
83
84impl ModelType {
85 #[inline]
87 pub fn supports_chat(self) -> bool {
88 self.contains(Self::CHAT)
89 }
90
91 #[inline]
93 pub fn supports_completions(self) -> bool {
94 self.contains(Self::COMPLETIONS)
95 }
96
97 #[inline]
99 pub fn supports_responses(self) -> bool {
100 self.contains(Self::RESPONSES)
101 }
102
103 #[inline]
105 pub fn supports_embeddings(self) -> bool {
106 self.contains(Self::EMBEDDINGS)
107 }
108
109 #[inline]
111 pub fn supports_rerank(self) -> bool {
112 self.contains(Self::RERANK)
113 }
114
115 #[inline]
117 pub fn supports_generate(self) -> bool {
118 self.contains(Self::GENERATE)
119 }
120
121 #[inline]
123 pub fn supports_vision(self) -> bool {
124 self.contains(Self::VISION)
125 }
126
127 #[inline]
129 pub fn supports_tools(self) -> bool {
130 self.contains(Self::TOOLS)
131 }
132
133 #[inline]
135 pub fn supports_reasoning(self) -> bool {
136 self.contains(Self::REASONING)
137 }
138
139 #[inline]
141 pub fn supports_image_gen(self) -> bool {
142 self.contains(Self::IMAGE_GEN)
143 }
144
145 #[inline]
147 pub fn supports_audio(self) -> bool {
148 self.contains(Self::AUDIO)
149 }
150
151 #[inline]
153 pub fn supports_moderation(self) -> bool {
154 self.contains(Self::MODERATION)
155 }
156
157 pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
159 match endpoint {
160 Endpoint::Chat => self.supports_chat(),
161 Endpoint::Completions => self.supports_completions(),
162 Endpoint::Responses => self.supports_responses(),
163 Endpoint::Embeddings => self.supports_embeddings(),
164 Endpoint::Rerank => self.supports_rerank(),
165 Endpoint::Generate => self.supports_generate(),
166 Endpoint::Models => true,
167 }
168 }
169
170 pub fn as_capability_names(self) -> Vec<&'static str> {
172 let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
173 for &(flag, name) in CAPABILITY_NAMES {
174 if self.contains(flag) {
175 result.push(name);
176 }
177 }
178 result
179 }
180
181 #[inline]
183 pub fn is_llm(self) -> bool {
184 self.supports_chat()
185 }
186
187 #[inline]
189 pub fn is_embedding_model(self) -> bool {
190 self.supports_embeddings() && !self.supports_chat()
191 }
192
193 #[inline]
195 pub fn is_reranker(self) -> bool {
196 self.supports_rerank() && !self.supports_chat()
197 }
198
199 #[inline]
201 pub fn is_image_model(self) -> bool {
202 self.supports_image_gen() && !self.supports_chat()
203 }
204
205 #[inline]
207 pub fn is_audio_model(self) -> bool {
208 self.supports_audio() && !self.supports_chat()
209 }
210
211 #[inline]
213 pub fn is_moderation_model(self) -> bool {
214 self.supports_moderation() && !self.supports_chat()
215 }
216}
217
218impl std::fmt::Display for ModelType {
219 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220 let names = self.as_capability_names();
221 if names.is_empty() {
222 write!(f, "none")
223 } else {
224 write!(f, "{}", names.join(","))
225 }
226 }
227}
228
229impl Serialize for ModelType {
230 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
231 where
232 S: serde::Serializer,
233 {
234 use serde::ser::SerializeSeq;
235 let names = self.as_capability_names();
236 let mut seq = serializer.serialize_seq(Some(names.len()))?;
237 for name in names {
238 seq.serialize_element(name)?;
239 }
240 seq.end()
241 }
242}
243
244impl<'de> Deserialize<'de> for ModelType {
245 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
246 where
247 D: serde::Deserializer<'de>,
248 {
249 use serde::de;
250
251 struct ModelTypeVisitor;
252
253 impl<'de> de::Visitor<'de> for ModelTypeVisitor {
254 type Value = ModelType;
255
256 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.write_str("an array of capability names or a u16 bitfield")
258 }
259
260 fn visit_u64<E: de::Error>(self, v: u64) -> Result<ModelType, E> {
262 let bits = u16::try_from(v)
263 .map_err(|_| E::custom(format!("ModelType bits out of u16 range: {v}")))?;
264 ModelType::from_bits(bits)
265 .ok_or_else(|| E::custom(format!("invalid ModelType bits: {bits}")))
266 }
267
268 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ModelType, A::Error> {
270 let mut model_type = ModelType::empty();
271 while let Some(name) = seq.next_element::<String>()? {
272 let flag = CAPABILITY_NAMES
273 .iter()
274 .find(|(_, n)| *n == name.as_str())
275 .map(|(f, _)| *f)
276 .ok_or_else(|| {
277 de::Error::custom(format!("unknown ModelType capability: {name}"))
278 })?;
279 model_type |= flag;
280 }
281 Ok(model_type)
282 }
283 }
284
285 deserializer.deserialize_any(ModelTypeVisitor)
286 }
287}
288
289impl JsonSchema for ModelType {
291 fn schema_name() -> String {
292 "ModelType".to_string()
293 }
294
295 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
296 use schemars::schema::*;
297 let items = SchemaObject {
298 instance_type: Some(InstanceType::String.into()),
299 enum_values: Some(vec![
300 "chat".into(),
301 "completions".into(),
302 "responses".into(),
303 "embeddings".into(),
304 "rerank".into(),
305 "generate".into(),
306 "vision".into(),
307 "tools".into(),
308 "reasoning".into(),
309 "image_gen".into(),
310 "audio".into(),
311 "moderation".into(),
312 ]),
313 ..Default::default()
314 };
315 SchemaObject {
316 instance_type: Some(InstanceType::Array.into()),
317 array: Some(Box::new(ArrayValidation {
318 items: Some(SingleOrVec::Single(Box::new(items.into()))),
319 ..Default::default()
320 })),
321 metadata: Some(Box::new(Metadata {
322 description: Some(
323 "Bitflag capabilities serialized as an array of capability names".to_string(),
324 ),
325 ..Default::default()
326 })),
327 ..Default::default()
328 }
329 .into()
330 }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
335#[serde(rename_all = "lowercase")]
336pub enum Endpoint {
337 Chat,
339 Completions,
341 Responses,
343 Embeddings,
345 Rerank,
347 Generate,
349 Models,
351}
352
353impl Endpoint {
354 pub fn path(self) -> &'static str {
356 match self {
357 Endpoint::Chat => "/v1/chat/completions",
358 Endpoint::Completions => "/v1/completions",
359 Endpoint::Responses => "/v1/responses",
360 Endpoint::Embeddings => "/v1/embeddings",
361 Endpoint::Rerank => "/v1/rerank",
362 Endpoint::Generate => "/generate",
363 Endpoint::Models => "/v1/models",
364 }
365 }
366
367 pub fn from_path(path: &str) -> Option<Self> {
369 let path = path.trim_end_matches('/');
370 match path {
371 "/v1/chat/completions" => Some(Endpoint::Chat),
372 "/v1/completions" => Some(Endpoint::Completions),
373 "/v1/responses" => Some(Endpoint::Responses),
374 "/v1/embeddings" => Some(Endpoint::Embeddings),
375 "/v1/rerank" => Some(Endpoint::Rerank),
376 "/generate" => Some(Endpoint::Generate),
377 "/v1/models" => Some(Endpoint::Models),
378 _ => None,
379 }
380 }
381
382 pub fn required_capability(self) -> Option<ModelType> {
384 match self {
385 Endpoint::Chat => Some(ModelType::CHAT),
386 Endpoint::Completions => Some(ModelType::COMPLETIONS),
387 Endpoint::Responses => Some(ModelType::RESPONSES),
388 Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
389 Endpoint::Rerank => Some(ModelType::RERANK),
390 Endpoint::Generate => Some(ModelType::GENERATE),
391 Endpoint::Models => None,
392 }
393 }
394}
395
396impl std::fmt::Display for Endpoint {
397 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398 match self {
399 Endpoint::Chat => write!(f, "chat"),
400 Endpoint::Completions => write!(f, "completions"),
401 Endpoint::Responses => write!(f, "responses"),
402 Endpoint::Embeddings => write!(f, "embeddings"),
403 Endpoint::Rerank => write!(f, "rerank"),
404 Endpoint::Generate => write!(f, "generate"),
405 Endpoint::Models => write!(f, "models"),
406 }
407 }
408}