use crate::formats::pytorch::TensorData;
use crate::convert::SimpleLayerInfo;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TransformerComponents {
pub attention_layers: Vec<MultiheadAttentionInfo>,
pub ffn_layers: Vec<FeedForwardInfo>,
pub layer_norms: Vec<LayerNormInfo>,
pub position_embeddings: Option<PositionEmbeddingInfo>,
pub token_embeddings: Option<TokenEmbeddingInfo>,
}
#[derive(Debug, Clone)]
pub struct MultiheadAttentionInfo {
pub name: String,
pub num_heads: usize,
pub embed_dim: usize,
pub head_dim: usize,
pub dropout: f32,
pub bias: bool,
pub batch_first: bool,
pub query_weights: Tensor<f32>,
pub key_weights: Tensor<f32>,
pub value_weights: Tensor<f32>,
pub output_weights: Tensor<f32>,
pub biases: Option<Vec<Tensor<f32>>>,
}
#[derive(Debug, Clone)]
pub struct FeedForwardInfo {
pub name: String,
pub input_dim: usize,
pub hidden_dim: usize,
pub linear1_weights: Tensor<f32>,
pub linear2_weights: Tensor<f32>,
pub biases: Option<Vec<Tensor<f32>>>,
pub activation: String,
}
#[derive(Debug, Clone)]
pub struct LayerNormInfo {
pub name: String,
pub normalized_shape: Vec<usize>,
pub weight: Tensor<f32>,
pub bias: Option<Tensor<f32>>,
pub eps: f32,
}
#[derive(Debug, Clone)]
pub struct PositionEmbeddingInfo {
pub max_length: usize,
pub embed_dim: usize,
pub weights: Tensor<f32>,
}
#[derive(Debug, Clone)]
pub struct TokenEmbeddingInfo {
pub vocab_size: usize,
pub embed_dim: usize,
pub weights: Tensor<f32>,
}
#[derive(Debug, Clone)]
pub struct CNNComponents {
pub conv_layers: Vec<ConvLayerInfo>,
pub pool_layers: Vec<PoolLayerInfo>,
pub batch_norms: Vec<BatchNormInfo>,
pub classifier: Option<ClassifierInfo>,
}
#[derive(Debug, Clone)]
pub struct ConvLayerInfo {
pub name: String,
pub in_channels: usize,
pub out_channels: usize,
pub kernel_size: (usize, usize),
pub stride: (usize, usize),
pub padding: (usize, usize),
pub dilation: (usize, usize),
pub groups: usize,
pub weights: Tensor<f32>,
pub bias: Option<Tensor<f32>>,
}
#[derive(Debug, Clone)]
pub struct PoolLayerInfo {
pub name: String,
pub pool_type: PoolType,
pub kernel_size: (usize, usize),
pub stride: Option<(usize, usize)>,
pub padding: Option<(usize, usize)>,
}
#[derive(Debug, Clone)]
pub enum PoolType {
MaxPool,
AvgPool,
AdaptiveMaxPool(Vec<usize>),
AdaptiveAvgPool(Vec<usize>),
}
#[derive(Debug, Clone)]
pub struct BatchNormInfo {
pub name: String,
pub num_features: usize,
pub weight: Tensor<f32>,
pub bias: Tensor<f32>,
pub running_mean: Tensor<f32>,
pub running_var: Tensor<f32>,
pub momentum: f32,
pub eps: f32,
}
#[derive(Debug, Clone)]
pub struct ClassifierInfo {
pub name: String,
pub in_features: usize,
pub num_classes: usize,
pub weights: Tensor<f32>,
pub bias: Option<Tensor<f32>>,
}
pub struct ComplexArchitectureParser;
impl ComplexArchitectureParser {
pub fn parse_transformer(
layers: &HashMap<String, SimpleLayerInfo>
) -> Result<TransformerComponents, RusTorchError> {
let mut attention_layers = Vec::new();
let mut ffn_layers = Vec::new();
let mut layer_norms = Vec::new();
let mut position_embeddings = None;
let mut token_embeddings = None;
for (layer_name, layer_info) in layers {
if Self::is_attention_layer(layer_name) {
if let Ok(attention) = Self::parse_attention_layer(layer_name, layer_info) {
attention_layers.push(attention);
}
} else if Self::is_ffn_layer(layer_name) {
if let Ok(ffn) = Self::parse_ffn_layer(layer_name, layer_info) {
ffn_layers.push(ffn);
}
} else if Self::is_layer_norm(layer_name) {
if let Ok(ln) = Self::parse_layer_norm(layer_name, layer_info) {
layer_norms.push(ln);
}
} else if Self::is_position_embedding(layer_name) {
position_embeddings = Some(Self::parse_position_embedding(layer_name, layer_info)?);
} else if Self::is_token_embedding(layer_name) {
token_embeddings = Some(Self::parse_token_embedding(layer_name, layer_info)?);
}
}
Ok(TransformerComponents {
attention_layers,
ffn_layers,
layer_norms,
position_embeddings,
token_embeddings,
})
}
pub fn parse_cnn(
layers: &HashMap<String, SimpleLayerInfo>
) -> Result<CNNComponents, RusTorchError> {
let mut conv_layers = Vec::new();
let mut pool_layers = Vec::new();
let mut batch_norms = Vec::new();
let mut classifier = None;
for (layer_name, layer_info) in layers {
if Self::is_conv_layer(layer_name, layer_info) {
if let Ok(conv) = Self::parse_conv_layer(layer_name, layer_info) {
conv_layers.push(conv);
}
} else if Self::is_pool_layer(layer_name) {
if let Ok(pool) = Self::parse_pool_layer(layer_name, layer_info) {
pool_layers.push(pool);
}
} else if Self::is_batch_norm(layer_name, layer_info) {
if let Ok(bn) = Self::parse_batch_norm(layer_name, layer_info) {
batch_norms.push(bn);
}
} else if Self::is_classifier(layer_name, layer_info) {
classifier = Some(Self::parse_classifier(layer_name, layer_info)?);
}
}
Ok(CNNComponents {
conv_layers,
pool_layers,
batch_norms,
classifier,
})
}
fn is_attention_layer(layer_name: &str) -> bool {
layer_name.contains("attention") ||
layer_name.contains("attn") ||
layer_name.contains("self_attn") ||
layer_name.contains("multi_head")
}
fn parse_attention_layer(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<MultiheadAttentionInfo, RusTorchError> {
let q_proj = layer_info.tensors.get("q_proj.weight")
.or_else(|| layer_info.tensors.get("query.weight"))
.or_else(|| layer_info.tensors.get("wq.weight"))
.ok_or_else(|| RusTorchError::import_error("query projection".to_string()))?;
let k_proj = layer_info.tensors.get("k_proj.weight")
.or_else(|| layer_info.tensors.get("key.weight"))
.or_else(|| layer_info.tensors.get("wk.weight"))
.ok_or_else(|| RusTorchError::import_error("key projection".to_string()))?;
let v_proj = layer_info.tensors.get("v_proj.weight")
.or_else(|| layer_info.tensors.get("value.weight"))
.or_else(|| layer_info.tensors.get("wv.weight"))
.ok_or_else(|| RusTorchError::import_error("value projection".to_string()))?;
let out_proj = layer_info.tensors.get("out_proj.weight")
.or_else(|| layer_info.tensors.get("output.weight"))
.or_else(|| layer_info.tensors.get("wo.weight"))
.ok_or_else(|| RusTorchError::import_error("output projection".to_string()))?;
let embed_dim = q_proj.shape()[1];
let total_head_dim = q_proj.shape()[0];
let num_heads = Self::infer_num_heads(embed_dim, total_head_dim);
let head_dim = total_head_dim / num_heads;
Ok(MultiheadAttentionInfo {
name: layer_name.to_string(),
num_heads,
embed_dim,
head_dim,
dropout: 0.1, bias: true, batch_first: true, query_weights: q_proj.clone(),
key_weights: k_proj.clone(),
value_weights: v_proj.clone(),
output_weights: out_proj.clone(),
biases: None, })
}
fn infer_num_heads(embed_dim: usize, total_head_dim: usize) -> usize {
if total_head_dim % embed_dim == 0 {
total_head_dim / embed_dim
} else if embed_dim % 64 == 0 && total_head_dim % 64 == 0 {
embed_dim / 64 } else if embed_dim % 32 == 0 && total_head_dim % 32 == 0 {
embed_dim / 32 } else {
8 }
}
fn is_ffn_layer(layer_name: &str) -> bool {
layer_name.contains("ffn") ||
layer_name.contains("feed_forward") ||
layer_name.contains("mlp") ||
(layer_name.contains("linear") && (layer_name.contains("1") || layer_name.contains("2")))
}
fn parse_ffn_layer(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<FeedForwardInfo, RusTorchError> {
let linear1 = layer_info.tensors.get("linear1.weight")
.or_else(|| layer_info.tensors.get("fc1.weight"))
.or_else(|| layer_info.tensors.get("w1.weight"))
.ok_or_else(|| RusTorchError::import_error("first linear layer".to_string()))?;
let linear2 = layer_info.tensors.get("linear2.weight")
.or_else(|| layer_info.tensors.get("fc2.weight"))
.or_else(|| layer_info.tensors.get("w2.weight"))
.ok_or_else(|| RusTorchError::import_error("second linear layer".to_string()))?;
let input_dim = linear1.shape()[1];
let hidden_dim = linear1.shape()[0];
Ok(FeedForwardInfo {
name: layer_name.to_string(),
input_dim,
hidden_dim,
linear1_weights: linear1.clone(),
linear2_weights: linear2.clone(),
biases: None,
activation: "relu".to_string(), })
}
fn is_layer_norm(layer_name: &str) -> bool {
layer_name.contains("layer_norm") ||
layer_name.contains("ln") ||
layer_name.contains("norm")
}
fn parse_layer_norm(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<LayerNormInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("layer norm weight".to_string()))?;
let bias = layer_info.tensors.get("bias");
let normalized_shape = weight.shape().to_vec();
Ok(LayerNormInfo {
name: layer_name.to_string(),
normalized_shape,
weight: weight.clone(),
bias: bias.cloned(),
eps: 1e-5, })
}
fn is_position_embedding(layer_name: &str) -> bool {
layer_name.contains("pos_emb") ||
layer_name.contains("position") ||
layer_name.contains("positional")
}
fn parse_position_embedding(
_layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<PositionEmbeddingInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("position embedding weight".to_string()))?;
let shape = weight.shape();
let max_length = shape[0];
let embed_dim = shape[1];
Ok(PositionEmbeddingInfo {
max_length,
embed_dim,
weights: weight.clone(),
})
}
fn is_token_embedding(layer_name: &str) -> bool {
layer_name.contains("token_emb") ||
layer_name.contains("word_emb") ||
layer_name.contains("embedding") ||
layer_name == "embeddings"
}
fn parse_token_embedding(
_layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<TokenEmbeddingInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("token embedding weight".to_string()))?;
let shape = weight.shape();
let vocab_size = shape[0];
let embed_dim = shape[1];
Ok(TokenEmbeddingInfo {
vocab_size,
embed_dim,
weights: weight.clone(),
})
}
fn is_conv_layer(layer_name: &str, layer_info: &SimpleLayerInfo) -> bool {
layer_name.contains("conv") && layer_info.layer_type == "Conv2d"
}
fn parse_conv_layer(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<ConvLayerInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("conv weight".to_string()))?;
let shape = weight.shape();
let out_channels = shape[0];
let in_channels = shape[1];
let kernel_h = shape[2];
let kernel_w = shape[3];
let bias = layer_info.tensors.get("bias");
Ok(ConvLayerInfo {
name: layer_name.to_string(),
in_channels,
out_channels,
kernel_size: (kernel_h, kernel_w),
stride: (1, 1), padding: (0, 0), dilation: (1, 1), groups: 1, weights: weight.clone(),
bias: bias.cloned(),
})
}
fn is_pool_layer(layer_name: &str) -> bool {
layer_name.contains("pool") || layer_name.contains("avgpool") || layer_name.contains("maxpool")
}
fn parse_pool_layer(
layer_name: &str,
_layer_info: &SimpleLayerInfo
) -> Result<PoolLayerInfo, RusTorchError> {
let pool_type = if layer_name.contains("max") {
PoolType::MaxPool
} else if layer_name.contains("avg") {
PoolType::AvgPool
} else if layer_name.contains("adaptive") {
if layer_name.contains("max") {
PoolType::AdaptiveMaxPool(vec![1, 1]) } else {
PoolType::AdaptiveAvgPool(vec![1, 1]) }
} else {
PoolType::MaxPool };
Ok(PoolLayerInfo {
name: layer_name.to_string(),
pool_type,
kernel_size: (2, 2), stride: Some((2, 2)), padding: Some((0, 0)), })
}
fn is_batch_norm(layer_name: &str, layer_info: &SimpleLayerInfo) -> bool {
(layer_name.contains("bn") || layer_name.contains("batch_norm")) &&
layer_info.layer_type == "BatchNorm2d"
}
fn parse_batch_norm(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<BatchNormInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("batch norm weight".to_string()))?;
let bias = layer_info.tensors.get("bias")
.ok_or_else(|| RusTorchError::import_error("batch norm bias".to_string()))?;
let running_mean = layer_info.tensors.get("running_mean")
.ok_or_else(|| RusTorchError::import_error("batch norm running mean".to_string()))?;
let running_var = layer_info.tensors.get("running_var")
.ok_or_else(|| RusTorchError::import_error("batch norm running var".to_string()))?;
let num_features = weight.shape()[0];
Ok(BatchNormInfo {
name: layer_name.to_string(),
num_features,
weight: weight.clone(),
bias: bias.clone(),
running_mean: running_mean.clone(),
running_var: running_var.clone(),
momentum: 0.1, eps: 1e-5, })
}
fn is_classifier(layer_name: &str, layer_info: &SimpleLayerInfo) -> bool {
(layer_name.contains("classifier") || layer_name.contains("fc") || layer_name == "head") &&
layer_info.layer_type == "Linear"
}
fn parse_classifier(
layer_name: &str,
layer_info: &SimpleLayerInfo
) -> Result<ClassifierInfo, RusTorchError> {
let weight = layer_info.tensors.get("weight")
.ok_or_else(|| RusTorchError::import_error("classifier weight".to_string()))?;
let shape = weight.shape();
let num_classes = shape[0];
let in_features = shape[1];
let bias = layer_info.tensors.get("bias");
Ok(ClassifierInfo {
name: layer_name.to_string(),
in_features,
num_classes,
weights: weight.clone(),
bias: bias.cloned(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_layer_detection() {
assert!(ComplexArchitectureParser::is_attention_layer("encoder.layer.0.attention"));
assert!(ComplexArchitectureParser::is_attention_layer("transformer.h.0.attn"));
assert!(ComplexArchitectureParser::is_attention_layer("layers.0.self_attn"));
assert!(!ComplexArchitectureParser::is_attention_layer("layers.0.linear"));
}
#[test]
fn test_ffn_layer_detection() {
assert!(ComplexArchitectureParser::is_ffn_layer("encoder.layer.0.ffn"));
assert!(ComplexArchitectureParser::is_ffn_layer("transformer.h.0.mlp"));
assert!(ComplexArchitectureParser::is_ffn_layer("layers.0.feed_forward"));
assert!(!ComplexArchitectureParser::is_ffn_layer("layers.0.attention"));
}
#[test]
fn test_conv_layer_detection() {
let mut layer_info = SimpleLayerInfo {
name: "features.0".to_string(),
layer_type: "Conv2d".to_string(),
parameter_shapes: HashMap::new(),
num_parameters: 0,
tensors: HashMap::new(),
};
assert!(ComplexArchitectureParser::is_conv_layer("features.0.conv", &layer_info));
layer_info.layer_type = "Linear".to_string();
assert!(!ComplexArchitectureParser::is_conv_layer("features.0.conv", &layer_info));
}
#[test]
fn test_num_heads_inference() {
assert_eq!(ComplexArchitectureParser::infer_num_heads(512, 512), 8); assert_eq!(ComplexArchitectureParser::infer_num_heads(768, 768), 12); assert_eq!(ComplexArchitectureParser::infer_num_heads(1024, 1024), 16); }
}