use crate::formats::pytorch::{PyTorchModel, StateDict};
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
#[derive(Debug, Clone)]
pub struct LayerDescription {
pub name: String,
pub layer_type: String,
pub input_shape: Vec<usize>,
pub output_shape: Vec<usize>,
}
#[derive(Debug)]
pub enum SimpleConversionError {
UnsupportedLayer(String),
MissingParameter(String),
InvalidParameter(String),
}
impl fmt::Display for SimpleConversionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SimpleConversionError::UnsupportedLayer(layer) => {
write!(f, "Unsupported layer: {}", layer)
}
SimpleConversionError::MissingParameter(param) => {
write!(f, "Missing parameter: {}", param)
}
SimpleConversionError::InvalidParameter(msg) => write!(f, "Invalid parameter: {}", msg),
}
}
}
impl Error for SimpleConversionError {}
#[derive(Debug, Clone)]
pub struct SimpleLayerDescription {
pub name: String,
pub layer_type: String,
pub parameter_shapes: HashMap<String, Vec<usize>>,
pub num_parameters: usize,
pub tensors: HashMap<String, Tensor<f32>>,
}
#[derive(Debug)]
pub struct SimplifiedPyTorchModel {
pub layers: HashMap<String, SimpleLayerDescription>,
pub execution_order: Vec<String>,
pub total_parameters: usize,
}
pub struct SimplePyTorchConverter;
impl SimplePyTorchConverter {
pub fn convert(
pytorch_model: &PyTorchModel,
) -> Result<SimplifiedPyTorchModel, SimpleConversionError> {
let mut layers = HashMap::new();
let mut total_parameters = 0;
let layer_params = Self::group_parameters_by_layer(&pytorch_model.state_dict)?;
for (layer_name, params) in layer_params {
let layer_info = Self::convert_layer(&layer_name, ¶ms)?;
total_parameters += layer_info.num_parameters;
layers.insert(layer_name.clone(), layer_info);
}
let mut execution_order: Vec<String> = layers.keys().cloned().collect();
execution_order.sort();
Ok(SimplifiedPyTorchModel {
layers,
execution_order,
total_parameters,
})
}
fn group_parameters_by_layer(
state_dict: &StateDict,
) -> Result<
HashMap<String, HashMap<String, &crate::formats::pytorch::TensorData>>,
SimpleConversionError,
> {
let mut layer_params = HashMap::new();
for (param_name, tensor_data) in &state_dict.tensors {
let (layer_name, param_type) = Self::parse_parameter_name(param_name)?;
layer_params
.entry(layer_name)
.or_insert_with(HashMap::new)
.insert(param_type, tensor_data);
}
Ok(layer_params)
}
fn parse_parameter_name(param_name: &str) -> Result<(String, String), SimpleConversionError> {
let parts: Vec<&str> = param_name.split('.').collect();
if parts.len() < 2 {
return Err(SimpleConversionError::InvalidParameter(format!(
"Invalid parameter name: {}",
param_name
)));
}
let param_type = parts.last().unwrap().to_string();
let layer_name = parts[..parts.len() - 1].join(".");
Ok((layer_name, param_type))
}
fn convert_layer(
layer_name: &str,
params: &HashMap<String, &crate::formats::pytorch::TensorData>,
) -> Result<SimpleLayerDescription, SimpleConversionError> {
let layer_type = Self::infer_layer_type(layer_name, params);
let mut tensors = HashMap::new();
let mut parameter_shapes = HashMap::new();
let mut num_parameters = 0;
for (param_name, tensor_data) in params {
let tensor = Self::convert_tensor_data(tensor_data);
let param_count: usize = tensor_data.shape.iter().product();
tensors.insert(param_name.clone(), tensor);
parameter_shapes.insert(param_name.clone(), tensor_data.shape.clone());
num_parameters += param_count;
}
Ok(SimpleLayerDescription {
name: layer_name.to_string(),
layer_type,
parameter_shapes,
num_parameters,
tensors,
})
}
fn infer_layer_type(
layer_name: &str,
params: &HashMap<String, &crate::formats::pytorch::TensorData>,
) -> String {
if layer_name.contains("linear")
|| layer_name.contains("fc")
|| layer_name.contains("classifier")
{
return "Linear".to_string();
}
if layer_name.contains("conv") && !layer_name.contains("transpose") {
return "Conv2d".to_string();
}
if layer_name.contains("bn") || layer_name.contains("batch_norm") {
return "BatchNorm2d".to_string();
}
if let Some(weight) = params.get("weight") {
match weight.shape.len() {
2 => "Linear".to_string(),
4 => "Conv2d".to_string(),
1 => "BatchNorm2d".to_string(),
_ => format!("Unknown_{}D", weight.shape.len()),
}
} else {
"Unknown".to_string()
}
}
fn convert_tensor_data(tensor_data: &crate::formats::pytorch::TensorData) -> Tensor<f32> {
let data: Vec<f32> = tensor_data.data.iter().map(|&x| x as f32).collect();
Tensor::from_vec(data, tensor_data.shape.clone())
}
}
impl SimplifiedPyTorchModel {
pub fn print_summary(&self) {
println!("🤖 Simplified PyTorch Model Summary");
println!("==================================");
println!("Total layers: {}", self.layers.len());
println!("Total parameters: {}", self.total_parameters);
println!();
println!("📋 Layer Details:");
for layer_name in &self.execution_order {
if let Some(layer) = self.layers.get(layer_name) {
println!(" 📦 {}: {}", layer_name, layer.layer_type);
println!(" Parameters: {}", layer.num_parameters);
for (param_name, shape) in &layer.parameter_shapes {
println!(" - {}: {:?}", param_name, shape);
}
println!();
}
}
}
pub fn get_layer(&self, name: &str) -> Option<&SimpleLayerDescription> {
self.layers.get(name)
}
pub fn layer_names(&self) -> Vec<&String> {
self.execution_order.iter().collect()
}
pub fn simulate_forward(
&self,
input_shape: Vec<usize>,
) -> Result<Vec<usize>, SimpleConversionError> {
let mut current_shape = input_shape;
for layer_name in &self.execution_order {
if let Some(layer) = self.layers.get(layer_name) {
current_shape = self.simulate_layer_forward(layer, current_shape)?;
println!("After {}: {:?}", layer_name, current_shape);
}
}
Ok(current_shape)
}
fn simulate_layer_forward(
&self,
layer: &SimpleLayerDescription,
input_shape: Vec<usize>,
) -> Result<Vec<usize>, SimpleConversionError> {
match layer.layer_type.as_str() {
"Linear" => {
if let Some(weight_shape) = layer.parameter_shapes.get("weight") {
if weight_shape.len() == 2 {
let out_features = weight_shape[0];
let mut output_shape = input_shape;
let last_idx = output_shape.len() - 1;
output_shape[last_idx] = out_features;
return Ok(output_shape);
}
}
Err(SimpleConversionError::InvalidParameter(
"Invalid Linear layer".to_string(),
))
}
"Conv2d" => {
if let Some(weight_shape) = layer.parameter_shapes.get("weight") {
if weight_shape.len() == 4 {
let out_channels = weight_shape[0];
let mut output_shape = input_shape;
if output_shape.len() >= 4 {
let channel_idx = output_shape.len() - 3;
output_shape[channel_idx] = out_channels;
}
return Ok(output_shape);
}
}
Err(SimpleConversionError::InvalidParameter(
"Invalid Conv2d layer".to_string(),
))
}
"BatchNorm2d" => {
Ok(input_shape)
}
_ => {
Ok(input_shape)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::formats::pytorch::{StateDict, TensorData};
fn create_simple_test_model() -> PyTorchModel {
let mut state_dict = StateDict::new();
state_dict.tensors.insert(
"fc.weight".to_string(),
TensorData {
shape: vec![10, 5],
data: vec![0.1; 50],
dtype: "f32".to_string(),
},
);
state_dict.tensors.insert(
"fc.bias".to_string(),
TensorData {
shape: vec![10],
data: vec![0.0; 10],
dtype: "f32".to_string(),
},
);
crate::formats::pytorch::PyTorchModel::from_state_dict(state_dict)
}
#[test]
fn test_simple_conversion() {
let pytorch_model = create_simple_test_model();
let converted = SimplePyTorchConverter::convert(&pytorch_model).unwrap();
assert_eq!(converted.layers.len(), 1);
assert!(converted.layers.contains_key("fc"));
assert_eq!(converted.total_parameters, 60); }
#[test]
fn test_layer_type_inference() {
let layer_type = SimplePyTorchConverter::infer_layer_type("fc", &HashMap::new());
assert_eq!(layer_type, "Linear");
let layer_type = SimplePyTorchConverter::infer_layer_type("conv1", &HashMap::new());
assert_eq!(layer_type, "Conv2d");
}
#[test]
fn test_parameter_parsing() {
let (layer_name, param_type) =
SimplePyTorchConverter::parse_parameter_name("features.0.weight").unwrap();
assert_eq!(layer_name, "features.0");
assert_eq!(param_type, "weight");
}
}