Skip to main content

pacha/model/
card.rs

1//! Model Card for standardized model documentation.
2//!
3//! Based on "Model Cards for Model Reporting" (Mitchell et al., 2019).
4
5use crate::data::DatasetReference;
6use crate::recipe::RecipeReference;
7use chrono::{DateTime, Duration, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use super::ModelReference;
12
13/// Model Card with standardized documentation.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ModelCard {
16    /// Model description.
17    pub description: String,
18
19    /// Reference to training dataset.
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub training_data: Option<DatasetReference>,
22    /// Reference to training recipe.
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub training_recipe: Option<RecipeReference>,
25    /// When training was performed.
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub training_date: Option<DateTime<Utc>>,
28    /// Training duration in seconds.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub training_duration_secs: Option<i64>,
31
32    /// Performance metrics (e.g., accuracy, F1 score).
33    #[serde(default)]
34    pub metrics: HashMap<String, f64>,
35    /// Reference to evaluation dataset.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub evaluation_data: Option<DatasetReference>,
38
39    /// Primary intended use cases.
40    #[serde(default)]
41    pub primary_uses: Vec<String>,
42    /// Out-of-scope use cases.
43    #[serde(default)]
44    pub out_of_scope_uses: Vec<String>,
45
46    /// Known limitations.
47    #[serde(default)]
48    pub limitations: Vec<String>,
49    /// Ethical considerations.
50    #[serde(default)]
51    pub ethical_considerations: Vec<String>,
52
53    /// Parent model (if fine-tuned).
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub parent_model: Option<ModelReference>,
56    /// Models this was derived from.
57    #[serde(default)]
58    pub derived_from: Vec<ModelReference>,
59
60    /// Additional metadata.
61    #[serde(default)]
62    pub extra: HashMap<String, serde_json::Value>,
63}
64
65impl ModelCard {
66    /// Create a new model card builder.
67    #[must_use]
68    pub fn builder() -> ModelCardBuilder {
69        ModelCardBuilder::new()
70    }
71
72    /// Create a minimal model card with just a description.
73    #[must_use]
74    pub fn new(description: impl Into<String>) -> Self {
75        Self {
76            description: description.into(),
77            training_data: None,
78            training_recipe: None,
79            training_date: None,
80            training_duration_secs: None,
81            metrics: HashMap::new(),
82            evaluation_data: None,
83            primary_uses: Vec::new(),
84            out_of_scope_uses: Vec::new(),
85            limitations: Vec::new(),
86            ethical_considerations: Vec::new(),
87            parent_model: None,
88            derived_from: Vec::new(),
89            extra: HashMap::new(),
90        }
91    }
92
93    /// Get training duration as a Duration.
94    #[must_use]
95    pub fn training_duration(&self) -> Option<Duration> {
96        self.training_duration_secs.map(Duration::seconds)
97    }
98
99    /// Add a metric.
100    pub fn add_metric(&mut self, name: impl Into<String>, value: f64) {
101        self.metrics.insert(name.into(), value);
102    }
103
104    /// Add a primary use case.
105    pub fn add_primary_use(&mut self, use_case: impl Into<String>) {
106        self.primary_uses.push(use_case.into());
107    }
108
109    /// Add a limitation.
110    pub fn add_limitation(&mut self, limitation: impl Into<String>) {
111        self.limitations.push(limitation.into());
112    }
113}
114
115impl Default for ModelCard {
116    fn default() -> Self {
117        Self::new("")
118    }
119}
120
121/// Builder for creating model cards.
122#[derive(Debug, Default)]
123pub struct ModelCardBuilder {
124    card: ModelCard,
125}
126
127impl ModelCardBuilder {
128    /// Create a new builder.
129    #[must_use]
130    pub fn new() -> Self {
131        Self { card: ModelCard::default() }
132    }
133
134    /// Set the description.
135    #[must_use]
136    pub fn description(mut self, description: impl Into<String>) -> Self {
137        self.card.description = description.into();
138        self
139    }
140
141    /// Set the training data reference.
142    #[must_use]
143    pub fn training_data(mut self, data: DatasetReference) -> Self {
144        self.card.training_data = Some(data);
145        self
146    }
147
148    /// Set the training recipe reference.
149    #[must_use]
150    pub fn training_recipe(mut self, recipe: RecipeReference) -> Self {
151        self.card.training_recipe = Some(recipe);
152        self
153    }
154
155    /// Set the training date.
156    #[must_use]
157    pub fn training_date(mut self, date: DateTime<Utc>) -> Self {
158        self.card.training_date = Some(date);
159        self
160    }
161
162    /// Set the training duration.
163    #[must_use]
164    pub fn training_duration(mut self, duration: Duration) -> Self {
165        self.card.training_duration_secs = Some(duration.num_seconds());
166        self
167    }
168
169    /// Add metrics from an iterator.
170    #[must_use]
171    pub fn metrics<I, K>(mut self, metrics: I) -> Self
172    where
173        I: IntoIterator<Item = (K, f64)>,
174        K: Into<String>,
175    {
176        for (k, v) in metrics {
177            self.card.metrics.insert(k.into(), v);
178        }
179        self
180    }
181
182    /// Set evaluation data reference.
183    #[must_use]
184    pub fn evaluation_data(mut self, data: DatasetReference) -> Self {
185        self.card.evaluation_data = Some(data);
186        self
187    }
188
189    /// Add primary uses.
190    #[must_use]
191    pub fn primary_uses<I, S>(mut self, uses: I) -> Self
192    where
193        I: IntoIterator<Item = S>,
194        S: Into<String>,
195    {
196        self.card.primary_uses = uses.into_iter().map(Into::into).collect();
197        self
198    }
199
200    /// Add out-of-scope uses.
201    #[must_use]
202    pub fn out_of_scope_uses<I, S>(mut self, uses: I) -> Self
203    where
204        I: IntoIterator<Item = S>,
205        S: Into<String>,
206    {
207        self.card.out_of_scope_uses = uses.into_iter().map(Into::into).collect();
208        self
209    }
210
211    /// Add limitations.
212    #[must_use]
213    pub fn limitations<I, S>(mut self, limitations: I) -> Self
214    where
215        I: IntoIterator<Item = S>,
216        S: Into<String>,
217    {
218        self.card.limitations = limitations.into_iter().map(Into::into).collect();
219        self
220    }
221
222    /// Add ethical considerations.
223    #[must_use]
224    pub fn ethical_considerations<I, S>(mut self, considerations: I) -> Self
225    where
226        I: IntoIterator<Item = S>,
227        S: Into<String>,
228    {
229        self.card.ethical_considerations = considerations.into_iter().map(Into::into).collect();
230        self
231    }
232
233    /// Set parent model reference.
234    #[must_use]
235    pub fn parent_model(mut self, parent: ModelReference) -> Self {
236        self.card.parent_model = Some(parent);
237        self
238    }
239
240    /// Set derived-from models.
241    #[must_use]
242    pub fn derived_from(mut self, models: Vec<ModelReference>) -> Self {
243        self.card.derived_from = models;
244        self
245    }
246
247    /// Build the model card.
248    #[must_use]
249    pub fn build(self) -> ModelCard {
250        self.card
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::data::DatasetVersion;
258    use crate::model::ModelVersion;
259    use crate::recipe::RecipeVersion;
260
261    #[test]
262    fn test_model_card_new() {
263        let card = ModelCard::new("A fraud detection model");
264        assert_eq!(card.description, "A fraud detection model");
265        assert!(card.metrics.is_empty());
266    }
267
268    #[test]
269    fn test_model_card_builder() {
270        let card = ModelCard::builder()
271            .description("Fraud detector v1")
272            .metrics([("auc", 0.95), ("f1", 0.88)])
273            .primary_uses(["Fraud detection in payment transactions"])
274            .limitations(["May have reduced accuracy on international transactions"])
275            .build();
276
277        assert_eq!(card.description, "Fraud detector v1");
278        assert_eq!(card.metrics.get("auc"), Some(&0.95));
279        assert_eq!(card.metrics.get("f1"), Some(&0.88));
280        assert_eq!(card.primary_uses.len(), 1);
281        assert_eq!(card.limitations.len(), 1);
282    }
283
284    #[test]
285    fn test_model_card_with_references() {
286        let dataset_ref = DatasetReference::new("transactions", DatasetVersion::new(1, 0, 0));
287        let recipe_ref = RecipeReference::new("fraud-training", RecipeVersion::new(1, 0, 0));
288        let parent_ref = ModelReference::new("base-classifier", ModelVersion::new(1, 0, 0));
289
290        let card = ModelCard::builder()
291            .description("Fine-tuned fraud detector")
292            .training_data(dataset_ref.clone())
293            .training_recipe(recipe_ref.clone())
294            .parent_model(parent_ref.clone())
295            .build();
296
297        assert_eq!(card.training_data.unwrap().name, "transactions");
298        assert_eq!(card.training_recipe.unwrap().name, "fraud-training");
299        assert_eq!(card.parent_model.unwrap().name, "base-classifier");
300    }
301
302    #[test]
303    fn test_model_card_add_methods() {
304        let mut card = ModelCard::new("Test model");
305        card.add_metric("accuracy", 0.92);
306        card.add_primary_use("Classification");
307        card.add_limitation("Requires normalized inputs");
308
309        assert_eq!(card.metrics.get("accuracy"), Some(&0.92));
310        assert_eq!(card.primary_uses, vec!["Classification"]);
311        assert_eq!(card.limitations, vec!["Requires normalized inputs"]);
312    }
313
314    #[test]
315    fn test_model_card_serialization() {
316        let card =
317            ModelCard::builder().description("Test model").metrics([("accuracy", 0.95)]).build();
318
319        let json = serde_json::to_string(&card).unwrap();
320        let deserialized: ModelCard = serde_json::from_str(&json).unwrap();
321
322        assert_eq!(card.description, deserialized.description);
323        assert_eq!(card.metrics, deserialized.metrics);
324    }
325
326    #[test]
327    fn test_training_duration() {
328        let card =
329            ModelCard::builder().description("Model").training_duration(Duration::hours(2)).build();
330
331        let duration = card.training_duration().unwrap();
332        assert_eq!(duration.num_hours(), 2);
333    }
334}