1use serde::{Deserialize, Serialize};
9
10use crate::error::JammiError;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20#[serde(try_from = "String", into = "String")]
21pub enum ModelTask {
22 TextEmbedding,
24 ImageEmbedding,
26 AudioEmbedding,
28 Classification,
30 Ner,
32 Regression,
40}
41
42impl ModelTask {
43 pub const ALL: &'static [ModelTask] = &[
50 ModelTask::TextEmbedding,
51 ModelTask::ImageEmbedding,
52 ModelTask::AudioEmbedding,
53 ModelTask::Classification,
54 ModelTask::Ner,
55 ModelTask::Regression,
56 ];
57
58 pub fn as_db_str(&self) -> &'static str {
61 match self {
62 Self::TextEmbedding => "text_embedding",
63 Self::ImageEmbedding => "image_embedding",
64 Self::AudioEmbedding => "audio_embedding",
65 Self::Classification => "classification",
66 Self::Ner => "ner",
67 Self::Regression => "regression",
68 }
69 }
70
71 pub fn try_from_db_str(s: &str) -> Result<Self, JammiError> {
75 match s {
76 "text_embedding" => Ok(Self::TextEmbedding),
77 "image_embedding" => Ok(Self::ImageEmbedding),
78 "audio_embedding" => Ok(Self::AudioEmbedding),
79 "classification" => Ok(Self::Classification),
80 "ner" => Ok(Self::Ner),
81 "regression" => Ok(Self::Regression),
82 other => Err(JammiError::Other(format!(
83 "Unknown model task '{other}'. Expected: text_embedding, image_embedding, audio_embedding, classification, ner, regression"
84 ))),
85 }
86 }
87
88 pub fn is_embedding(&self) -> bool {
91 matches!(
92 self,
93 Self::TextEmbedding | Self::ImageEmbedding | Self::AudioEmbedding
94 )
95 }
96}
97
98impl std::fmt::Display for ModelTask {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.write_str(self.as_db_str())
101 }
102}
103
104impl std::str::FromStr for ModelTask {
105 type Err = JammiError;
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 Self::try_from_db_str(s)
108 }
109}
110
111impl TryFrom<String> for ModelTask {
112 type Error = JammiError;
113 fn try_from(s: String) -> Result<Self, Self::Error> {
114 Self::try_from_db_str(&s)
115 }
116}
117
118impl From<ModelTask> for String {
119 fn from(task: ModelTask) -> Self {
120 task.as_db_str().to_string()
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn db_str_round_trip_covers_every_variant() {
130 for variant in ModelTask::ALL {
131 let s = variant.as_db_str();
132 assert_eq!(
133 ModelTask::try_from_db_str(s).unwrap(),
134 *variant,
135 "round-trip failed for {variant:?} via '{s}'"
136 );
137 }
138 }
139
140 #[test]
141 fn all_covers_every_variant_via_exhaustive_match() {
142 fn assert_listed_in_all(t: ModelTask) {
147 match t {
148 ModelTask::TextEmbedding
149 | ModelTask::ImageEmbedding
150 | ModelTask::AudioEmbedding
151 | ModelTask::Classification
152 | ModelTask::Ner
153 | ModelTask::Regression => {
154 assert!(
155 ModelTask::ALL.contains(&t),
156 "ModelTask::ALL is missing {t:?}"
157 );
158 }
159 }
160 }
161 for v in ModelTask::ALL {
162 assert_listed_in_all(*v);
163 }
164 }
165
166 #[test]
167 fn unknown_db_str_returns_typed_error() {
168 let err = ModelTask::try_from_db_str("not_a_task").unwrap_err();
169 assert!(
170 matches!(err, JammiError::Other(ref m) if m.contains("not_a_task")),
171 "unknown variant should surface as JammiError::Other naming the input, got {err:?}"
172 );
173 }
174
175 #[test]
176 fn display_matches_db_str() {
177 assert_eq!(format!("{}", ModelTask::TextEmbedding), "text_embedding");
178 assert_eq!(format!("{}", ModelTask::ImageEmbedding), "image_embedding");
179 assert_eq!(format!("{}", ModelTask::AudioEmbedding), "audio_embedding");
180 assert_eq!(format!("{}", ModelTask::Classification), "classification");
181 assert_eq!(format!("{}", ModelTask::Ner), "ner");
182 assert_eq!(format!("{}", ModelTask::Regression), "regression");
183 }
184
185 #[test]
186 fn from_str_delegates_to_try_from_db_str() {
187 use std::str::FromStr;
188 assert_eq!(
189 ModelTask::from_str("text_embedding").unwrap(),
190 ModelTask::TextEmbedding
191 );
192 assert!(ModelTask::from_str("bogus").is_err());
193 }
194
195 #[test]
196 fn is_embedding_is_true_only_for_embedding_variants() {
197 assert!(ModelTask::TextEmbedding.is_embedding());
198 assert!(ModelTask::ImageEmbedding.is_embedding());
199 assert!(ModelTask::AudioEmbedding.is_embedding());
200 assert!(!ModelTask::Classification.is_embedding());
201 assert!(!ModelTask::Ner.is_embedding());
202 assert!(!ModelTask::Regression.is_embedding());
203 }
204
205 #[test]
206 fn serde_round_trips_via_canonical_string() {
207 for variant in ModelTask::ALL {
208 let json = serde_json::to_string(variant).unwrap();
209 let decoded: ModelTask = serde_json::from_str(&json).unwrap();
210 assert_eq!(decoded, *variant);
211 assert_eq!(json, format!("\"{}\"", variant.as_db_str()));
214 }
215 }
216}