Skip to main content

openai_protocol/
model_type.rs

1//! Model type definitions using bitflags for endpoint support.
2//!
3//! Defines [`ModelType`] using bitflags to represent which endpoints a model
4//! can support, and [`Endpoint`] for routing decisions.
5
6use bitflags::bitflags;
7use serde::{Deserialize, Serialize};
8
9bitflags! {
10    #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
11    pub struct ModelType: u16 {
12        /// OpenAI Chat Completions API (/v1/chat/completions)
13        const CHAT        = 1 << 0;
14        /// OpenAI Completions API - legacy (/v1/completions)
15        const COMPLETIONS = 1 << 1;
16        /// OpenAI Responses API (/v1/responses)
17        const RESPONSES   = 1 << 2;
18        /// Embeddings API (/v1/embeddings)
19        const EMBEDDINGS  = 1 << 3;
20        /// Rerank API (/v1/rerank)
21        const RERANK      = 1 << 4;
22        /// SGLang Generate API (/generate)
23        const GENERATE    = 1 << 5;
24        /// Vision/multimodal support (images in input)
25        const VISION      = 1 << 6;
26        /// Tool/function calling support
27        const TOOLS       = 1 << 7;
28        /// Reasoning/thinking support (e.g., o1, DeepSeek-R1)
29        const REASONING   = 1 << 8;
30        /// Image generation (DALL-E, Sora, gpt-image)
31        const IMAGE_GEN   = 1 << 9;
32        /// Audio models (TTS, Whisper, realtime, transcribe)
33        const AUDIO       = 1 << 10;
34        /// Content moderation models
35        const MODERATION  = 1 << 11;
36
37        /// Standard LLM: chat + completions + responses + tools
38        const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
39                  | Self::RESPONSES.bits() | Self::TOOLS.bits();
40
41        /// Vision-capable LLM: LLM + vision
42        const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
43
44        /// Reasoning LLM: LLM + reasoning (e.g., o1, o3, DeepSeek-R1)
45        const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
46
47        /// Full-featured LLM: all text generation capabilities
48        const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
49
50        /// Embedding model only
51        const EMBED_MODEL = Self::EMBEDDINGS.bits();
52
53        /// Reranker model only
54        const RERANK_MODEL = Self::RERANK.bits();
55
56        /// Image generation model only (DALL-E, Sora, gpt-image)
57        const IMAGE_MODEL = Self::IMAGE_GEN.bits();
58
59        /// Audio model only (TTS, Whisper, realtime)
60        const AUDIO_MODEL = Self::AUDIO.bits();
61
62        /// Content moderation model only
63        const MODERATION_MODEL = Self::MODERATION.bits();
64    }
65}
66
67/// Mapping of individual capability flags to their names.
68const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
69    (ModelType::CHAT, "chat"),
70    (ModelType::COMPLETIONS, "completions"),
71    (ModelType::RESPONSES, "responses"),
72    (ModelType::EMBEDDINGS, "embeddings"),
73    (ModelType::RERANK, "rerank"),
74    (ModelType::GENERATE, "generate"),
75    (ModelType::VISION, "vision"),
76    (ModelType::TOOLS, "tools"),
77    (ModelType::REASONING, "reasoning"),
78    (ModelType::IMAGE_GEN, "image_gen"),
79    (ModelType::AUDIO, "audio"),
80    (ModelType::MODERATION, "moderation"),
81];
82
83impl ModelType {
84    /// Check if this model type supports the chat completions endpoint
85    #[inline]
86    pub fn supports_chat(&self) -> bool {
87        self.contains(Self::CHAT)
88    }
89
90    /// Check if this model type supports the legacy completions endpoint
91    #[inline]
92    pub fn supports_completions(&self) -> bool {
93        self.contains(Self::COMPLETIONS)
94    }
95
96    /// Check if this model type supports the responses endpoint
97    #[inline]
98    pub fn supports_responses(&self) -> bool {
99        self.contains(Self::RESPONSES)
100    }
101
102    /// Check if this model type supports the embeddings endpoint
103    #[inline]
104    pub fn supports_embeddings(&self) -> bool {
105        self.contains(Self::EMBEDDINGS)
106    }
107
108    /// Check if this model type supports the rerank endpoint
109    #[inline]
110    pub fn supports_rerank(&self) -> bool {
111        self.contains(Self::RERANK)
112    }
113
114    /// Check if this model type supports the generate endpoint
115    #[inline]
116    pub fn supports_generate(&self) -> bool {
117        self.contains(Self::GENERATE)
118    }
119
120    /// Check if this model type supports vision/multimodal input
121    #[inline]
122    pub fn supports_vision(&self) -> bool {
123        self.contains(Self::VISION)
124    }
125
126    /// Check if this model type supports tool/function calling
127    #[inline]
128    pub fn supports_tools(&self) -> bool {
129        self.contains(Self::TOOLS)
130    }
131
132    /// Check if this model type supports reasoning/thinking
133    #[inline]
134    pub fn supports_reasoning(&self) -> bool {
135        self.contains(Self::REASONING)
136    }
137
138    /// Check if this model type supports image generation
139    #[inline]
140    pub fn supports_image_gen(&self) -> bool {
141        self.contains(Self::IMAGE_GEN)
142    }
143
144    /// Check if this model type supports audio (TTS, Whisper, etc.)
145    #[inline]
146    pub fn supports_audio(&self) -> bool {
147        self.contains(Self::AUDIO)
148    }
149
150    /// Check if this model type supports content moderation
151    #[inline]
152    pub fn supports_moderation(&self) -> bool {
153        self.contains(Self::MODERATION)
154    }
155
156    /// Check if this model type supports a given endpoint
157    pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
158        match endpoint {
159            Endpoint::Chat => self.supports_chat(),
160            Endpoint::Completions => self.supports_completions(),
161            Endpoint::Responses => self.supports_responses(),
162            Endpoint::Embeddings => self.supports_embeddings(),
163            Endpoint::Rerank => self.supports_rerank(),
164            Endpoint::Generate => self.supports_generate(),
165            Endpoint::Models => true,
166        }
167    }
168
169    /// Convert to a list of supported capability names
170    pub fn as_capability_names(&self) -> Vec<&'static str> {
171        let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
172        for &(flag, name) in CAPABILITY_NAMES {
173            if self.contains(flag) {
174                result.push(name);
175            }
176        }
177        result
178    }
179
180    /// Check if this is an LLM (supports at least chat)
181    #[inline]
182    pub fn is_llm(&self) -> bool {
183        self.supports_chat()
184    }
185
186    /// Check if this is an embedding model
187    #[inline]
188    pub fn is_embedding_model(&self) -> bool {
189        self.supports_embeddings() && !self.supports_chat()
190    }
191
192    /// Check if this is a reranker model
193    #[inline]
194    pub fn is_reranker(&self) -> bool {
195        self.supports_rerank() && !self.supports_chat()
196    }
197
198    /// Check if this is an image generation model
199    #[inline]
200    pub fn is_image_model(&self) -> bool {
201        self.supports_image_gen() && !self.supports_chat()
202    }
203
204    /// Check if this is an audio model
205    #[inline]
206    pub fn is_audio_model(&self) -> bool {
207        self.supports_audio() && !self.supports_chat()
208    }
209
210    /// Check if this is a moderation model
211    #[inline]
212    pub fn is_moderation_model(&self) -> bool {
213        self.supports_moderation() && !self.supports_chat()
214    }
215}
216
217impl std::fmt::Display for ModelType {
218    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219        let names = self.as_capability_names();
220        if names.is_empty() {
221            write!(f, "none")
222        } else {
223            write!(f, "{}", names.join(","))
224        }
225    }
226}
227
228impl Serialize for ModelType {
229    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230    where
231        S: serde::Serializer,
232    {
233        serializer.serialize_u16(self.bits())
234    }
235}
236
237impl<'de> Deserialize<'de> for ModelType {
238    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
239    where
240        D: serde::Deserializer<'de>,
241    {
242        let bits = u16::deserialize(deserializer)?;
243        ModelType::from_bits(bits)
244            .ok_or_else(|| serde::de::Error::custom(format!("invalid ModelType bits: {}", bits)))
245    }
246}
247
248/// Endpoint types for routing decisions.
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250#[serde(rename_all = "lowercase")]
251pub enum Endpoint {
252    /// Chat completions endpoint (/v1/chat/completions)
253    Chat,
254    /// Legacy completions endpoint (/v1/completions)
255    Completions,
256    /// Responses endpoint (/v1/responses)
257    Responses,
258    /// Embeddings endpoint (/v1/embeddings)
259    Embeddings,
260    /// Rerank endpoint (/v1/rerank)
261    Rerank,
262    /// SGLang generate endpoint (/generate)
263    Generate,
264    /// Models listing endpoint (/v1/models)
265    Models,
266}
267
268impl Endpoint {
269    /// Get the URL path for this endpoint
270    pub fn path(&self) -> &'static str {
271        match self {
272            Endpoint::Chat => "/v1/chat/completions",
273            Endpoint::Completions => "/v1/completions",
274            Endpoint::Responses => "/v1/responses",
275            Endpoint::Embeddings => "/v1/embeddings",
276            Endpoint::Rerank => "/v1/rerank",
277            Endpoint::Generate => "/generate",
278            Endpoint::Models => "/v1/models",
279        }
280    }
281
282    /// Parse an endpoint from a URL path
283    pub fn from_path(path: &str) -> Option<Self> {
284        let path = path.trim_end_matches('/');
285        match path {
286            "/v1/chat/completions" => Some(Endpoint::Chat),
287            "/v1/completions" => Some(Endpoint::Completions),
288            "/v1/responses" => Some(Endpoint::Responses),
289            "/v1/embeddings" => Some(Endpoint::Embeddings),
290            "/v1/rerank" => Some(Endpoint::Rerank),
291            "/generate" => Some(Endpoint::Generate),
292            "/v1/models" => Some(Endpoint::Models),
293            _ => None,
294        }
295    }
296
297    /// Get the required ModelType flag for this endpoint
298    pub fn required_capability(&self) -> Option<ModelType> {
299        match self {
300            Endpoint::Chat => Some(ModelType::CHAT),
301            Endpoint::Completions => Some(ModelType::COMPLETIONS),
302            Endpoint::Responses => Some(ModelType::RESPONSES),
303            Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
304            Endpoint::Rerank => Some(ModelType::RERANK),
305            Endpoint::Generate => Some(ModelType::GENERATE),
306            Endpoint::Models => None,
307        }
308    }
309}
310
311impl std::fmt::Display for Endpoint {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        match self {
314            Endpoint::Chat => write!(f, "chat"),
315            Endpoint::Completions => write!(f, "completions"),
316            Endpoint::Responses => write!(f, "responses"),
317            Endpoint::Embeddings => write!(f, "embeddings"),
318            Endpoint::Rerank => write!(f, "rerank"),
319            Endpoint::Generate => write!(f, "generate"),
320            Endpoint::Models => write!(f, "models"),
321        }
322    }
323}