use bitflags::bitflags;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
bitflags! {
#[derive(Copy, Debug, Default, Clone, Eq, PartialEq, Hash)]
pub struct ModelType: u16 {
const CHAT = 1 << 0;
const COMPLETIONS = 1 << 1;
const RESPONSES = 1 << 2;
const EMBEDDINGS = 1 << 3;
const RERANK = 1 << 4;
const GENERATE = 1 << 5;
const VISION = 1 << 6;
const TOOLS = 1 << 7;
const REASONING = 1 << 8;
const IMAGE_GEN = 1 << 9;
const AUDIO = 1 << 10;
const MODERATION = 1 << 11;
const LLM = Self::CHAT.bits() | Self::COMPLETIONS.bits()
| Self::RESPONSES.bits() | Self::TOOLS.bits();
const VISION_LLM = Self::LLM.bits() | Self::VISION.bits();
const REASONING_LLM = Self::LLM.bits() | Self::REASONING.bits();
const FULL_LLM = Self::VISION_LLM.bits() | Self::REASONING.bits();
const EMBED_MODEL = Self::EMBEDDINGS.bits();
const RERANK_MODEL = Self::RERANK.bits();
const IMAGE_MODEL = Self::IMAGE_GEN.bits();
const AUDIO_MODEL = Self::AUDIO.bits();
const MODERATION_MODEL = Self::MODERATION.bits();
}
}
const CAPABILITY_NAMES: &[(ModelType, &str)] = &[
(ModelType::CHAT, "chat"),
(ModelType::COMPLETIONS, "completions"),
(ModelType::RESPONSES, "responses"),
(ModelType::EMBEDDINGS, "embeddings"),
(ModelType::RERANK, "rerank"),
(ModelType::GENERATE, "generate"),
(ModelType::VISION, "vision"),
(ModelType::TOOLS, "tools"),
(ModelType::REASONING, "reasoning"),
(ModelType::IMAGE_GEN, "image_gen"),
(ModelType::AUDIO, "audio"),
(ModelType::MODERATION, "moderation"),
];
impl ModelType {
#[inline]
pub fn supports_chat(self) -> bool {
self.contains(Self::CHAT)
}
#[inline]
pub fn supports_completions(self) -> bool {
self.contains(Self::COMPLETIONS)
}
#[inline]
pub fn supports_responses(self) -> bool {
self.contains(Self::RESPONSES)
}
#[inline]
pub fn supports_embeddings(self) -> bool {
self.contains(Self::EMBEDDINGS)
}
#[inline]
pub fn supports_rerank(self) -> bool {
self.contains(Self::RERANK)
}
#[inline]
pub fn supports_generate(self) -> bool {
self.contains(Self::GENERATE)
}
#[inline]
pub fn supports_vision(self) -> bool {
self.contains(Self::VISION)
}
#[inline]
pub fn supports_tools(self) -> bool {
self.contains(Self::TOOLS)
}
#[inline]
pub fn supports_reasoning(self) -> bool {
self.contains(Self::REASONING)
}
#[inline]
pub fn supports_image_gen(self) -> bool {
self.contains(Self::IMAGE_GEN)
}
#[inline]
pub fn supports_audio(self) -> bool {
self.contains(Self::AUDIO)
}
#[inline]
pub fn supports_moderation(self) -> bool {
self.contains(Self::MODERATION)
}
pub fn supports_endpoint(self, endpoint: Endpoint) -> bool {
match endpoint {
Endpoint::Chat => self.supports_chat(),
Endpoint::Completions => self.supports_completions(),
Endpoint::Responses => self.supports_responses(),
Endpoint::Embeddings => self.supports_embeddings(),
Endpoint::Rerank => self.supports_rerank(),
Endpoint::Generate => self.supports_generate(),
Endpoint::Models => true,
}
}
pub fn as_capability_names(self) -> Vec<&'static str> {
let mut result = Vec::with_capacity(CAPABILITY_NAMES.len());
for &(flag, name) in CAPABILITY_NAMES {
if self.contains(flag) {
result.push(name);
}
}
result
}
#[inline]
pub fn is_llm(self) -> bool {
self.supports_chat()
}
#[inline]
pub fn is_embedding_model(self) -> bool {
self.supports_embeddings() && !self.supports_chat()
}
#[inline]
pub fn is_reranker(self) -> bool {
self.supports_rerank() && !self.supports_chat()
}
#[inline]
pub fn is_image_model(self) -> bool {
self.supports_image_gen() && !self.supports_chat()
}
#[inline]
pub fn is_audio_model(self) -> bool {
self.supports_audio() && !self.supports_chat()
}
#[inline]
pub fn is_moderation_model(self) -> bool {
self.supports_moderation() && !self.supports_chat()
}
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let names = self.as_capability_names();
if names.is_empty() {
write!(f, "none")
} else {
write!(f, "{}", names.join(","))
}
}
}
impl Serialize for ModelType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;
let names = self.as_capability_names();
let mut seq = serializer.serialize_seq(Some(names.len()))?;
for name in names {
seq.serialize_element(name)?;
}
seq.end()
}
}
impl<'de> Deserialize<'de> for ModelType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct ModelTypeVisitor;
impl<'de> de::Visitor<'de> for ModelTypeVisitor {
type Value = ModelType;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("an array of capability names or a u16 bitfield")
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<ModelType, E> {
let bits = u16::try_from(v)
.map_err(|_| E::custom(format!("ModelType bits out of u16 range: {v}")))?;
ModelType::from_bits(bits)
.ok_or_else(|| E::custom(format!("invalid ModelType bits: {bits}")))
}
fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<ModelType, A::Error> {
let mut model_type = ModelType::empty();
while let Some(name) = seq.next_element::<String>()? {
let flag = CAPABILITY_NAMES
.iter()
.find(|(_, n)| *n == name.as_str())
.map(|(f, _)| *f)
.ok_or_else(|| {
de::Error::custom(format!("unknown ModelType capability: {name}"))
})?;
model_type |= flag;
}
Ok(model_type)
}
}
deserializer.deserialize_any(ModelTypeVisitor)
}
}
impl JsonSchema for ModelType {
fn schema_name() -> String {
"ModelType".to_string()
}
fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
use schemars::schema::*;
let items = SchemaObject {
instance_type: Some(InstanceType::String.into()),
enum_values: Some(vec![
"chat".into(),
"completions".into(),
"responses".into(),
"embeddings".into(),
"rerank".into(),
"generate".into(),
"vision".into(),
"tools".into(),
"reasoning".into(),
"image_gen".into(),
"audio".into(),
"moderation".into(),
]),
..Default::default()
};
SchemaObject {
instance_type: Some(InstanceType::Array.into()),
array: Some(Box::new(ArrayValidation {
items: Some(SingleOrVec::Single(Box::new(items.into()))),
..Default::default()
})),
metadata: Some(Box::new(Metadata {
description: Some(
"Bitflag capabilities serialized as an array of capability names".to_string(),
),
..Default::default()
})),
..Default::default()
}
.into()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum Endpoint {
Chat,
Completions,
Responses,
Embeddings,
Rerank,
Generate,
Models,
}
impl Endpoint {
pub fn path(self) -> &'static str {
match self {
Endpoint::Chat => "/v1/chat/completions",
Endpoint::Completions => "/v1/completions",
Endpoint::Responses => "/v1/responses",
Endpoint::Embeddings => "/v1/embeddings",
Endpoint::Rerank => "/v1/rerank",
Endpoint::Generate => "/generate",
Endpoint::Models => "/v1/models",
}
}
pub fn from_path(path: &str) -> Option<Self> {
let path = path.trim_end_matches('/');
match path {
"/v1/chat/completions" => Some(Endpoint::Chat),
"/v1/completions" => Some(Endpoint::Completions),
"/v1/responses" => Some(Endpoint::Responses),
"/v1/embeddings" => Some(Endpoint::Embeddings),
"/v1/rerank" => Some(Endpoint::Rerank),
"/generate" => Some(Endpoint::Generate),
"/v1/models" => Some(Endpoint::Models),
_ => None,
}
}
pub fn required_capability(self) -> Option<ModelType> {
match self {
Endpoint::Chat => Some(ModelType::CHAT),
Endpoint::Completions => Some(ModelType::COMPLETIONS),
Endpoint::Responses => Some(ModelType::RESPONSES),
Endpoint::Embeddings => Some(ModelType::EMBEDDINGS),
Endpoint::Rerank => Some(ModelType::RERANK),
Endpoint::Generate => Some(ModelType::GENERATE),
Endpoint::Models => None,
}
}
}
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Endpoint::Chat => write!(f, "chat"),
Endpoint::Completions => write!(f, "completions"),
Endpoint::Responses => write!(f, "responses"),
Endpoint::Embeddings => write!(f, "embeddings"),
Endpoint::Rerank => write!(f, "rerank"),
Endpoint::Generate => write!(f, "generate"),
Endpoint::Models => write!(f, "models"),
}
}
}