use bitflags::bitflags;
use serde::{Deserialize, Serialize};
use std::fmt;
use strum::Display;
bitflags! {
#[derive(Copy, Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelType: u8 {
const Chat = 1 << 0;
const Completions = 1 << 1;
const Embedding = 1 << 2;
const TensorBased = 1 << 3;
const Prefill = 1 << 4;
const Images = 1 << 5;
const Audios = 1 << 6;
const Videos = 1 << 7;
}
}
impl ModelType {
pub fn as_str(&self) -> String {
self.as_vec().join(",")
}
pub fn supports_chat(&self) -> bool {
self.contains(ModelType::Chat)
}
pub fn supports_completions(&self) -> bool {
self.contains(ModelType::Completions)
}
pub fn supports_embedding(&self) -> bool {
self.contains(ModelType::Embedding)
}
pub fn supports_tensor(&self) -> bool {
self.contains(ModelType::TensorBased)
}
pub fn supports_prefill(&self) -> bool {
self.contains(ModelType::Prefill)
}
pub fn supports_images(&self) -> bool {
self.contains(ModelType::Images)
}
pub fn supports_audios(&self) -> bool {
self.contains(ModelType::Audios)
}
pub fn supports_videos(&self) -> bool {
self.contains(ModelType::Videos)
}
pub fn as_vec(&self) -> Vec<&'static str> {
let mut result = Vec::new();
if self.supports_chat() {
result.push("chat");
}
if self.supports_completions() {
result.push("completions");
}
if self.supports_embedding() {
result.push("embedding");
}
if self.supports_tensor() {
result.push("tensor");
}
if self.supports_prefill() {
result.push("prefill");
}
if self.supports_images() {
result.push("images");
}
if self.supports_audios() {
result.push("audios");
}
if self.supports_videos() {
result.push("videos");
}
result
}
pub fn units(&self) -> Vec<ModelType> {
let mut result = Vec::new();
if self.supports_chat() {
result.push(ModelType::Chat);
}
if self.supports_completions() {
result.push(ModelType::Completions);
}
if self.supports_embedding() {
result.push(ModelType::Embedding);
}
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
if self.supports_prefill() {
result.push(ModelType::Prefill);
}
if self.supports_images() {
result.push(ModelType::Images);
}
if self.supports_audios() {
result.push(ModelType::Audios);
}
if self.supports_videos() {
result.push(ModelType::Videos);
}
result
}
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
let mut endpoint_types = Vec::new();
if self.contains(Self::Chat) {
endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
endpoint_types.push(crate::endpoint_type::EndpointType::Responses);
if dynamo_runtime::config::env_is_truthy(
dynamo_runtime::config::environment_names::llm::DYN_ENABLE_ANTHROPIC_API,
) {
endpoint_types.push(crate::endpoint_type::EndpointType::AnthropicMessages);
}
}
if self.contains(Self::Completions) {
endpoint_types.push(crate::endpoint_type::EndpointType::Completion);
}
if self.contains(Self::Embedding) {
endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
}
if self.contains(Self::Images) {
endpoint_types.push(crate::endpoint_type::EndpointType::Images);
}
if self.contains(Self::Audios) {
endpoint_types.push(crate::endpoint_type::EndpointType::Audios);
}
if self.contains(Self::Videos) {
endpoint_types.push(crate::endpoint_type::EndpointType::Videos);
}
endpoint_types
}
}
impl fmt::Display for ModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Copy, Debug, Default, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelInput {
#[default]
Text,
Tokens,
Tensor,
}
impl ModelInput {
pub fn as_str(&self) -> &str {
match self {
Self::Text => "text",
Self::Tokens => "tokens",
Self::Tensor => "tensor",
}
}
}