1use crate::data::DatasetReference;
6use crate::recipe::RecipeReference;
7use chrono::{DateTime, Duration, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use super::ModelReference;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ModelCard {
16 pub description: String,
18
19 #[serde(skip_serializing_if = "Option::is_none")]
21 pub training_data: Option<DatasetReference>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub training_recipe: Option<RecipeReference>,
25 #[serde(skip_serializing_if = "Option::is_none")]
27 pub training_date: Option<DateTime<Utc>>,
28 #[serde(skip_serializing_if = "Option::is_none")]
30 pub training_duration_secs: Option<i64>,
31
32 #[serde(default)]
34 pub metrics: HashMap<String, f64>,
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub evaluation_data: Option<DatasetReference>,
38
39 #[serde(default)]
41 pub primary_uses: Vec<String>,
42 #[serde(default)]
44 pub out_of_scope_uses: Vec<String>,
45
46 #[serde(default)]
48 pub limitations: Vec<String>,
49 #[serde(default)]
51 pub ethical_considerations: Vec<String>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub parent_model: Option<ModelReference>,
56 #[serde(default)]
58 pub derived_from: Vec<ModelReference>,
59
60 #[serde(default)]
62 pub extra: HashMap<String, serde_json::Value>,
63}
64
65impl ModelCard {
66 #[must_use]
68 pub fn builder() -> ModelCardBuilder {
69 ModelCardBuilder::new()
70 }
71
72 #[must_use]
74 pub fn new(description: impl Into<String>) -> Self {
75 Self {
76 description: description.into(),
77 training_data: None,
78 training_recipe: None,
79 training_date: None,
80 training_duration_secs: None,
81 metrics: HashMap::new(),
82 evaluation_data: None,
83 primary_uses: Vec::new(),
84 out_of_scope_uses: Vec::new(),
85 limitations: Vec::new(),
86 ethical_considerations: Vec::new(),
87 parent_model: None,
88 derived_from: Vec::new(),
89 extra: HashMap::new(),
90 }
91 }
92
93 #[must_use]
95 pub fn training_duration(&self) -> Option<Duration> {
96 self.training_duration_secs.map(Duration::seconds)
97 }
98
99 pub fn add_metric(&mut self, name: impl Into<String>, value: f64) {
101 self.metrics.insert(name.into(), value);
102 }
103
104 pub fn add_primary_use(&mut self, use_case: impl Into<String>) {
106 self.primary_uses.push(use_case.into());
107 }
108
109 pub fn add_limitation(&mut self, limitation: impl Into<String>) {
111 self.limitations.push(limitation.into());
112 }
113}
114
115impl Default for ModelCard {
116 fn default() -> Self {
117 Self::new("")
118 }
119}
120
121#[derive(Debug, Default)]
123pub struct ModelCardBuilder {
124 card: ModelCard,
125}
126
127impl ModelCardBuilder {
128 #[must_use]
130 pub fn new() -> Self {
131 Self { card: ModelCard::default() }
132 }
133
134 #[must_use]
136 pub fn description(mut self, description: impl Into<String>) -> Self {
137 self.card.description = description.into();
138 self
139 }
140
141 #[must_use]
143 pub fn training_data(mut self, data: DatasetReference) -> Self {
144 self.card.training_data = Some(data);
145 self
146 }
147
148 #[must_use]
150 pub fn training_recipe(mut self, recipe: RecipeReference) -> Self {
151 self.card.training_recipe = Some(recipe);
152 self
153 }
154
155 #[must_use]
157 pub fn training_date(mut self, date: DateTime<Utc>) -> Self {
158 self.card.training_date = Some(date);
159 self
160 }
161
162 #[must_use]
164 pub fn training_duration(mut self, duration: Duration) -> Self {
165 self.card.training_duration_secs = Some(duration.num_seconds());
166 self
167 }
168
169 #[must_use]
171 pub fn metrics<I, K>(mut self, metrics: I) -> Self
172 where
173 I: IntoIterator<Item = (K, f64)>,
174 K: Into<String>,
175 {
176 for (k, v) in metrics {
177 self.card.metrics.insert(k.into(), v);
178 }
179 self
180 }
181
182 #[must_use]
184 pub fn evaluation_data(mut self, data: DatasetReference) -> Self {
185 self.card.evaluation_data = Some(data);
186 self
187 }
188
189 #[must_use]
191 pub fn primary_uses<I, S>(mut self, uses: I) -> Self
192 where
193 I: IntoIterator<Item = S>,
194 S: Into<String>,
195 {
196 self.card.primary_uses = uses.into_iter().map(Into::into).collect();
197 self
198 }
199
200 #[must_use]
202 pub fn out_of_scope_uses<I, S>(mut self, uses: I) -> Self
203 where
204 I: IntoIterator<Item = S>,
205 S: Into<String>,
206 {
207 self.card.out_of_scope_uses = uses.into_iter().map(Into::into).collect();
208 self
209 }
210
211 #[must_use]
213 pub fn limitations<I, S>(mut self, limitations: I) -> Self
214 where
215 I: IntoIterator<Item = S>,
216 S: Into<String>,
217 {
218 self.card.limitations = limitations.into_iter().map(Into::into).collect();
219 self
220 }
221
222 #[must_use]
224 pub fn ethical_considerations<I, S>(mut self, considerations: I) -> Self
225 where
226 I: IntoIterator<Item = S>,
227 S: Into<String>,
228 {
229 self.card.ethical_considerations = considerations.into_iter().map(Into::into).collect();
230 self
231 }
232
233 #[must_use]
235 pub fn parent_model(mut self, parent: ModelReference) -> Self {
236 self.card.parent_model = Some(parent);
237 self
238 }
239
240 #[must_use]
242 pub fn derived_from(mut self, models: Vec<ModelReference>) -> Self {
243 self.card.derived_from = models;
244 self
245 }
246
247 #[must_use]
249 pub fn build(self) -> ModelCard {
250 self.card
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::data::DatasetVersion;
258 use crate::model::ModelVersion;
259 use crate::recipe::RecipeVersion;
260
261 #[test]
262 fn test_model_card_new() {
263 let card = ModelCard::new("A fraud detection model");
264 assert_eq!(card.description, "A fraud detection model");
265 assert!(card.metrics.is_empty());
266 }
267
268 #[test]
269 fn test_model_card_builder() {
270 let card = ModelCard::builder()
271 .description("Fraud detector v1")
272 .metrics([("auc", 0.95), ("f1", 0.88)])
273 .primary_uses(["Fraud detection in payment transactions"])
274 .limitations(["May have reduced accuracy on international transactions"])
275 .build();
276
277 assert_eq!(card.description, "Fraud detector v1");
278 assert_eq!(card.metrics.get("auc"), Some(&0.95));
279 assert_eq!(card.metrics.get("f1"), Some(&0.88));
280 assert_eq!(card.primary_uses.len(), 1);
281 assert_eq!(card.limitations.len(), 1);
282 }
283
284 #[test]
285 fn test_model_card_with_references() {
286 let dataset_ref = DatasetReference::new("transactions", DatasetVersion::new(1, 0, 0));
287 let recipe_ref = RecipeReference::new("fraud-training", RecipeVersion::new(1, 0, 0));
288 let parent_ref = ModelReference::new("base-classifier", ModelVersion::new(1, 0, 0));
289
290 let card = ModelCard::builder()
291 .description("Fine-tuned fraud detector")
292 .training_data(dataset_ref.clone())
293 .training_recipe(recipe_ref.clone())
294 .parent_model(parent_ref.clone())
295 .build();
296
297 assert_eq!(card.training_data.unwrap().name, "transactions");
298 assert_eq!(card.training_recipe.unwrap().name, "fraud-training");
299 assert_eq!(card.parent_model.unwrap().name, "base-classifier");
300 }
301
302 #[test]
303 fn test_model_card_add_methods() {
304 let mut card = ModelCard::new("Test model");
305 card.add_metric("accuracy", 0.92);
306 card.add_primary_use("Classification");
307 card.add_limitation("Requires normalized inputs");
308
309 assert_eq!(card.metrics.get("accuracy"), Some(&0.92));
310 assert_eq!(card.primary_uses, vec!["Classification"]);
311 assert_eq!(card.limitations, vec!["Requires normalized inputs"]);
312 }
313
314 #[test]
315 fn test_model_card_serialization() {
316 let card =
317 ModelCard::builder().description("Test model").metrics([("accuracy", 0.95)]).build();
318
319 let json = serde_json::to_string(&card).unwrap();
320 let deserialized: ModelCard = serde_json::from_str(&json).unwrap();
321
322 assert_eq!(card.description, deserialized.description);
323 assert_eq!(card.metrics, deserialized.metrics);
324 }
325
326 #[test]
327 fn test_training_duration() {
328 let card =
329 ModelCard::builder().description("Model").training_duration(Duration::hours(2)).build();
330
331 let duration = card.training_duration().unwrap();
332 assert_eq!(duration.num_hours(), 2);
333 }
334}