1use bitflags::bitflags;
5use serde::{Deserialize, Serialize};
6use std::fmt;
7use strum::Display;
8
9bitflags! {
10 #[derive(Copy, Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
34 pub struct ModelType: u8 {
35 const Chat = 1 << 0;
36 const Completions = 1 << 1;
37 const Embedding = 1 << 2;
38 const TensorBased = 1 << 3;
39 }
40}
41
42impl ModelType {
43 pub fn as_str(&self) -> String {
44 self.as_vec().join(",")
45 }
46
47 pub fn supports_chat(&self) -> bool {
48 self.contains(ModelType::Chat)
49 }
50 pub fn supports_completions(&self) -> bool {
51 self.contains(ModelType::Completions)
52 }
53 pub fn supports_embedding(&self) -> bool {
54 self.contains(ModelType::Embedding)
55 }
56 pub fn supports_tensor(&self) -> bool {
57 self.contains(ModelType::TensorBased)
58 }
59
60 pub fn as_vec(&self) -> Vec<&'static str> {
61 let mut result = Vec::new();
62 if self.supports_chat() {
63 result.push("chat");
64 }
65 if self.supports_completions() {
66 result.push("completions");
67 }
68 if self.supports_embedding() {
69 result.push("embedding");
70 }
71 if self.supports_tensor() {
72 result.push("tensor");
73 }
74 result
75 }
76
77 pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
80 let mut endpoint_types = Vec::new();
81 if self.contains(Self::Chat) {
82 endpoint_types.push(crate::endpoint_type::EndpointType::Chat);
83 }
84 if self.contains(Self::Completions) {
85 endpoint_types.push(crate::endpoint_type::EndpointType::Completion);
86 }
87 if self.contains(Self::Embedding) {
88 endpoint_types.push(crate::endpoint_type::EndpointType::Embedding);
89 }
90 endpoint_types
94 }
95}
96
97impl fmt::Display for ModelType {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{}", self.as_str())
100 }
101}
102
103#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
104pub enum ModelInput {
105 Text,
107 Tokens,
109 Tensor,
111}
112
113impl ModelInput {
114 pub fn as_str(&self) -> &str {
115 match self {
116 Self::Text => "text",
117 Self::Tokens => "tokens",
118 Self::Tensor => "tensor",
119 }
120 }
121}