use std::collections::HashMap;
use std::fs;
use std::path::Path;
use torsh_core::error::{Result, TorshError};
use crate::exporter::{ExportConfig, PackageExporter};
use crate::importer::PackageImporter;
use crate::manifest::{ModuleInfo, PackageManifest};
use crate::resources::{Resource, ResourceType};
use crate::utils::calculate_hash;
use crate::PACKAGE_FORMAT_VERSION;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct Package {
pub(crate) manifest: PackageManifest,
pub(crate) resources: HashMap<String, Resource>,
}
impl Package {
pub fn new(name: String, version: String) -> Self {
let manifest = PackageManifest {
name,
version,
format_version: PACKAGE_FORMAT_VERSION.to_string(),
created_at: chrono::Utc::now(),
author: None,
description: None,
license: None,
dependencies: HashMap::new(),
modules: Vec::new(),
resources: Vec::new(),
metadata: HashMap::new(),
signature: None,
};
Self {
manifest,
resources: HashMap::new(),
}
}
pub fn name(&self) -> &str {
&self.manifest.name
}
pub fn get_version(&self) -> &str {
&self.manifest.version
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let importer = PackageImporter::new(crate::importer::ImportConfig::default());
importer.import_package(path)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let exporter = PackageExporter::new(ExportConfig::default());
exporter.export_package(self, path)
}
#[cfg(feature = "with-nn")]
pub fn add_module<M: torsh_nn::Module>(
&mut self,
name: &str,
module: &M,
include_source: bool,
) -> Result<()> {
let parameters = module.parameters();
let mut param_metadata = Vec::new();
for (param_name, param) in parameters {
let shape = param.shape().unwrap_or_default();
let numel = param.numel().unwrap_or(0);
let metadata = format!(
"{}:shape={:?},numel={},requires_grad={}",
param_name,
shape,
numel,
param.requires_grad()
);
param_metadata.push(metadata);
}
let param_data = serde_json::to_vec(¶m_metadata)
.map_err(|e| TorshError::SerializationError(e.to_string()))?;
let resource = Resource {
name: format!("{}.pth", name),
resource_type: ResourceType::Model,
data: param_data,
metadata: {
let mut meta = HashMap::new();
meta.insert("type".to_string(), "module".to_string());
meta.insert("name".to_string(), name.to_string());
meta
},
};
self.resources.insert(resource.name.clone(), resource);
let module_info = ModuleInfo {
name: name.to_string(),
class_name: name.to_string(), version: "1.0.0".to_string(), dependencies: Vec::new(),
has_source: include_source,
};
self.manifest.modules.push(module_info);
self.manifest.dependencies.insert(
"torsh-nn".to_string(),
env!("CARGO_PKG_VERSION").to_string(),
);
if include_source {
let source_placeholder = format!(
"// Source code for module: {}\n\
// In a production implementation, this would contain:\n\
// - Module definition and implementation\n\
// - Parameter initialization code\n\
// - Forward pass implementation\n\
// \n\
// Note: Automatic source code extraction requires additional\n\
// infrastructure like proc macros or reflection capabilities.\n\
// \n\
// Module class: {}\n",
name, name
);
let source_resource = Resource {
name: format!("{}.rs", name),
resource_type: ResourceType::Source,
data: source_placeholder.as_bytes().to_vec(),
metadata: {
let mut meta = HashMap::new();
meta.insert("type".to_string(), "source".to_string());
meta.insert("module".to_string(), name.to_string());
meta.insert("language".to_string(), "rust".to_string());
meta
},
};
self.resources
.insert(source_resource.name.clone(), source_resource);
}
Ok(())
}
pub fn get_module(&self, name: &str) -> Result<Vec<u8>> {
let module_path = format!("{}.pth", name);
self.resources
.get(&module_path)
.map(|resource| resource.data.clone())
.ok_or_else(|| {
TorshError::General(torsh_core::error::GeneralError::InvalidArgument(format!(
"Module '{}' not found",
name
)))
})
}
pub fn add_data_file<P: AsRef<Path>>(&mut self, name: &str, path: P) -> Result<()> {
let data = fs::read(&path)
.map_err(|e| TorshError::IoError(format!("Failed to read file: {}", e)))?;
let resource = Resource {
name: name.to_string(),
resource_type: ResourceType::Data,
data,
metadata: HashMap::new(),
};
self.resources.insert(resource.name.clone(), resource);
Ok(())
}
pub fn add_source_file(&mut self, name: &str, source: &str) -> Result<()> {
let resource = Resource {
name: format!("{}.rs", name),
resource_type: ResourceType::Source,
data: source.as_bytes().to_vec(),
metadata: HashMap::new(),
};
self.resources.insert(resource.name.clone(), resource);
Ok(())
}
pub fn list_modules(&self) -> Vec<&ModuleInfo> {
self.manifest.modules.iter().collect()
}
pub fn metadata(&self) -> &PackageManifest {
&self.manifest
}
pub fn resources(&self) -> &std::collections::HashMap<String, Resource> {
&self.resources
}
pub fn add_resource(&mut self, resource: Resource) {
self.resources.insert(resource.name.clone(), resource);
}
pub fn resources_mut(&mut self) -> &mut std::collections::HashMap<String, Resource> {
&mut self.resources
}
pub fn manifest_mut(&mut self) -> &mut PackageManifest {
&mut self.manifest
}
pub fn add_dependency(&mut self, name: &str, version: &str) {
self.manifest
.dependencies
.insert(name.to_string(), version.to_string());
}
pub fn verify(&self) -> Result<bool> {
if self.manifest.name.is_empty() {
return Ok(false);
}
if self.manifest.version.is_empty() {
return Ok(false);
}
let format_version =
semver::Version::parse(&self.manifest.format_version).map_err(|e| {
TorshError::General(torsh_core::error::GeneralError::InvalidArgument(
e.to_string(),
))
})?;
let current_format = semver::Version::parse(PACKAGE_FORMAT_VERSION).map_err(|e| {
TorshError::General(torsh_core::error::GeneralError::ConfigError(e.to_string()))
})?;
if format_version.major != current_format.major {
return Ok(false);
}
for resource in self.resources.values() {
if let Some(expected_hash) = resource.metadata.get("sha256") {
let actual_hash = calculate_hash(&resource.data);
if &actual_hash != expected_hash {
return Ok(false);
}
}
}
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_package_creation() {
let mut package = Package::new("test_package".to_string(), "1.0.0".to_string());
package.add_source_file("test", "fn main() {}").unwrap();
assert_eq!(package.manifest.name, "test_package");
assert_eq!(package.manifest.version, "1.0.0");
assert_eq!(package.resources.len(), 1);
}
#[test]
fn test_package_verification() {
let package = Package::new("test".to_string(), "1.0.0".to_string());
assert!(package.verify().unwrap());
let mut invalid_package = Package::new("".to_string(), "1.0.0".to_string());
assert!(!invalid_package.verify().unwrap());
invalid_package.manifest.name = "test".to_string();
invalid_package.manifest.version = "".to_string();
assert!(!invalid_package.verify().unwrap());
}
#[test]
fn test_add_dependency() {
let mut package = Package::new("test".to_string(), "1.0.0".to_string());
package.add_dependency("serde", "1.0");
assert_eq!(
package.manifest.dependencies.get("serde"),
Some(&"1.0".to_string())
);
}
#[test]
fn test_add_data_file() {
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.txt");
std::fs::write(&file_path, b"test data").unwrap();
let mut package = Package::new("test".to_string(), "1.0.0".to_string());
package.add_data_file("test.txt", &file_path).unwrap();
let resource = package.resources.get("test.txt").unwrap();
assert_eq!(resource.data, b"test data");
assert_eq!(resource.resource_type, ResourceType::Data);
}
}