Skip to main content

openai_protocol/
model_type.rs

1//! Model type definitions using bitflags for endpoint support.
2//!
3//! Defines [`ModelType`] using bitflags to represent which endpoints a model
4//! can support, and [`Endpoint`] for routing decisions.
5
6use bitflags::bitflags;
7use serde::{Deserialize, Serialize};
8
9bitflags! {
10    #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
11    pub struct ModelType: u16 {
12        /// OpenAI Chat Completions API (/v1/chat/completions)
13        const CHAT        = 1 << 0;
14        /// OpenAI Completions API - legacy (/v1/completions)
15        const COMPLETIONS = 1 << 1;
16        /// OpenAI Responses API (/v1/responses)
17        const RESPONSES   = 1 << 2;
18        /// Embeddings API (/v1/embeddings)
19        const EMBEDDINGS  = 1 << 3;
20        /// Rerank API (/v1/rerank)
21        const RERANK      = 1 << 4;
22        /// SGLang Generate API (/generate)
23        const GENERATE    = 1 << 5;
24        /// Vision/multimodal support (images in input)
25        const VISION      = 1 << 6;
26        /// Tool/function calling support
27        const TOOLS       = 1 << 7;
28        /// Reasoning/thinking support (e.g., o1, DeepSeek-R1)
29        const REASONING   = 1 << 8;
30        /// Image generation (DALL-E, Sora, gpt-image)
31        const IMAGE_GEN   = 1 << 9;
32        /// Audio models (TTS, Whisper, realtime, transcribe)
33        const AUDIO       = 1 << 10;
34        /// Content moderation models
35        const MODERATION  = 1 << 11;
36
37        /// Standard LLM: chat + completions + responses + tools
38        const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
39                  | Self::RESPONSES.bits() | Self::TOOLS.bits();
40
41        /// Vision-capable LLM: LLM + vision
42        const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
43
44        /// Reasoning LLM: LLM + reasoning (e.g., o1, o3, DeepSeek-R1)
45        const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
46
47        /// Full-featured LLM: all text generation capabilities
48        const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
49
50        /// Embedding model only
51        const EMBED_MODEL = Self::EMBEDDINGS.bits();
52
53        /// Reranker model only
54        const RERANK_MODEL = Self::RERANK.bits();
55
56        /// Image generation model only (DALL-E, Sora, gpt-image)
57        const IMAGE_MODEL = Self::IMAGE_GEN.bits();
58
59        /// Audio model only (TTS, Whisper, realtime)
60        const AUDIO_MODEL = Self::AUDIO.bits();
61
62        /// Content moderation model only
63        const MODERATION_MODEL = Self::MODERATION.bits();
64    }
65}
66
67/// Mapping of individual capability flags to their names.
68const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
69    (ModelType::CHAT, "chat"),
70    (ModelType::COMPLETIONS, "completions"),
71    (ModelType::RESPONSES, "responses"),
72    (ModelType::EMBEDDINGS, "embeddings"),
73    (ModelType::RERANK, "rerank"),
74    (ModelType::GENERATE, "generate"),
75    (ModelType::VISION, "vision"),
76    (ModelType::TOOLS, "tools"),
77    (ModelType::REASONING, "reasoning"),
78    (ModelType::IMAGE_GEN, "image_gen"),
79    (ModelType::AUDIO, "audio"),
80    (ModelType::MODERATION, "moderation"),
81];
82
83impl ModelType {
84    /// Check if this model type supports the chat completions endpoint
85    #[inline]
86    pub fn supports_chat(self) -> bool {
87        self.contains(Self::CHAT)
88    }
89
90    /// Check if this model type supports the legacy completions endpoint
91    #[inline]
92    pub fn supports_completions(self) -> bool {
93        self.contains(Self::COMPLETIONS)
94    }
95
96    /// Check if this model type supports the responses endpoint
97    #[inline]
98    pub fn supports_responses(self) -> bool {
99        self.contains(Self::RESPONSES)
100    }
101
102    /// Check if this model type supports the embeddings endpoint
103    #[inline]
104    pub fn supports_embeddings(self) -> bool {
105        self.contains(Self::EMBEDDINGS)
106    }
107
108    /// Check if this model type supports the rerank endpoint
109    #[inline]
110    pub fn supports_rerank(self) -> bool {
111        self.contains(Self::RERANK)
112    }
113
114    /// Check if this model type supports the generate endpoint
115    #[inline]
116    pub fn supports_generate(self) -> bool {
117        self.contains(Self::GENERATE)
118    }
119
120    /// Check if this model type supports vision/multimodal input
121    #[inline]
122    pub fn supports_vision(self) -> bool {
123        self.contains(Self::VISION)
124    }
125
126    /// Check if this model type supports tool/function calling
127    #[inline]
128    pub fn supports_tools(self) -> bool {
129        self.contains(Self::TOOLS)
130    }
131
132    /// Check if this model type supports reasoning/thinking
133    #[inline]
134    pub fn supports_reasoning(self) -> bool {
135        self.contains(Self::REASONING)
136    }
137
138    /// Check if this model type supports image generation
139    #[inline]
140    pub fn supports_image_gen(self) -> bool {
141        self.contains(Self::IMAGE_GEN)
142    }
143
144    /// Check if this model type supports audio (TTS, Whisper, etc.)
145    #[inline]
146    pub fn supports_audio(self) -> bool {
147        self.contains(Self::AUDIO)
148    }
149
150    /// Check if this model type supports content moderation
151    #[inline]
152    pub fn supports_moderation(self) -> bool {
153        self.contains(Self::MODERATION)
154    }
155
156    /// Check if this model type supports a given endpoint
157    pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
158        match endpoint {
159            Endpoint::Chat => self.supports_chat(),
160            Endpoint::Completions => self.supports_completions(),
161            Endpoint::Responses => self.supports_responses(),
162            Endpoint::Embeddings => self.supports_embeddings(),
163            Endpoint::Rerank => self.supports_rerank(),
164            Endpoint::Generate => self.supports_generate(),
165            Endpoint::Models => true,
166        }
167    }
168
169    /// Convert to a list of supported capability names
170    pub fn as_capability_names(self) -> Vec<&'static str> {
171        let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
172        for &(flag, name) in CAPABILITY_NAMES {
173            if self.contains(flag) {
174                result.push(name);
175            }
176        }
177        result
178    }
179
180    /// Check if this is an LLM (supports at least chat)
181    #[inline]
182    pub fn is_llm(self) -> bool {
183        self.supports_chat()
184    }
185
186    /// Check if this is an embedding model
187    #[inline]
188    pub fn is_embedding_model(self) -> bool {
189        self.supports_embeddings() && !self.supports_chat()
190    }
191
192    /// Check if this is a reranker model
193    #[inline]
194    pub fn is_reranker(self) -> bool {
195        self.supports_rerank() && !self.supports_chat()
196    }
197
198    /// Check if this is an image generation model
199    #[inline]
200    pub fn is_image_model(self) -> bool {
201        self.supports_image_gen() && !self.supports_chat()
202    }
203
204    /// Check if this is an audio model
205    #[inline]
206    pub fn is_audio_model(self) -> bool {
207        self.supports_audio() && !self.supports_chat()
208    }
209
210    /// Check if this is a moderation model
211    #[inline]
212    pub fn is_moderation_model(self) -> bool {
213        self.supports_moderation() && !self.supports_chat()
214    }
215}
216
217impl std::fmt::Display for ModelType {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        let names = self.as_capability_names();
220        if names.is_empty() {
221            write!(f, "none")
222        } else {
223            write!(f, "{}", names.join(","))
224        }
225    }
226}
227
228impl Serialize for ModelType {
229    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230    where
231        S: serde::Serializer,
232    {
233        use serde::ser::SerializeSeq;
234        let names = self.as_capability_names();
235        let mut seq = serializer.serialize_seq(Some(names.len()))?;
236        for name in names {
237            seq.serialize_element(name)?;
238        }
239        seq.end()
240    }
241}
242
243impl<'de> Deserialize<'de> for ModelType {
244    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245    where
246        D: serde::Deserializer<'de>,
247    {
248        use serde::de;
249
250        struct ModelTypeVisitor;
251
252        impl<'de> de::Visitor<'de> for ModelTypeVisitor {
253            type Value = ModelType;
254
255            fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256                f.write_str("an array of capability names or a u16 bitfield")
257            }
258
259            // Backward compat: accept numeric u16 bitfield
260            fn visit_u64<E: de::Error>(self, v: u64) -> Result<ModelType, E> {
261                let bits = u16::try_from(v)
262                    .map_err(|_| E::custom(format!("ModelType bits out of u16 range: {v}")))?;
263                ModelType::from_bits(bits)
264                    .ok_or_else(|| E::custom(format!("invalid ModelType bits: {bits}")))
265            }
266
267            // New format: array of capability name strings
268            fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ModelType, A::Error> {
269                let mut model_type = ModelType::empty();
270                while let Some(name) = seq.next_element::<String>()? {
271                    let flag = CAPABILITY_NAMES
272                        .iter()
273                        .find(|(_, n)| *n == name.as_str())
274                        .map(|(f, _)| *f)
275                        .ok_or_else(|| {
276                            de::Error::custom(format!("unknown ModelType capability: {name}"))
277                        })?;
278                    model_type |= flag;
279                }
280                Ok(model_type)
281            }
282        }
283
284        deserializer.deserialize_any(ModelTypeVisitor)
285    }
286}
287
288/// Endpoint types for routing decisions.
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
290#[serde(rename_all = "lowercase")]
291pub enum Endpoint {
292    /// Chat completions endpoint (/v1/chat/completions)
293    Chat,
294    /// Legacy completions endpoint (/v1/completions)
295    Completions,
296    /// Responses endpoint (/v1/responses)
297    Responses,
298    /// Embeddings endpoint (/v1/embeddings)
299    Embeddings,
300    /// Rerank endpoint (/v1/rerank)
301    Rerank,
302    /// SGLang generate endpoint (/generate)
303    Generate,
304    /// Models listing endpoint (/v1/models)
305    Models,
306}
307
308impl Endpoint {
309    /// Get the URL path for this endpoint
310    pub fn path(self) -> &'static str {
311        match self {
312            Endpoint::Chat => "/v1/chat/completions",
313            Endpoint::Completions => "/v1/completions",
314            Endpoint::Responses => "/v1/responses",
315            Endpoint::Embeddings => "/v1/embeddings",
316            Endpoint::Rerank => "/v1/rerank",
317            Endpoint::Generate => "/generate",
318            Endpoint::Models => "/v1/models",
319        }
320    }
321
322    /// Parse an endpoint from a URL path
323    pub fn from_path(path: &str) -> Option<Self> {
324        let path = path.trim_end_matches('/');
325        match path {
326            "/v1/chat/completions" => Some(Endpoint::Chat),
327            "/v1/completions" => Some(Endpoint::Completions),
328            "/v1/responses" => Some(Endpoint::Responses),
329            "/v1/embeddings" => Some(Endpoint::Embeddings),
330            "/v1/rerank" => Some(Endpoint::Rerank),
331            "/generate" => Some(Endpoint::Generate),
332            "/v1/models" => Some(Endpoint::Models),
333            _ => None,
334        }
335    }
336
337    /// Get the required ModelType flag for this endpoint
338    pub fn required_capability(self) -> Option<ModelType> {
339        match self {
340            Endpoint::Chat => Some(ModelType::CHAT),
341            Endpoint::Completions => Some(ModelType::COMPLETIONS),
342            Endpoint::Responses => Some(ModelType::RESPONSES),
343            Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
344            Endpoint::Rerank => Some(ModelType::RERANK),
345            Endpoint::Generate => Some(ModelType::GENERATE),
346            Endpoint::Models => None,
347        }
348    }
349}
350
351impl std::fmt::Display for Endpoint {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        match self {
354            Endpoint::Chat => write!(f, "chat"),
355            Endpoint::Completions => write!(f, "completions"),
356            Endpoint::Responses => write!(f, "responses"),
357            Endpoint::Embeddings => write!(f, "embeddings"),
358            Endpoint::Rerank => write!(f, "rerank"),
359            Endpoint::Generate => write!(f, "generate"),
360            Endpoint::Models => write!(f, "models"),
361        }
362    }
363}