use crate::error::Result;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
Json,
SafeTensors,
Cbor,
MessagePack,
}
pub trait ModelSerialize {
fn save(&self, path: &Path, format: ModelFormat) -> Result<()>;
fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>>;
fn architecture_name(&self) -> &str;
fn model_version(&self) -> String {
"0.1.0".to_string()
}
}
pub trait ModelDeserialize: Sized {
fn load(path: &Path, format: ModelFormat) -> Result<Self>;
fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self>;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ModelMetadata {
pub architecture: String,
pub version: String,
pub framework_version: String,
pub num_parameters: usize,
pub dtype: String,
pub extra: std::collections::HashMap<String, String>,
}
impl ModelMetadata {
pub fn new(architecture: &str, dtype: &str, num_parameters: usize) -> Self {
Self {
architecture: architecture.to_string(),
version: "0.1.0".to_string(),
framework_version: env!("CARGO_PKG_VERSION").to_string(),
num_parameters,
dtype: dtype.to_string(),
extra: std::collections::HashMap::new(),
}
}
pub fn with_extra(mut self, key: &str, value: &str) -> Self {
self.extra.insert(key.to_string(), value.to_string());
self
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TensorInfo {
pub name: String,
pub dtype: String,
pub shape: Vec<usize>,
pub data_offset: usize,
pub byte_length: usize,
}
impl TensorInfo {
pub fn new(
name: &str,
dtype: &str,
shape: Vec<usize>,
data_offset: usize,
byte_length: usize,
) -> Self {
Self {
name: name.to_string(),
dtype: dtype.to_string(),
shape,
data_offset,
byte_length,
}
}
pub fn num_elements(&self) -> usize {
if self.shape.is_empty() {
0
} else {
self.shape.iter().product()
}
}
}
#[derive(Debug, Clone)]
pub struct NamedParameters {
pub parameters: Vec<(String, Vec<f64>, Vec<usize>)>,
}
impl NamedParameters {
pub fn new() -> Self {
Self {
parameters: Vec::new(),
}
}
pub fn add(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
self.parameters.push((name.to_string(), values, shape));
}
pub fn total_parameters(&self) -> usize {
self.parameters.iter().map(|(_, v, _)| v.len()).sum()
}
pub fn get(&self, name: &str) -> Option<&(String, Vec<f64>, Vec<usize>)> {
self.parameters.iter().find(|(n, _, _)| n == name)
}
pub fn len(&self) -> usize {
self.parameters.len()
}
pub fn is_empty(&self) -> bool {
self.parameters.is_empty()
}
}
impl Default for NamedParameters {
fn default() -> Self {
Self::new()
}
}
pub trait ExtractParameters {
fn extract_named_parameters(&self) -> Result<NamedParameters>;
fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_metadata_creation() {
let metadata = ModelMetadata::new("ResNet", "f32", 11_000_000);
assert_eq!(metadata.architecture, "ResNet");
assert_eq!(metadata.dtype, "f32");
assert_eq!(metadata.num_parameters, 11_000_000);
}
#[test]
fn test_model_metadata_with_extra() {
let metadata = ModelMetadata::new("BERT", "f32", 110_000_000)
.with_extra("variant", "base-uncased")
.with_extra("vocab_size", "30522");
assert_eq!(
metadata.extra.get("variant"),
Some(&"base-uncased".to_string())
);
assert_eq!(metadata.extra.get("vocab_size"), Some(&"30522".to_string()));
}
#[test]
fn test_tensor_info() {
let info = TensorInfo::new("layer1.weight", "F32", vec![768, 3072], 0, 768 * 3072 * 4);
assert_eq!(info.num_elements(), 768 * 3072);
assert_eq!(info.byte_length, 768 * 3072 * 4);
}
#[test]
fn test_tensor_info_empty_shape() {
let info = TensorInfo::new("empty", "F32", vec![], 0, 0);
assert_eq!(info.num_elements(), 0);
}
#[test]
fn test_named_parameters() {
let mut params = NamedParameters::new();
assert!(params.is_empty());
assert_eq!(params.len(), 0);
params.add("layer1.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
params.add("layer1.bias", vec![0.1, 0.2], vec![2]);
assert_eq!(params.len(), 2);
assert!(!params.is_empty());
assert_eq!(params.total_parameters(), 6);
let found = params.get("layer1.weight");
assert!(found.is_some());
let (name, values, shape) = found.expect("parameter should exist");
assert_eq!(name, "layer1.weight");
assert_eq!(values, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(shape, &[2, 2]);
assert!(params.get("nonexistent").is_none());
}
#[test]
fn test_model_format_enum() {
let fmt = ModelFormat::SafeTensors;
assert_eq!(fmt, ModelFormat::SafeTensors);
assert_ne!(fmt, ModelFormat::Json);
let _json = ModelFormat::Json;
let _st = ModelFormat::SafeTensors;
let _cbor = ModelFormat::Cbor;
let _mp = ModelFormat::MessagePack;
}
}