Skip to main content

openai_protocol/
model_type.rs

1//! Model type definitions using bitflags for endpoint support.
2//!
3//! Defines [`ModelType`] using bitflags to represent which endpoints a model
4//! can support, and [`Endpoint`] for routing decisions.
5
6use bitflags::bitflags;
7use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
8use serde::{Deserialize, Serialize};
9
10bitflags! {
11    #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
12    pub struct ModelType: u16 {
13        /// OpenAI Chat Completions API (/v1/chat/completions)
14        const CHAT        = 1 << 0;
15        /// OpenAI Completions API - legacy (/v1/completions)
16        const COMPLETIONS = 1 << 1;
17        /// OpenAI Responses API (/v1/responses)
18        const RESPONSES   = 1 << 2;
19        /// Embeddings API (/v1/embeddings)
20        const EMBEDDINGS  = 1 << 3;
21        /// Rerank API (/v1/rerank)
22        const RERANK      = 1 << 4;
23        /// SGLang Generate API (/generate)
24        const GENERATE    = 1 << 5;
25        /// Vision/multimodal support (images in input)
26        const VISION      = 1 << 6;
27        /// Tool/function calling support
28        const TOOLS       = 1 << 7;
29        /// Reasoning/thinking support (e.g., o1, DeepSeek-R1)
30        const REASONING   = 1 << 8;
31        /// Image generation (DALL-E, Sora, gpt-image)
32        const IMAGE_GEN   = 1 << 9;
33        /// Audio models (TTS, Whisper, realtime, transcribe)
34        const AUDIO       = 1 << 10;
35        /// Content moderation models
36        const MODERATION  = 1 << 11;
37
38        /// Standard LLM: chat + completions + responses + tools
39        const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
40                  | Self::RESPONSES.bits() | Self::TOOLS.bits();
41
42        /// Vision-capable LLM: LLM + vision
43        const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
44
45        /// Reasoning LLM: LLM + reasoning (e.g., o1, o3, DeepSeek-R1)
46        const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
47
48        /// Full-featured LLM: all text generation capabilities
49        const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
50
51        /// Embedding model only
52        const EMBED_MODEL = Self::EMBEDDINGS.bits();
53
54        /// Reranker model only
55        const RERANK_MODEL = Self::RERANK.bits();
56
57        /// Image generation model only (DALL-E, Sora, gpt-image)
58        const IMAGE_MODEL = Self::IMAGE_GEN.bits();
59
60        /// Audio model only (TTS, Whisper, realtime)
61        const AUDIO_MODEL = Self::AUDIO.bits();
62
63        /// Content moderation model only
64        const MODERATION_MODEL = Self::MODERATION.bits();
65    }
66}
67
68/// Mapping of individual capability flags to their names.
69const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
70    (ModelType::CHAT, "chat"),
71    (ModelType::COMPLETIONS, "completions"),
72    (ModelType::RESPONSES, "responses"),
73    (ModelType::EMBEDDINGS, "embeddings"),
74    (ModelType::RERANK, "rerank"),
75    (ModelType::GENERATE, "generate"),
76    (ModelType::VISION, "vision"),
77    (ModelType::TOOLS, "tools"),
78    (ModelType::REASONING, "reasoning"),
79    (ModelType::IMAGE_GEN, "image_gen"),
80    (ModelType::AUDIO, "audio"),
81    (ModelType::MODERATION, "moderation"),
82];
83
84impl ModelType {
85    /// Check if this model type supports the chat completions endpoint
86    #[inline]
87    pub fn supports_chat(self) -> bool {
88        self.contains(Self::CHAT)
89    }
90
91    /// Check if this model type supports the legacy completions endpoint
92    #[inline]
93    pub fn supports_completions(self) -> bool {
94        self.contains(Self::COMPLETIONS)
95    }
96
97    /// Check if this model type supports the responses endpoint
98    #[inline]
99    pub fn supports_responses(self) -> bool {
100        self.contains(Self::RESPONSES)
101    }
102
103    /// Check if this model type supports the embeddings endpoint
104    #[inline]
105    pub fn supports_embeddings(self) -> bool {
106        self.contains(Self::EMBEDDINGS)
107    }
108
109    /// Check if this model type supports the rerank endpoint
110    #[inline]
111    pub fn supports_rerank(self) -> bool {
112        self.contains(Self::RERANK)
113    }
114
115    /// Check if this model type supports the generate endpoint
116    #[inline]
117    pub fn supports_generate(self) -> bool {
118        self.contains(Self::GENERATE)
119    }
120
121    /// Check if this model type supports vision/multimodal input
122    #[inline]
123    pub fn supports_vision(self) -> bool {
124        self.contains(Self::VISION)
125    }
126
127    /// Check if this model type supports tool/function calling
128    #[inline]
129    pub fn supports_tools(self) -> bool {
130        self.contains(Self::TOOLS)
131    }
132
133    /// Check if this model type supports reasoning/thinking
134    #[inline]
135    pub fn supports_reasoning(self) -> bool {
136        self.contains(Self::REASONING)
137    }
138
139    /// Check if this model type supports image generation
140    #[inline]
141    pub fn supports_image_gen(self) -> bool {
142        self.contains(Self::IMAGE_GEN)
143    }
144
145    /// Check if this model type supports audio (TTS, Whisper, etc.)
146    #[inline]
147    pub fn supports_audio(self) -> bool {
148        self.contains(Self::AUDIO)
149    }
150
151    /// Check if this model type supports content moderation
152    #[inline]
153    pub fn supports_moderation(self) -> bool {
154        self.contains(Self::MODERATION)
155    }
156
157    /// Check if this model type supports a given endpoint
158    pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
159        match endpoint {
160            Endpoint::Chat => self.supports_chat(),
161            Endpoint::Completions => self.supports_completions(),
162            Endpoint::Responses => self.supports_responses(),
163            Endpoint::Embeddings => self.supports_embeddings(),
164            Endpoint::Rerank => self.supports_rerank(),
165            Endpoint::Generate => self.supports_generate(),
166            Endpoint::Models => true,
167        }
168    }
169
170    /// Convert to a list of supported capability names
171    pub fn as_capability_names(self) -> Vec<&'static str> {
172        let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
173        for &(flag, name) in CAPABILITY_NAMES {
174            if self.contains(flag) {
175                result.push(name);
176            }
177        }
178        result
179    }
180
181    /// Check if this is an LLM (supports at least chat)
182    #[inline]
183    pub fn is_llm(self) -> bool {
184        self.supports_chat()
185    }
186
187    /// Check if this is an embedding model
188    #[inline]
189    pub fn is_embedding_model(self) -> bool {
190        self.supports_embeddings() && !self.supports_chat()
191    }
192
193    /// Check if this is a reranker model
194    #[inline]
195    pub fn is_reranker(self) -> bool {
196        self.supports_rerank() && !self.supports_chat()
197    }
198
199    /// Check if this is an image generation model
200    #[inline]
201    pub fn is_image_model(self) -> bool {
202        self.supports_image_gen() && !self.supports_chat()
203    }
204
205    /// Check if this is an audio model
206    #[inline]
207    pub fn is_audio_model(self) -> bool {
208        self.supports_audio() && !self.supports_chat()
209    }
210
211    /// Check if this is a moderation model
212    #[inline]
213    pub fn is_moderation_model(self) -> bool {
214        self.supports_moderation() && !self.supports_chat()
215    }
216}
217
218impl std::fmt::Display for ModelType {
219    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220        let names = self.as_capability_names();
221        if names.is_empty() {
222            write!(f, "none")
223        } else {
224            write!(f, "{}", names.join(","))
225        }
226    }
227}
228
229impl Serialize for ModelType {
230    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
231    where
232        S: serde::Serializer,
233    {
234        use serde::ser::SerializeSeq;
235        let names = self.as_capability_names();
236        let mut seq = serializer.serialize_seq(Some(names.len()))?;
237        for name in names {
238            seq.serialize_element(name)?;
239        }
240        seq.end()
241    }
242}
243
244impl<'de> Deserialize<'de> for ModelType {
245    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
246    where
247        D: serde::Deserializer<'de>,
248    {
249        use serde::de;
250
251        struct ModelTypeVisitor;
252
253        impl<'de> de::Visitor<'de> for ModelTypeVisitor {
254            type Value = ModelType;
255
256            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257                f.write_str("an array of capability names or a u16 bitfield")
258            }
259
260            // Backward compat: accept numeric u16 bitfield
261            fn visit_u64<E: de::Error>(self, v: u64) -> Result<ModelType, E> {
262                let bits = u16::try_from(v)
263                    .map_err(|_| E::custom(format!("ModelType bits out of u16 range: {v}")))?;
264                ModelType::from_bits(bits)
265                    .ok_or_else(|| E::custom(format!("invalid ModelType bits: {bits}")))
266            }
267
268            // New format: array of capability name strings
269            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ModelType, A::Error> {
270                let mut model_type = ModelType::empty();
271                while let Some(name) = seq.next_element::<String>()? {
272                    let flag = CAPABILITY_NAMES
273                        .iter()
274                        .find(|(_, n)| *n == name.as_str())
275                        .map(|(f, _)| *f)
276                        .ok_or_else(|| {
277                            de::Error::custom(format!("unknown ModelType capability: {name}"))
278                        })?;
279                    model_type |= flag;
280                }
281                Ok(model_type)
282            }
283        }
284
285        deserializer.deserialize_any(ModelTypeVisitor)
286    }
287}
288
289/// Manual JsonSchema impl for `ModelType` — serialized as an array of capability name strings.
290impl JsonSchema for ModelType {
291    fn schema_name() -> String {
292        "ModelType".to_string()
293    }
294
295    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
296        use schemars::schema::*;
297        let items = SchemaObject {
298            instance_type: Some(InstanceType::String.into()),
299            enum_values: Some(vec![
300                "chat".into(),
301                "completions".into(),
302                "responses".into(),
303                "embeddings".into(),
304                "rerank".into(),
305                "generate".into(),
306                "vision".into(),
307                "tools".into(),
308                "reasoning".into(),
309                "image_gen".into(),
310                "audio".into(),
311                "moderation".into(),
312            ]),
313            ..Default::default()
314        };
315        SchemaObject {
316            instance_type: Some(InstanceType::Array.into()),
317            array: Some(Box::new(ArrayValidation {
318                items: Some(SingleOrVec::Single(Box::new(items.into()))),
319                ..Default::default()
320            })),
321            metadata: Some(Box::new(Metadata {
322                description: Some(
323                    "Bitflag capabilities serialized as an array of capability names".to_string(),
324                ),
325                ..Default::default()
326            })),
327            ..Default::default()
328        }
329        .into()
330    }
331}
332
333/// Endpoint types for routing decisions.
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
335#[serde(rename_all = "lowercase")]
336pub enum Endpoint {
337    /// Chat completions endpoint (/v1/chat/completions)
338    Chat,
339    /// Legacy completions endpoint (/v1/completions)
340    Completions,
341    /// Responses endpoint (/v1/responses)
342    Responses,
343    /// Embeddings endpoint (/v1/embeddings)
344    Embeddings,
345    /// Rerank endpoint (/v1/rerank)
346    Rerank,
347    /// SGLang generate endpoint (/generate)
348    Generate,
349    /// Models listing endpoint (/v1/models)
350    Models,
351}
352
353impl Endpoint {
354    /// Get the URL path for this endpoint
355    pub fn path(self) -> &'static str {
356        match self {
357            Endpoint::Chat => "/v1/chat/completions",
358            Endpoint::Completions => "/v1/completions",
359            Endpoint::Responses => "/v1/responses",
360            Endpoint::Embeddings => "/v1/embeddings",
361            Endpoint::Rerank => "/v1/rerank",
362            Endpoint::Generate => "/generate",
363            Endpoint::Models => "/v1/models",
364        }
365    }
366
367    /// Parse an endpoint from a URL path
368    pub fn from_path(path: &str) -> Option<Self> {
369        let path = path.trim_end_matches('/');
370        match path {
371            "/v1/chat/completions" => Some(Endpoint::Chat),
372            "/v1/completions" => Some(Endpoint::Completions),
373            "/v1/responses" => Some(Endpoint::Responses),
374            "/v1/embeddings" => Some(Endpoint::Embeddings),
375            "/v1/rerank" => Some(Endpoint::Rerank),
376            "/generate" => Some(Endpoint::Generate),
377            "/v1/models" => Some(Endpoint::Models),
378            _ => None,
379        }
380    }
381
382    /// Get the required ModelType flag for this endpoint
383    pub fn required_capability(self) -> Option<ModelType> {
384        match self {
385            Endpoint::Chat => Some(ModelType::CHAT),
386            Endpoint::Completions => Some(ModelType::COMPLETIONS),
387            Endpoint::Responses => Some(ModelType::RESPONSES),
388            Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
389            Endpoint::Rerank => Some(ModelType::RERANK),
390            Endpoint::Generate => Some(ModelType::GENERATE),
391            Endpoint::Models => None,
392        }
393    }
394}
395
396impl std::fmt::Display for Endpoint {
397    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
398        match self {
399            Endpoint::Chat => write!(f, "chat"),
400            Endpoint::Completions => write!(f, "completions"),
401            Endpoint::Responses => write!(f, "responses"),
402            Endpoint::Embeddings => write!(f, "embeddings"),
403            Endpoint::Rerank => write!(f, "rerank"),
404            Endpoint::Generate => write!(f, "generate"),
405            Endpoint::Models => write!(f, "models"),
406        }
407    }
408}