use std::collections::HashMap;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum MetadataField {
CodeRepository,
Commit,
Description,
License,
ModelRepository,
OnnxHash,
RunId,
RunUrl,
ProducerName,
ProducerVersion,
Custom(String),
}
impl MetadataField {
fn name(&self) -> &str {
match self {
Self::CodeRepository => "code_repository",
Self::Commit => "commit",
Self::Description => "description",
Self::License => "license",
Self::ModelRepository => "model_repository",
Self::OnnxHash => "onnx_hash",
Self::RunId => "run_id",
Self::RunUrl => "run_url",
Self::ProducerName => "producer_name",
Self::ProducerVersion => "producer_version",
Self::Custom(value) => value,
}
}
}
impl std::fmt::Display for MetadataField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.name())
}
}
#[derive(Debug, Default)]
pub struct ModelMetadata {
fields: HashMap<MetadataField, String>,
}
impl ModelMetadata {
pub(crate) fn from_fields(fields: impl IntoIterator<Item = (MetadataField, String)>) -> Self {
Self {
fields: fields.into_iter().collect(),
}
}
pub fn onnx_hash(&self) -> Option<&str> {
self.field(&MetadataField::OnnxHash)
}
pub fn description(&self) -> Option<&str> {
self.field(&MetadataField::Description)
}
pub fn license(&self) -> Option<&str> {
self.field(&MetadataField::License)
}
pub fn commit(&self) -> Option<&str> {
self.field(&MetadataField::Commit)
}
pub fn code_repository(&self) -> Option<&str> {
self.field(&MetadataField::CodeRepository)
}
pub fn model_repository(&self) -> Option<&str> {
self.field(&MetadataField::ModelRepository)
}
pub fn run_id(&self) -> Option<&str> {
self.field(&MetadataField::RunId)
}
pub fn run_url(&self) -> Option<&str> {
self.field(&MetadataField::RunUrl)
}
pub fn producer_name(&self) -> Option<&str> {
self.field(&MetadataField::ProducerName)
}
pub fn producer_version(&self) -> Option<&str> {
self.field(&MetadataField::ProducerVersion)
}
pub fn get(&self, name: &str) -> Option<&str> {
let key = MetadataField::Custom(
name.to_string(), );
self.fields.get(&key).map(|v| v.as_str())
}
fn field(&self, field: &MetadataField) -> Option<&str> {
self.fields.get(field).map(|x| x.as_str())
}
pub fn fields(&self) -> impl Iterator<Item = (&str, &str)> {
self.fields
.iter()
.map(|(field, val)| (field.name(), val.as_str()))
}
}
#[cfg(test)]
mod tests {
use super::{MetadataField, ModelMetadata};
#[test]
fn test_model_metadata() {
let model_metadata = ModelMetadata::from_fields([
(MetadataField::OnnxHash, "abc".to_string()),
(MetadataField::Description, "A simple model".to_string()),
(MetadataField::License, "BSD-2-Clause".to_string()),
(MetadataField::Commit, "def".to_string()),
(
MetadataField::CodeRepository,
"https://github.com/robertknight/rten".to_string(),
),
(
MetadataField::ModelRepository,
"https://huggingface.co/robertknight/rten".to_string(),
),
(MetadataField::RunId, "1234".to_string()),
(
MetadataField::RunUrl,
"https://wandb.ai/robertknight/text-detection/runs/1234".to_string(),
),
]);
assert_eq!(model_metadata.onnx_hash(), Some("abc"));
assert_eq!(model_metadata.description(), Some("A simple model"));
assert_eq!(model_metadata.license(), Some("BSD-2-Clause"));
assert_eq!(model_metadata.commit(), Some("def"));
assert_eq!(
model_metadata.code_repository(),
Some("https://github.com/robertknight/rten")
);
assert_eq!(
model_metadata.model_repository(),
Some("https://huggingface.co/robertknight/rten")
);
assert_eq!(model_metadata.run_id(), Some("1234"));
assert_eq!(
model_metadata.run_url(),
Some("https://wandb.ai/robertknight/text-detection/runs/1234")
);
}
}