Skip to main content

jammi_db/
model_task.rs

1//! ML task taxonomy shared across the catalog, store, cache, and inference
2//! call sites. Lives in `jammi-db` because `jammi-db` owns the
3//! catalog tables that persist it (`models.task`, `result_tables.task`) and
4//! the on-disk strings must agree across every crate that reads or writes
5//! them. `jammi-ai` re-exports the type for callers that consume the
6//! higher-level inference surface.
7
8use serde::{Deserialize, Serialize};
9
10use crate::error::JammiError;
11
12/// What inference task a model performs.
13///
14/// The catalog persists this as a snake-case `TEXT` column; in-process call
15/// sites should pass the enum directly. The
16/// [`as_db_str`](Self::as_db_str) / [`try_from_db_str`](Self::try_from_db_str)
17/// pair is the authoritative database mapping — `Display`, `FromStr`, and
18/// serde all delegate to it so there is exactly one spelling per variant.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
20#[serde(try_from = "String", into = "String")]
21pub enum ModelTask {
22    /// Produce dense vector representations of input text.
23    TextEmbedding,
24    /// Produce dense vector representations of input images.
25    ImageEmbedding,
26    /// Produce dense vector representations of input audio clips.
27    AudioEmbedding,
28    /// Assign a label and confidence score to input text.
29    Classification,
30    /// Extract named entities (person, org, location, etc.) from text.
31    Ner,
32    /// Predict a continuous outcome as a *distribution* — a Gaussian
33    /// `(mean, std)` or a set of quantiles — rather than a point. The
34    /// distributional decoder ([`DistributionAdapter`](../../jammi-ai/inference/adapter/distribution.rs))
35    /// and the proper-scoring objectives (NLL/CRPS/pinball) train and serve it.
36    /// Unlike S9's similarity edge — a *derivation* over embeddings that earned
37    /// no variant — this is a genuine model output type, so it belongs in
38    /// [`Self::ALL`] and the resolution path.
39    Regression,
40}
41
42impl ModelTask {
43    /// Every variant in declaration order. The single source of truth for
44    /// "what tasks exist" — `ResultStore`, the catalog SQL builders, and
45    /// any future caller that needs to fan over the full set must read it
46    /// here rather than re-listing variants. Kept consistent with the
47    /// `enum` body by `all_covers_every_variant_via_exhaustive_match` in
48    /// `tests` below.
49    pub const ALL: &'static [ModelTask] = &[
50        ModelTask::TextEmbedding,
51        ModelTask::ImageEmbedding,
52        ModelTask::AudioEmbedding,
53        ModelTask::Classification,
54        ModelTask::Ner,
55        ModelTask::Regression,
56    ];
57
58    /// Canonical snake-case string stored in the catalog. The single source
59    /// of truth — `Display`, `FromStr`, serde all route through this.
60    pub fn as_db_str(&self) -> &'static str {
61        match self {
62            Self::TextEmbedding => "text_embedding",
63            Self::ImageEmbedding => "image_embedding",
64            Self::AudioEmbedding => "audio_embedding",
65            Self::Classification => "classification",
66            Self::Ner => "ner",
67            Self::Regression => "regression",
68        }
69    }
70
71    /// Decode the canonical snake-case string back into a [`ModelTask`].
72    /// Unknown spellings raise [`JammiError::Other`] naming the offending
73    /// value and the accepted set.
74    pub fn try_from_db_str(s: &str) -> Result<Self, JammiError> {
75        match s {
76            "text_embedding" => Ok(Self::TextEmbedding),
77            "image_embedding" => Ok(Self::ImageEmbedding),
78            "audio_embedding" => Ok(Self::AudioEmbedding),
79            "classification" => Ok(Self::Classification),
80            "ner" => Ok(Self::Ner),
81            "regression" => Ok(Self::Regression),
82            other => Err(JammiError::Other(format!(
83                "Unknown model task '{other}'. Expected: text_embedding, image_embedding, audio_embedding, classification, ner, regression"
84            ))),
85        }
86    }
87
88    /// `true` for the two embedding variants that participate in vector
89    /// search and ANN sidecar indexes; `false` for inference-only tasks.
90    pub fn is_embedding(&self) -> bool {
91        matches!(
92            self,
93            Self::TextEmbedding | Self::ImageEmbedding | Self::AudioEmbedding
94        )
95    }
96}
97
98impl std::fmt::Display for ModelTask {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.write_str(self.as_db_str())
101    }
102}
103
104impl std::str::FromStr for ModelTask {
105    type Err = JammiError;
106    fn from_str(s: &str) -> Result<Self, Self::Err> {
107        Self::try_from_db_str(s)
108    }
109}
110
111impl TryFrom<String> for ModelTask {
112    type Error = JammiError;
113    fn try_from(s: String) -> Result<Self, Self::Error> {
114        Self::try_from_db_str(&s)
115    }
116}
117
118impl From<ModelTask> for String {
119    fn from(task: ModelTask) -> Self {
120        task.as_db_str().to_string()
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn db_str_round_trip_covers_every_variant() {
130        for variant in ModelTask::ALL {
131            let s = variant.as_db_str();
132            assert_eq!(
133                ModelTask::try_from_db_str(s).unwrap(),
134                *variant,
135                "round-trip failed for {variant:?} via '{s}'"
136            );
137        }
138    }
139
140    #[test]
141    fn all_covers_every_variant_via_exhaustive_match() {
142        // The match below is exhaustive — adding a new variant to the
143        // enum without extending `ALL` either fails to compile here
144        // (new arm needed) or fails the `contains` assertion at test
145        // time. Two-layer defense against `ALL` drifting from the enum.
146        fn assert_listed_in_all(t: ModelTask) {
147            match t {
148                ModelTask::TextEmbedding
149                | ModelTask::ImageEmbedding
150                | ModelTask::AudioEmbedding
151                | ModelTask::Classification
152                | ModelTask::Ner
153                | ModelTask::Regression => {
154                    assert!(
155                        ModelTask::ALL.contains(&t),
156                        "ModelTask::ALL is missing {t:?}"
157                    );
158                }
159            }
160        }
161        for v in ModelTask::ALL {
162            assert_listed_in_all(*v);
163        }
164    }
165
166    #[test]
167    fn unknown_db_str_returns_typed_error() {
168        let err = ModelTask::try_from_db_str("not_a_task").unwrap_err();
169        assert!(
170            matches!(err, JammiError::Other(ref m) if m.contains("not_a_task")),
171            "unknown variant should surface as JammiError::Other naming the input, got {err:?}"
172        );
173    }
174
175    #[test]
176    fn display_matches_db_str() {
177        assert_eq!(format!("{}", ModelTask::TextEmbedding), "text_embedding");
178        assert_eq!(format!("{}", ModelTask::ImageEmbedding), "image_embedding");
179        assert_eq!(format!("{}", ModelTask::AudioEmbedding), "audio_embedding");
180        assert_eq!(format!("{}", ModelTask::Classification), "classification");
181        assert_eq!(format!("{}", ModelTask::Ner), "ner");
182        assert_eq!(format!("{}", ModelTask::Regression), "regression");
183    }
184
185    #[test]
186    fn from_str_delegates_to_try_from_db_str() {
187        use std::str::FromStr;
188        assert_eq!(
189            ModelTask::from_str("text_embedding").unwrap(),
190            ModelTask::TextEmbedding
191        );
192        assert!(ModelTask::from_str("bogus").is_err());
193    }
194
195    #[test]
196    fn is_embedding_is_true_only_for_embedding_variants() {
197        assert!(ModelTask::TextEmbedding.is_embedding());
198        assert!(ModelTask::ImageEmbedding.is_embedding());
199        assert!(ModelTask::AudioEmbedding.is_embedding());
200        assert!(!ModelTask::Classification.is_embedding());
201        assert!(!ModelTask::Ner.is_embedding());
202        assert!(!ModelTask::Regression.is_embedding());
203    }
204
205    #[test]
206    fn serde_round_trips_via_canonical_string() {
207        for variant in ModelTask::ALL {
208            let json = serde_json::to_string(variant).unwrap();
209            let decoded: ModelTask = serde_json::from_str(&json).unwrap();
210            assert_eq!(decoded, *variant);
211            // serde flatten via String -> the JSON is the canonical
212            // snake-case spelling wrapped in quotes.
213            assert_eq!(json, format!("\"{}\"", variant.as_db_str()));
214        }
215    }
216}