1use bitflags::bitflags;
7use serde::{Deserialize, Serialize};
8
9bitflags! {
10 #[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
11 pub struct ModelType: u16 {
12 const CHAT = 1 << 0;
14 const COMPLETIONS = 1 << 1;
16 const RESPONSES = 1 << 2;
18 const EMBEDDINGS = 1 << 3;
20 const RERANK = 1 << 4;
22 const GENERATE = 1 << 5;
24 const VISION = 1 << 6;
26 const TOOLS = 1 << 7;
28 const REASONING = 1 << 8;
30 const IMAGE_GEN = 1 << 9;
32 const AUDIO = 1 << 10;
34 const MODERATION = 1 << 11;
36
37 const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
39 | Self::RESPONSES.bits() | Self::TOOLS.bits();
40
41 const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
43
44 const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
46
47 const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
49
50 const EMBED_MODEL = Self::EMBEDDINGS.bits();
52
53 const RERANK_MODEL = Self::RERANK.bits();
55
56 const IMAGE_MODEL = Self::IMAGE_GEN.bits();
58
59 const AUDIO_MODEL = Self::AUDIO.bits();
61
62 const MODERATION_MODEL = Self::MODERATION.bits();
64 }
65}
66
67const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
69 (ModelType::CHAT, "chat"),
70 (ModelType::COMPLETIONS, "completions"),
71 (ModelType::RESPONSES, "responses"),
72 (ModelType::EMBEDDINGS, "embeddings"),
73 (ModelType::RERANK, "rerank"),
74 (ModelType::GENERATE, "generate"),
75 (ModelType::VISION, "vision"),
76 (ModelType::TOOLS, "tools"),
77 (ModelType::REASONING, "reasoning"),
78 (ModelType::IMAGE_GEN, "image_gen"),
79 (ModelType::AUDIO, "audio"),
80 (ModelType::MODERATION, "moderation"),
81];
82
83impl ModelType {
84 #[inline]
86 pub fn supports_chat(self) -> bool {
87 self.contains(Self::CHAT)
88 }
89
90 #[inline]
92 pub fn supports_completions(self) -> bool {
93 self.contains(Self::COMPLETIONS)
94 }
95
96 #[inline]
98 pub fn supports_responses(self) -> bool {
99 self.contains(Self::RESPONSES)
100 }
101
102 #[inline]
104 pub fn supports_embeddings(self) -> bool {
105 self.contains(Self::EMBEDDINGS)
106 }
107
108 #[inline]
110 pub fn supports_rerank(self) -> bool {
111 self.contains(Self::RERANK)
112 }
113
114 #[inline]
116 pub fn supports_generate(self) -> bool {
117 self.contains(Self::GENERATE)
118 }
119
120 #[inline]
122 pub fn supports_vision(self) -> bool {
123 self.contains(Self::VISION)
124 }
125
126 #[inline]
128 pub fn supports_tools(self) -> bool {
129 self.contains(Self::TOOLS)
130 }
131
132 #[inline]
134 pub fn supports_reasoning(self) -> bool {
135 self.contains(Self::REASONING)
136 }
137
138 #[inline]
140 pub fn supports_image_gen(self) -> bool {
141 self.contains(Self::IMAGE_GEN)
142 }
143
144 #[inline]
146 pub fn supports_audio(self) -> bool {
147 self.contains(Self::AUDIO)
148 }
149
150 #[inline]
152 pub fn supports_moderation(self) -> bool {
153 self.contains(Self::MODERATION)
154 }
155
156 pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
158 match endpoint {
159 Endpoint::Chat => self.supports_chat(),
160 Endpoint::Completions => self.supports_completions(),
161 Endpoint::Responses => self.supports_responses(),
162 Endpoint::Embeddings => self.supports_embeddings(),
163 Endpoint::Rerank => self.supports_rerank(),
164 Endpoint::Generate => self.supports_generate(),
165 Endpoint::Models => true,
166 }
167 }
168
169 pub fn as_capability_names(self) -> Vec<&'static str> {
171 let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
172 for &(flag, name) in CAPABILITY_NAMES {
173 if self.contains(flag) {
174 result.push(name);
175 }
176 }
177 result
178 }
179
180 #[inline]
182 pub fn is_llm(self) -> bool {
183 self.supports_chat()
184 }
185
186 #[inline]
188 pub fn is_embedding_model(self) -> bool {
189 self.supports_embeddings() && !self.supports_chat()
190 }
191
192 #[inline]
194 pub fn is_reranker(self) -> bool {
195 self.supports_rerank() && !self.supports_chat()
196 }
197
198 #[inline]
200 pub fn is_image_model(self) -> bool {
201 self.supports_image_gen() && !self.supports_chat()
202 }
203
204 #[inline]
206 pub fn is_audio_model(self) -> bool {
207 self.supports_audio() && !self.supports_chat()
208 }
209
210 #[inline]
212 pub fn is_moderation_model(self) -> bool {
213 self.supports_moderation() && !self.supports_chat()
214 }
215}
216
217impl std::fmt::Display for ModelType {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 let names = self.as_capability_names();
220 if names.is_empty() {
221 write!(f, "none")
222 } else {
223 write!(f, "{}", names.join(","))
224 }
225 }
226}
227
228impl Serialize for ModelType {
229 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230 where
231 S: serde::Serializer,
232 {
233 use serde::ser::SerializeSeq;
234 let names = self.as_capability_names();
235 let mut seq = serializer.serialize_seq(Some(names.len()))?;
236 for name in names {
237 seq.serialize_element(name)?;
238 }
239 seq.end()
240 }
241}
242
243impl<'de> Deserialize<'de> for ModelType {
244 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
245 where
246 D: serde::Deserializer<'de>,
247 {
248 use serde::de;
249
250 struct ModelTypeVisitor;
251
252 impl<'de> de::Visitor<'de> for ModelTypeVisitor {
253 type Value = ModelType;
254
255 fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 f.write_str("an array of capability names or a u16 bitfield")
257 }
258
259 fn visit_u64<E: de::Error>(self, v: u64) -> Result<ModelType, E> {
261 let bits = u16::try_from(v)
262 .map_err(|_| E::custom(format!("ModelType bits out of u16 range: {v}")))?;
263 ModelType::from_bits(bits)
264 .ok_or_else(|| E::custom(format!("invalid ModelType bits: {bits}")))
265 }
266
267 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ModelType, A::Error> {
269 let mut model_type = ModelType::empty();
270 while let Some(name) = seq.next_element::<String>()? {
271 let flag = CAPABILITY_NAMES
272 .iter()
273 .find(|(_, n)| *n == name.as_str())
274 .map(|(f, _)| *f)
275 .ok_or_else(|| {
276 de::Error::custom(format!("unknown ModelType capability: {name}"))
277 })?;
278 model_type |= flag;
279 }
280 Ok(model_type)
281 }
282 }
283
284 deserializer.deserialize_any(ModelTypeVisitor)
285 }
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
290#[serde(rename_all = "lowercase")]
291pub enum Endpoint {
292 Chat,
294 Completions,
296 Responses,
298 Embeddings,
300 Rerank,
302 Generate,
304 Models,
306}
307
308impl Endpoint {
309 pub fn path(self) -> &'static str {
311 match self {
312 Endpoint::Chat => "/v1/chat/completions",
313 Endpoint::Completions => "/v1/completions",
314 Endpoint::Responses => "/v1/responses",
315 Endpoint::Embeddings => "/v1/embeddings",
316 Endpoint::Rerank => "/v1/rerank",
317 Endpoint::Generate => "/generate",
318 Endpoint::Models => "/v1/models",
319 }
320 }
321
322 pub fn from_path(path: &str) -> Option<Self> {
324 let path = path.trim_end_matches('/');
325 match path {
326 "/v1/chat/completions" => Some(Endpoint::Chat),
327 "/v1/completions" => Some(Endpoint::Completions),
328 "/v1/responses" => Some(Endpoint::Responses),
329 "/v1/embeddings" => Some(Endpoint::Embeddings),
330 "/v1/rerank" => Some(Endpoint::Rerank),
331 "/generate" => Some(Endpoint::Generate),
332 "/v1/models" => Some(Endpoint::Models),
333 _ => None,
334 }
335 }
336
337 pub fn required_capability(self) -> Option<ModelType> {
339 match self {
340 Endpoint::Chat => Some(ModelType::CHAT),
341 Endpoint::Completions => Some(ModelType::COMPLETIONS),
342 Endpoint::Responses => Some(ModelType::RESPONSES),
343 Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
344 Endpoint::Rerank => Some(ModelType::RERANK),
345 Endpoint::Generate => Some(ModelType::GENERATE),
346 Endpoint::Models => None,
347 }
348 }
349}
350
351impl std::fmt::Display for Endpoint {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 match self {
354 Endpoint::Chat => write!(f, "chat"),
355 Endpoint::Completions => write!(f, "completions"),
356 Endpoint::Responses => write!(f, "responses"),
357 Endpoint::Embeddings => write!(f, "embeddings"),
358 Endpoint::Rerank => write!(f, "rerank"),
359 Endpoint::Generate => write!(f, "generate"),
360 Endpoint::Models => write!(f, "models"),
361 }
362 }
363}