use crate::data::DatasetReference;
use crate::recipe::RecipeReference;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::ModelReference;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCard {
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_data: Option<DatasetReference>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_recipe: Option<RecipeReference>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_date: Option<DateTime<Utc>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_duration_secs: Option<i64>,
#[serde(default)]
pub metrics: HashMap<String, f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub evaluation_data: Option<DatasetReference>,
#[serde(default)]
pub primary_uses: Vec<String>,
#[serde(default)]
pub out_of_scope_uses: Vec<String>,
#[serde(default)]
pub limitations: Vec<String>,
#[serde(default)]
pub ethical_considerations: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent_model: Option<ModelReference>,
#[serde(default)]
pub derived_from: Vec<ModelReference>,
#[serde(default)]
pub extra: HashMap<String, serde_json::Value>,
}
impl ModelCard {
#[must_use]
pub fn builder() -> ModelCardBuilder {
ModelCardBuilder::new()
}
#[must_use]
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
training_data: None,
training_recipe: None,
training_date: None,
training_duration_secs: None,
metrics: HashMap::new(),
evaluation_data: None,
primary_uses: Vec::new(),
out_of_scope_uses: Vec::new(),
limitations: Vec::new(),
ethical_considerations: Vec::new(),
parent_model: None,
derived_from: Vec::new(),
extra: HashMap::new(),
}
}
#[must_use]
pub fn training_duration(&self) -> Option<Duration> {
self.training_duration_secs.map(Duration::seconds)
}
pub fn add_metric(&mut self, name: impl Into<String>, value: f64) {
self.metrics.insert(name.into(), value);
}
pub fn add_primary_use(&mut self, use_case: impl Into<String>) {
self.primary_uses.push(use_case.into());
}
pub fn add_limitation(&mut self, limitation: impl Into<String>) {
self.limitations.push(limitation.into());
}
}
impl Default for ModelCard {
fn default() -> Self {
Self::new("")
}
}
#[derive(Debug, Default)]
pub struct ModelCardBuilder {
card: ModelCard,
}
impl ModelCardBuilder {
#[must_use]
pub fn new() -> Self {
Self { card: ModelCard::default() }
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.card.description = description.into();
self
}
#[must_use]
pub fn training_data(mut self, data: DatasetReference) -> Self {
self.card.training_data = Some(data);
self
}
#[must_use]
pub fn training_recipe(mut self, recipe: RecipeReference) -> Self {
self.card.training_recipe = Some(recipe);
self
}
#[must_use]
pub fn training_date(mut self, date: DateTime<Utc>) -> Self {
self.card.training_date = Some(date);
self
}
#[must_use]
pub fn training_duration(mut self, duration: Duration) -> Self {
self.card.training_duration_secs = Some(duration.num_seconds());
self
}
#[must_use]
pub fn metrics<I, K>(mut self, metrics: I) -> Self
where
I: IntoIterator<Item = (K, f64)>,
K: Into<String>,
{
for (k, v) in metrics {
self.card.metrics.insert(k.into(), v);
}
self
}
#[must_use]
pub fn evaluation_data(mut self, data: DatasetReference) -> Self {
self.card.evaluation_data = Some(data);
self
}
#[must_use]
pub fn primary_uses<I, S>(mut self, uses: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.card.primary_uses = uses.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn out_of_scope_uses<I, S>(mut self, uses: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.card.out_of_scope_uses = uses.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn limitations<I, S>(mut self, limitations: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.card.limitations = limitations.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn ethical_considerations<I, S>(mut self, considerations: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.card.ethical_considerations = considerations.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn parent_model(mut self, parent: ModelReference) -> Self {
self.card.parent_model = Some(parent);
self
}
#[must_use]
pub fn derived_from(mut self, models: Vec<ModelReference>) -> Self {
self.card.derived_from = models;
self
}
#[must_use]
pub fn build(self) -> ModelCard {
self.card
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::DatasetVersion;
use crate::model::ModelVersion;
use crate::recipe::RecipeVersion;
#[test]
fn test_model_card_new() {
let card = ModelCard::new("A fraud detection model");
assert_eq!(card.description, "A fraud detection model");
assert!(card.metrics.is_empty());
}
#[test]
fn test_model_card_builder() {
let card = ModelCard::builder()
.description("Fraud detector v1")
.metrics([("auc", 0.95), ("f1", 0.88)])
.primary_uses(["Fraud detection in payment transactions"])
.limitations(["May have reduced accuracy on international transactions"])
.build();
assert_eq!(card.description, "Fraud detector v1");
assert_eq!(card.metrics.get("auc"), Some(&0.95));
assert_eq!(card.metrics.get("f1"), Some(&0.88));
assert_eq!(card.primary_uses.len(), 1);
assert_eq!(card.limitations.len(), 1);
}
#[test]
fn test_model_card_with_references() {
let dataset_ref = DatasetReference::new("transactions", DatasetVersion::new(1, 0, 0));
let recipe_ref = RecipeReference::new("fraud-training", RecipeVersion::new(1, 0, 0));
let parent_ref = ModelReference::new("base-classifier", ModelVersion::new(1, 0, 0));
let card = ModelCard::builder()
.description("Fine-tuned fraud detector")
.training_data(dataset_ref.clone())
.training_recipe(recipe_ref.clone())
.parent_model(parent_ref.clone())
.build();
assert_eq!(card.training_data.unwrap().name, "transactions");
assert_eq!(card.training_recipe.unwrap().name, "fraud-training");
assert_eq!(card.parent_model.unwrap().name, "base-classifier");
}
#[test]
fn test_model_card_add_methods() {
let mut card = ModelCard::new("Test model");
card.add_metric("accuracy", 0.92);
card.add_primary_use("Classification");
card.add_limitation("Requires normalized inputs");
assert_eq!(card.metrics.get("accuracy"), Some(&0.92));
assert_eq!(card.primary_uses, vec!["Classification"]);
assert_eq!(card.limitations, vec!["Requires normalized inputs"]);
}
#[test]
fn test_model_card_serialization() {
let card =
ModelCard::builder().description("Test model").metrics([("accuracy", 0.95)]).build();
let json = serde_json::to_string(&card).unwrap();
let deserialized: ModelCard = serde_json::from_str(&json).unwrap();
assert_eq!(card.description, deserialized.description);
assert_eq!(card.metrics, deserialized.metrics);
}
#[test]
fn test_training_duration() {
let card =
ModelCard::builder().description("Model").training_duration(Duration::hours(2)).build();
let duration = card.training_duration().unwrap();
assert_eq!(duration.num_hours(), 2);
}
}