diffai_core/parsers/
pytorch.rs1use anyhow::Result;
2use serde_json::Value;
3use std::fs::File;
4use std::io::Read;
5use std::path::Path;
6
7pub fn parse_pytorch_model(file_path: &Path) -> Result<Value> {
9 let file = File::open(file_path)?;
11 let mut reader = std::io::BufReader::new(file);
12 let mut buffer = Vec::new();
13 reader.read_to_end(&mut buffer)?;
14
15 let mut result = serde_json::Map::new();
18 result.insert(
19 "model_type".to_string(),
20 Value::String("pytorch".to_string()),
21 );
22 result.insert("file_size".to_string(), Value::Number(buffer.len().into()));
23 result.insert("format".to_string(), Value::String("pickle".to_string()));
24
25 let model_info = extract_pytorch_model_info(&buffer);
27 for (key, value) in model_info {
28 result.insert(key, value);
29 }
30
31 Ok(Value::Object(result))
32}
33
34fn extract_pytorch_model_info(buffer: &[u8]) -> serde_json::Map<String, Value> {
36 let mut info = serde_json::Map::new();
37
38 let searchable_content = String::from_utf8_lossy(buffer);
43
44 let weight_count = searchable_content.matches("weight").count();
46 let bias_count = searchable_content.matches("bias").count();
47
48 let conv_count = searchable_content.matches("conv").count();
50 let linear_count =
51 searchable_content.matches("linear").count() + searchable_content.matches("fc.").count();
52 let bn_count =
53 searchable_content.matches("bn").count() + searchable_content.matches("batch_norm").count();
54
55 let mut detected_layers = Vec::new();
57 if conv_count > 0 {
58 detected_layers.push(format!("convolution: {conv_count}"));
59 }
60 if linear_count > 0 {
61 detected_layers.push(format!("linear: {linear_count}"));
62 }
63 if bn_count > 0 {
64 detected_layers.push(format!("batch_norm: {bn_count}"));
65 }
66 if weight_count > 0 {
67 detected_layers.push(format!("weight_params: {weight_count}"));
68 }
69 if bias_count > 0 {
70 detected_layers.push(format!("bias_params: {bias_count}"));
71 }
72
73 if !detected_layers.is_empty() {
74 info.insert(
75 "detected_components".to_string(),
76 Value::String(detected_layers.join(", ")),
77 );
78 }
79
80 let layer_count = weight_count.max(bias_count / 2); if layer_count > 0 {
83 info.insert(
84 "estimated_layers".to_string(),
85 Value::Number(layer_count.into()),
86 );
87 }
88
89 let architectures = [
91 ("resnet", "ResNet"),
92 ("vgg", "VGG"),
93 ("densenet", "DenseNet"),
94 ("mobilenet", "MobileNet"),
95 ("efficientnet", "EfficientNet"),
96 ("transformer", "Transformer"),
97 ("bert", "BERT"),
98 ("gpt", "GPT"),
99 ];
100
101 for (pattern, arch_name) in &architectures {
102 if searchable_content.to_lowercase().contains(pattern) {
103 info.insert(
104 "detected_architecture".to_string(),
105 Value::String(arch_name.to_string()),
106 );
107 break;
108 }
109 }
110
111 if searchable_content.contains("optimizer") {
113 info.insert("has_optimizer_state".to_string(), Value::Bool(true));
114 }
115 if searchable_content.contains("epoch") {
116 info.insert("has_training_metadata".to_string(), Value::Bool(true));
117 }
118 if searchable_content.contains("lr") || searchable_content.contains("learning_rate") {
119 info.insert("has_learning_rate".to_string(), Value::Bool(true));
120 }
121
122 info.insert(
124 "binary_size".to_string(),
125 Value::Number(buffer.len().into()),
126 );
127
128 if buffer.len() > 2 {
130 let protocol_byte = buffer[1];
131 if protocol_byte <= 5 {
132 info.insert(
133 "pickle_protocol".to_string(),
134 Value::Number(protocol_byte.into()),
135 );
136 }
137 }
138
139 let structure_hash = calculate_simple_hash(&searchable_content);
141 info.insert(
142 "structure_fingerprint".to_string(),
143 Value::String(format!("{structure_hash:x}")),
144 );
145
146 info
147}
148
149fn calculate_simple_hash(content: &str) -> u64 {
151 use std::collections::hash_map::DefaultHasher;
152 use std::hash::{Hash, Hasher};
153
154 let mut hasher = DefaultHasher::new();
155 let structure_parts: Vec<&str> = content
157 .matches(|c: char| c.is_alphanumeric() || c == '.')
158 .take(1000) .collect();
160 structure_parts.hash(&mut hasher);
161 hasher.finish()
162}