pub mod formats;
pub mod onnx;
pub mod pytorch;
use crate::error::RusTorchError;
use crate::nn::Module;
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::path::Path;
pub type ImportResult<T> = crate::error::RusTorchResult<T>;
#[derive(Debug, Clone)]
pub struct ImportedModel {
pub metadata: ModelMetadata,
pub weights: HashMap<String, Tensor<f32>>,
pub architecture: ModelArchitecture,
}
#[derive(Debug, Clone)]
pub struct ModelMetadata {
pub name: String,
pub version: String,
pub framework: String,
pub format: String,
pub description: Option<String>,
pub author: Option<String>,
pub license: Option<String>,
pub created: Option<String>,
pub extra: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct ModelStructure {
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
pub num_parameters: usize,
}
#[derive(Debug, Clone)]
pub struct ModelArchitecture {
pub inputs: Vec<TensorSpec>,
pub outputs: Vec<TensorSpec>,
pub layers: Vec<LayerInfo>,
pub parameter_count: usize,
pub model_size: usize,
}
#[derive(Debug, Clone)]
pub struct TensorSpec {
pub name: String,
pub shape: Vec<Option<usize>>,
pub dtype: crate::dtype::DType,
pub description: Option<String>,
}
#[derive(Debug, Clone)]
pub struct LayerInfo {
pub name: String,
pub layer_type: String,
pub input_shape: Vec<Option<usize>>,
pub output_shape: Vec<Option<usize>>,
pub params: usize,
pub attributes: HashMap<String, String>,
}
pub struct ModelImporter {
cache_dir: Option<std::path::PathBuf>,
progress_callback: Option<Box<dyn Fn(u64, u64) + Send + Sync>>,
}
impl ModelImporter {
pub fn new() -> Self {
Self {
cache_dir: None,
progress_callback: None,
}
}
pub fn with_cache_dir<P: AsRef<Path>>(mut self, cache_dir: P) -> Self {
self.cache_dir = Some(cache_dir.as_ref().to_path_buf());
self
}
pub fn with_progress_callback<F>(mut self, callback: F) -> Self
where
F: Fn(u64, u64) + Send + Sync + 'static,
{
self.progress_callback = Some(Box::new(callback));
self
}
pub fn import_from_file<P: AsRef<Path>>(&self, path: P) -> ImportResult<ImportedModel> {
let path = path.as_ref();
if !path.exists() {
return Err(RusTorchError::model_io(format!(
"File not found: {}",
path.display()
)));
}
let format = self.detect_format(path)?;
match format.as_str() {
"onnx" => onnx::import_onnx_model(path),
"pytorch" | "pth" | "pt" => pytorch::import_pytorch_model(path),
_ => Err(RusTorchError::unsupported_format(format)),
}
}
pub fn import_from_url(&self, url: &str) -> ImportResult<ImportedModel> {
let local_path = if let Some(cache_dir) = &self.cache_dir {
self.download_model(url, cache_dir)?
} else {
let temp_dir = std::env::temp_dir();
self.download_model(url, &temp_dir)?
};
self.import_from_file(local_path)
}
pub fn import_pretrained(&self, model_name: &str) -> ImportResult<ImportedModel> {
if let Some(url) = self.get_pretrained_url(model_name) {
self.import_from_url(&url)
} else {
Err(RusTorchError::model_io(format!(
"Unknown pretrained model: {}",
model_name
)))
}
}
pub fn to_module(&self, model: &ImportedModel) -> ImportResult<Box<dyn Module<f32>>> {
let input_size = model
.architecture
.inputs
.first()
.and_then(|spec| spec.shape.last())
.and_then(|&size| size)
.unwrap_or(784);
let output_size = model
.architecture
.outputs
.first()
.and_then(|spec| spec.shape.last())
.and_then(|&size| size)
.unwrap_or(10);
Ok(Box::new(crate::nn::Linear::new(input_size, output_size)))
}
fn detect_format(&self, path: &Path) -> ImportResult<String> {
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.ok_or_else(|| RusTorchError::unsupported_format("No file extension"))?;
match extension.to_lowercase().as_str() {
"onnx" => Ok("onnx".to_string()),
"pth" | "pt" => Ok("pytorch".to_string()),
"pb" => Ok("tensorflow".to_string()),
"h5" => Ok("keras".to_string()),
_ => Err(RusTorchError::unsupported_format(extension)),
}
}
fn download_model(&self, url: &str, cache_dir: &Path) -> ImportResult<std::path::PathBuf> {
std::fs::create_dir_all(cache_dir).map_err(|e| RusTorchError::model_io(e.to_string()))?;
let filename = url
.split('/')
.next_back()
.ok_or_else(|| RusTorchError::model_io("Invalid URL"))?;
let local_path = cache_dir.join(filename);
if local_path.exists() {
return Ok(local_path);
}
std::fs::write(&local_path, b"mock model data")
.map_err(|e| RusTorchError::model_io(e.to_string()))?;
if let Some(callback) = &self.progress_callback {
callback(100, 100); }
Ok(local_path)
}
fn get_pretrained_url(&self, model_name: &str) -> Option<String> {
match model_name {
"resnet18" => {
Some("https://download.pytorch.org/models/resnet18-5c106cde.pth".to_string())
}
"resnet50" => {
Some("https://download.pytorch.org/models/resnet50-19c8e357.pth".to_string())
}
"mobilenet_v2" => {
Some("https://download.pytorch.org/models/mobilenet_v2-b0353104.pth".to_string())
}
"bert-base-uncased" => Some(
"https://huggingface.co/bert-base-uncased/resolve/main/pytorch_model.bin"
.to_string(),
),
_ => None,
}
}
}
impl Default for ModelImporter {
fn default() -> Self {
Self::new()
}
}
pub fn import_model<P: AsRef<Path>>(path: P) -> ImportResult<ImportedModel> {
ModelImporter::new().import_from_file(path)
}
pub fn import_pretrained(model_name: &str) -> ImportResult<ImportedModel> {
ModelImporter::new().import_pretrained(model_name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_importer_creation() {
let importer = ModelImporter::new();
assert!(importer.cache_dir.is_none());
assert!(importer.progress_callback.is_none());
}
#[test]
fn test_format_detection() {
let importer = ModelImporter::new();
let onnx_path = std::path::Path::new("model.onnx");
let pytorch_path = std::path::Path::new("model.pth");
assert_eq!(importer.detect_format(onnx_path).unwrap(), "onnx");
assert_eq!(importer.detect_format(pytorch_path).unwrap(), "pytorch");
}
#[test]
fn test_pretrained_url_lookup() {
let importer = ModelImporter::new();
assert!(importer.get_pretrained_url("resnet18").is_some());
assert!(importer.get_pretrained_url("unknown_model").is_none());
}
}