1use bitflags::bitflags;
7use serde::{Deserialize, Serialize};
8
9bitflags! {
10 #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
11 pub struct ModelType: u16 {
12 const CHAT = 1 << 0;
14 const COMPLETIONS = 1 << 1;
16 const RESPONSES = 1 << 2;
18 const EMBEDDINGS = 1 << 3;
20 const RERANK = 1 << 4;
22 const GENERATE = 1 << 5;
24 const VISION = 1 << 6;
26 const TOOLS = 1 << 7;
28 const REASONING = 1 << 8;
30 const IMAGE_GEN = 1 << 9;
32 const AUDIO = 1 << 10;
34 const MODERATION = 1 << 11;
36
37 const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
39 | Self::RESPONSES.bits() | Self::TOOLS.bits();
40
41 const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
43
44 const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
46
47 const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
49
50 const EMBED_MODEL = Self::EMBEDDINGS.bits();
52
53 const RERANK_MODEL = Self::RERANK.bits();
55
56 const IMAGE_MODEL = Self::IMAGE_GEN.bits();
58
59 const AUDIO_MODEL = Self::AUDIO.bits();
61
62 const MODERATION_MODEL = Self::MODERATION.bits();
64 }
65}
66
67const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
69 (ModelType::CHAT, "chat"),
70 (ModelType::COMPLETIONS, "completions"),
71 (ModelType::RESPONSES, "responses"),
72 (ModelType::EMBEDDINGS, "embeddings"),
73 (ModelType::RERANK, "rerank"),
74 (ModelType::GENERATE, "generate"),
75 (ModelType::VISION, "vision"),
76 (ModelType::TOOLS, "tools"),
77 (ModelType::REASONING, "reasoning"),
78 (ModelType::IMAGE_GEN, "image_gen"),
79 (ModelType::AUDIO, "audio"),
80 (ModelType::MODERATION, "moderation"),
81];
82
83impl ModelType {
84 #[inline]
86 pub fn supports_chat(&self) -> bool {
87 self.contains(Self::CHAT)
88 }
89
90 #[inline]
92 pub fn supports_completions(&self) -> bool {
93 self.contains(Self::COMPLETIONS)
94 }
95
96 #[inline]
98 pub fn supports_responses(&self) -> bool {
99 self.contains(Self::RESPONSES)
100 }
101
102 #[inline]
104 pub fn supports_embeddings(&self) -> bool {
105 self.contains(Self::EMBEDDINGS)
106 }
107
108 #[inline]
110 pub fn supports_rerank(&self) -> bool {
111 self.contains(Self::RERANK)
112 }
113
114 #[inline]
116 pub fn supports_generate(&self) -> bool {
117 self.contains(Self::GENERATE)
118 }
119
120 #[inline]
122 pub fn supports_vision(&self) -> bool {
123 self.contains(Self::VISION)
124 }
125
126 #[inline]
128 pub fn supports_tools(&self) -> bool {
129 self.contains(Self::TOOLS)
130 }
131
132 #[inline]
134 pub fn supports_reasoning(&self) -> bool {
135 self.contains(Self::REASONING)
136 }
137
138 #[inline]
140 pub fn supports_image_gen(&self) -> bool {
141 self.contains(Self::IMAGE_GEN)
142 }
143
144 #[inline]
146 pub fn supports_audio(&self) -> bool {
147 self.contains(Self::AUDIO)
148 }
149
150 #[inline]
152 pub fn supports_moderation(&self) -> bool {
153 self.contains(Self::MODERATION)
154 }
155
156 pub fn supports_endpoint(&self, endpoint: Endpoint) -> bool {
158 match endpoint {
159 Endpoint::Chat => self.supports_chat(),
160 Endpoint::Completions => self.supports_completions(),
161 Endpoint::Responses => self.supports_responses(),
162 Endpoint::Embeddings => self.supports_embeddings(),
163 Endpoint::Rerank => self.supports_rerank(),
164 Endpoint::Generate => self.supports_generate(),
165 Endpoint::Models => true,
166 }
167 }
168
169 pub fn as_capability_names(&self) -> Vec<&'static str> {
171 let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
172 for &(flag, name) in CAPABILITY_NAMES {
173 if self.contains(flag) {
174 result.push(name);
175 }
176 }
177 result
178 }
179
180 #[inline]
182 pub fn is_llm(&self) -> bool {
183 self.supports_chat()
184 }
185
186 #[inline]
188 pub fn is_embedding_model(&self) -> bool {
189 self.supports_embeddings() && !self.supports_chat()
190 }
191
192 #[inline]
194 pub fn is_reranker(&self) -> bool {
195 self.supports_rerank() && !self.supports_chat()
196 }
197
198 #[inline]
200 pub fn is_image_model(&self) -> bool {
201 self.supports_image_gen() && !self.supports_chat()
202 }
203
204 #[inline]
206 pub fn is_audio_model(&self) -> bool {
207 self.supports_audio() && !self.supports_chat()
208 }
209
210 #[inline]
212 pub fn is_moderation_model(&self) -> bool {
213 self.supports_moderation() && !self.supports_chat()
214 }
215}
216
217impl std::fmt::Display for ModelType {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 let names = self.as_capability_names();
220 if names.is_empty() {
221 write!(f, "none")
222 } else {
223 write!(f, "{}", names.join(","))
224 }
225 }
226}
227
228impl Serialize for ModelType {
229 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230 where
231 S: serde::Serializer,
232 {
233 serializer.serialize_u16(self.bits())
234 }
235}
236
237impl<'de> Deserialize<'de> for ModelType {
238 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
239 where
240 D: serde::Deserializer<'de>,
241 {
242 let bits = u16::deserialize(deserializer)?;
243 ModelType::from_bits(bits)
244 .ok_or_else(|| serde::de::Error::custom(format!("invalid ModelType bits: {}", bits)))
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250#[serde(rename_all = "lowercase")]
251pub enum Endpoint {
252 Chat,
254 Completions,
256 Responses,
258 Embeddings,
260 Rerank,
262 Generate,
264 Models,
266}
267
268impl Endpoint {
269 pub fn path(&self) -> &'static str {
271 match self {
272 Endpoint::Chat => "/v1/chat/completions",
273 Endpoint::Completions => "/v1/completions",
274 Endpoint::Responses => "/v1/responses",
275 Endpoint::Embeddings => "/v1/embeddings",
276 Endpoint::Rerank => "/v1/rerank",
277 Endpoint::Generate => "/generate",
278 Endpoint::Models => "/v1/models",
279 }
280 }
281
282 pub fn from_path(path: &str) -> Option<Self> {
284 let path = path.trim_end_matches('/');
285 match path {
286 "/v1/chat/completions" => Some(Endpoint::Chat),
287 "/v1/completions" => Some(Endpoint::Completions),
288 "/v1/responses" => Some(Endpoint::Responses),
289 "/v1/embeddings" => Some(Endpoint::Embeddings),
290 "/v1/rerank" => Some(Endpoint::Rerank),
291 "/generate" => Some(Endpoint::Generate),
292 "/v1/models" => Some(Endpoint::Models),
293 _ => None,
294 }
295 }
296
297 pub fn required_capability(&self) -> Option<ModelType> {
299 match self {
300 Endpoint::Chat => Some(ModelType::CHAT),
301 Endpoint::Completions => Some(ModelType::COMPLETIONS),
302 Endpoint::Responses => Some(ModelType::RESPONSES),
303 Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
304 Endpoint::Rerank => Some(ModelType::RERANK),
305 Endpoint::Generate => Some(ModelType::GENERATE),
306 Endpoint::Models => None,
307 }
308 }
309}
310
311impl std::fmt::Display for Endpoint {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 Endpoint::Chat => write!(f, "chat"),
315 Endpoint::Completions => write!(f, "completions"),
316 Endpoint::Responses => write!(f, "responses"),
317 Endpoint::Embeddings => write!(f, "embeddings"),
318 Endpoint::Rerank => write!(f, "rerank"),
319 Endpoint::Generate => write!(f, "generate"),
320 Endpoint::Models => write!(f, "models"),
321 }
322 }
323}