use crate::python::error::to_py_err;
use crate::python::training::PyModel;
use pyo3::prelude::*;
use std::collections::HashMap;
use std::path::Path;
#[pyclass]
pub struct PyModelSerializer {}
#[pymethods]
impl PyModelSerializer {
#[new]
pub fn new() -> Self {
PyModelSerializer {}
}
#[staticmethod]
pub fn save(model: &PyModel, path: &str) -> PyResult<()> {
println!("Saving model '{}' to: {}", model.name, path);
if let Some(parent) = Path::new(path).parent() {
std::fs::create_dir_all(parent).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to create directory: {}", e))
})?;
}
let metadata = format!(
"model_name: {}\nlayers: {:?}\ncompiled: {}",
model.name, model.layers, model.compiled
);
std::fs::write(path, metadata).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to save model: {}", e))
})?;
println!("Model saved successfully");
Ok(())
}
#[staticmethod]
pub fn load(path: &str) -> PyResult<PyModel> {
let content = std::fs::read_to_string(path).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to load model: {}", e))
})?;
println!("Loading model from: {}", path);
println!("Model content:\n{}", content);
let model = PyModel::new(Some("LoadedModel".to_string()));
Ok(model)
}
#[staticmethod]
pub fn get_model_info(path: &str) -> PyResult<HashMap<String, String>> {
let mut info = HashMap::new();
if Path::new(path).exists() {
let metadata = std::fs::metadata(path).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to get file info: {}", e))
})?;
info.insert("path".to_string(), path.to_string());
info.insert("size".to_string(), metadata.len().to_string());
info.insert("exists".to_string(), "true".to_string());
if let Ok(content) = std::fs::read_to_string(path) {
info.insert(
"preview".to_string(),
content.lines().take(3).collect::<Vec<_>>().join(" | "),
);
}
} else {
info.insert("exists".to_string(), "false".to_string());
}
Ok(info)
}
#[staticmethod]
pub fn export(model: &PyModel, path: &str, format: Option<String>) -> PyResult<()> {
let format = format.unwrap_or_else(|| "rustorch".to_string());
match format.as_str() {
"rustorch" => {
Self::save(model, path)?;
}
"onnx" => {
println!("Exporting to ONNX format: {}", path);
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"ONNX export not implemented",
));
}
"pytorch" => {
println!("Exporting to PyTorch format: {}", path);
return Err(pyo3::exceptions::PyNotImplementedError::new_err(
"PyTorch export not implemented",
));
}
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Unsupported export format: {}",
format
)));
}
}
Ok(())
}
pub fn __repr__(&self) -> String {
"ModelSerializer()".to_string()
}
}
#[pyclass]
pub struct PyModelComparator {}
#[pymethods]
impl PyModelComparator {
#[new]
pub fn new() -> Self {
PyModelComparator {}
}
#[staticmethod]
pub fn compare(model1: &PyModel, model2: &PyModel) -> HashMap<String, String> {
let mut comparison = HashMap::new();
comparison.insert("model1_name".to_string(), model1.name.clone());
comparison.insert("model2_name".to_string(), model2.name.clone());
comparison.insert("model1_layers".to_string(), model1.layers.len().to_string());
comparison.insert("model2_layers".to_string(), model2.layers.len().to_string());
comparison.insert("model1_compiled".to_string(), model1.compiled.to_string());
comparison.insert("model2_compiled".to_string(), model2.compiled.to_string());
let layers_match = model1.layers == model2.layers;
comparison.insert("layers_identical".to_string(), layers_match.to_string());
let same_compilation = model1.compiled == model2.compiled;
comparison.insert(
"compilation_identical".to_string(),
same_compilation.to_string(),
);
comparison
}
#[staticmethod]
pub fn get_stats(model: &PyModel) -> HashMap<String, String> {
let mut stats = HashMap::new();
stats.insert("name".to_string(), model.name.clone());
stats.insert("num_layers".to_string(), model.layers.len().to_string());
stats.insert("compiled".to_string(), model.compiled.to_string());
let mut layer_types = HashMap::new();
for layer in &model.layers {
let layer_type = if layer.contains("Dense") {
"Dense"
} else if layer.contains("Conv") {
"Convolutional"
} else if layer.contains("Dropout") {
"Dropout"
} else {
"Other"
};
*layer_types.entry(layer_type.to_string()).or_insert(0) += 1;
}
for (layer_type, count) in layer_types {
stats.insert(
format!("{}_layers", layer_type.to_lowercase()),
count.to_string(),
);
}
stats
}
}
#[pyclass]
pub struct PyConfig {
pub(crate) settings: HashMap<String, String>,
}
#[pymethods]
impl PyConfig {
#[new]
pub fn new() -> Self {
let mut settings = HashMap::new();
settings.insert("device".to_string(), "cpu".to_string());
settings.insert("dtype".to_string(), "float32".to_string());
settings.insert("backend".to_string(), "native".to_string());
settings.insert("num_threads".to_string(), "4".to_string());
settings.insert("memory_limit".to_string(), "1024".to_string());
PyConfig { settings }
}
pub fn get(&self, key: &str) -> Option<String> {
self.settings.get(key).cloned()
}
pub fn set(&mut self, key: String, value: String) {
self.settings.insert(key, value);
}
pub fn all(&self) -> HashMap<String, String> {
self.settings.clone()
}
pub fn load_from_file(&mut self, path: &str) -> PyResult<()> {
if !Path::new(path).exists() {
return Err(pyo3::exceptions::PyFileNotFoundError::new_err(format!(
"Configuration file not found: {}",
path
)));
}
let content = std::fs::read_to_string(path).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to read config: {}", e))
})?;
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
self.settings
.insert(key.trim().to_string(), value.trim().to_string());
}
}
Ok(())
}
pub fn save_to_file(&self, path: &str) -> PyResult<()> {
let mut content = String::new();
content.push_str("# RusTorch Configuration\n");
content.push_str("# Auto-generated configuration file\n\n");
for (key, value) in &self.settings {
content.push_str(&format!("{}={}\n", key, value));
}
std::fs::write(path, content).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to save config: {}", e))
})?;
Ok(())
}
pub fn reset(&mut self) {
self.settings.clear();
*self = Self::new();
}
pub fn __repr__(&self) -> String {
format!("Config(settings={})", self.settings.len())
}
}
#[pyclass]
pub struct PyProfiler {
pub(crate) enabled: bool,
pub(crate) timings: HashMap<String, Vec<f64>>,
}
#[pymethods]
impl PyProfiler {
#[new]
pub fn new() -> Self {
PyProfiler {
enabled: false,
timings: HashMap::new(),
}
}
pub fn enable(&mut self) {
self.enabled = true;
self.timings.clear();
}
pub fn disable(&mut self) {
self.enabled = false;
}
pub fn record(&mut self, name: String, duration: f64) {
if self.enabled {
self.timings.entry(name).or_default().push(duration);
}
}
pub fn get_stats(&self) -> HashMap<String, HashMap<String, f64>> {
let mut stats = HashMap::new();
for (name, times) in &self.timings {
let mut operation_stats = HashMap::new();
if !times.is_empty() {
let sum: f64 = times.iter().sum();
let count = times.len() as f64;
let mean = sum / count;
let min = times.iter().copied().fold(f64::INFINITY, f64::min);
let max = times.iter().copied().fold(f64::NEG_INFINITY, f64::max);
operation_stats.insert("count".to_string(), count);
operation_stats.insert("total".to_string(), sum);
operation_stats.insert("mean".to_string(), mean);
operation_stats.insert("min".to_string(), min);
operation_stats.insert("max".to_string(), max);
}
stats.insert(name.clone(), operation_stats);
}
stats
}
pub fn clear(&mut self) {
self.timings.clear();
}
pub fn __repr__(&self) -> String {
format!(
"Profiler(enabled={}, operations={})",
self.enabled,
self.timings.len()
)
}
}
#[pyfunction]
pub fn get_system_info() -> HashMap<String, String> {
let mut info = HashMap::new();
info.insert(
"rust_version".to_string(),
env!("CARGO_PKG_VERSION").to_string(),
);
info.insert(
"rustorch_version".to_string(),
env!("CARGO_PKG_VERSION").to_string(),
);
info.insert("target_os".to_string(), std::env::consts::OS.to_string());
info.insert(
"target_arch".to_string(),
std::env::consts::ARCH.to_string(),
);
let num_cpus = num_cpus::get();
info.insert("cpu_count".to_string(), num_cpus.to_string());
info.insert("available_memory".to_string(), "unknown".to_string());
info
}
#[pyfunction]
pub fn set_seed(seed: u64) {
println!("Setting random seed: {}", seed);
}
#[pyfunction]
pub fn cuda_is_available() -> bool {
false }
#[pyfunction]
pub fn metal_is_available() -> bool {
cfg!(target_os = "macos") && std::env::consts::ARCH == "aarch64"
}