1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
/// Parse error message to detect specific error types (GH-129)
#[cfg(feature = "hf-hub-integration")]
pub fn parse_import_error(error_msg: &str, resource: &str) -> ImportError {
let msg_lower = error_msg.to_lowercase();
// Check for 404 / not found
if msg_lower.contains("404")
|| msg_lower.contains("not found")
|| msg_lower.contains("does not exist")
|| msg_lower.contains("no such")
{
return ImportError::NotFound {
resource: resource.to_string(),
status: 404,
};
}
// Check for authentication / 401 / 403
if msg_lower.contains("401")
|| msg_lower.contains("403")
|| msg_lower.contains("unauthorized")
|| msg_lower.contains("forbidden")
|| msg_lower.contains("gated")
|| msg_lower.contains("access denied")
{
return ImportError::AuthRequired {
resource: resource.to_string(),
};
}
// Check for rate limiting / 429
if msg_lower.contains("429")
|| msg_lower.contains("rate limit")
|| msg_lower.contains("too many requests")
{
// Try to extract retry-after
let retry_after = if let Some(pos) = msg_lower.find("retry") {
msg_lower[pos..]
.split_whitespace()
.find_map(|s| s.parse::<u64>().ok())
} else {
None
};
return ImportError::RateLimited { retry_after };
}
// Default to download failed
ImportError::DownloadFailed {
source: resource.to_string(),
reason: error_msg.to_string(),
}
}
// ============================================================================
// GH-127: Sharded Model Support
// ============================================================================
/// Parsed sharded model index (model.safetensors.index.json)
///
/// `HuggingFace` uses this format for large models split across multiple shards.
/// Example: Llama-2-7b has 2 shards, Llama-2-70b has 15 shards.
#[derive(Debug, Clone)]
pub struct ShardedIndex {
/// Map of tensor name → shard filename
weight_map: std::collections::HashMap<String, String>,
/// Optional total size in bytes
total_size: Option<u64>,
}
impl ShardedIndex {
/// Parse a sharded index from JSON string
///
/// # Example JSON format
/// ```json
/// {
/// "metadata": {"total_size": 14000000000},
/// "weight_map": {
/// "model.encoder.weight": "model-00001-of-00002.safetensors",
/// "model.decoder.weight": "model-00002-of-00002.safetensors"
/// }
/// }
/// ```
pub fn parse(json: &str) -> Result<Self> {
// Minimal JSON parsing without serde dependency
// Look for "weight_map" key and parse the object
let json = json.trim();
if !json.starts_with('{') || !json.ends_with('}') {
return Err(AprenderError::FormatError {
message: "Invalid JSON: expected object".to_string(),
});
}
// Find weight_map section
let weight_map_start =
json.find("\"weight_map\"")
.ok_or_else(|| AprenderError::FormatError {
message: "Missing 'weight_map' key in index.json".to_string(),
})?;
// Parse weight_map object
let after_key = &json[weight_map_start + 12..]; // Skip "weight_map"
let obj_start = after_key
.find('{')
.ok_or_else(|| AprenderError::FormatError {
message: "Invalid weight_map: expected object".to_string(),
})?;
let obj_content = &after_key[obj_start..];
let mut weight_map = std::collections::HashMap::new();
let mut depth = 0;
let mut obj_end = 0;
for (i, c) in obj_content.char_indices() {
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
obj_end = i;
break;
}
}
_ => {}
}
}
let inner = &obj_content[1..obj_end];
// Parse key-value pairs: "tensor_name": "shard_file"
for pair in inner.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let parts: Vec<&str> = pair.splitn(2, ':').collect();
if parts.len() == 2 {
let key = parts[0].trim().trim_matches('"');
let val = parts[1].trim().trim_matches('"');
if !key.is_empty() && !val.is_empty() {
weight_map.insert(key.to_string(), val.to_string());
}
}
}
// Parse optional total_size from metadata
let total_size = json.find("\"total_size\"").and_then(|pos| {
let after = &json[pos + 12..];
let colon = after.find(':')?;
let after_colon = after[colon + 1..].trim_start();
let end = after_colon.find(|c: char| !c.is_ascii_digit())?;
after_colon[..end].parse::<u64>().ok()
});
Ok(Self {
weight_map,
total_size,
})
}
/// Number of unique shard files
#[must_use]
pub fn shard_count(&self) -> usize {
let unique: std::collections::HashSet<_> = self.weight_map.values().collect();
unique.len()
}
/// Number of tensors in the index
#[must_use]
pub fn tensor_count(&self) -> usize {
self.weight_map.len()
}
/// Total model size in bytes (if available)
#[must_use]
pub fn total_size(&self) -> Option<u64> {
self.total_size
}
/// Get the shard file containing a specific tensor
#[must_use]
pub fn shard_for_tensor(&self, tensor_name: &str) -> Option<&str> {
self.weight_map.get(tensor_name).map(String::as_str)
}
/// Get all tensor names in a specific shard
#[must_use]
pub fn tensors_in_shard(&self, shard_file: &str) -> Vec<&str> {
self.weight_map
.iter()
.filter(|(_, v)| v.as_str() == shard_file)
.map(|(k, _)| k.as_str())
.collect()
}
/// Get sorted list of shard files
#[must_use]
pub fn shard_files(&self) -> Vec<&str> {
let mut files: Vec<_> = self
.weight_map
.values()
.map(String::as_str)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
files.sort_unstable();
files
}
}
/// Detect if a model directory contains a sharded model
///
/// Checks for `model.safetensors.index.json` which indicates sharding.
#[must_use]
pub fn detect_sharded_model(dir: &std::path::Path, base_name: &str) -> Option<PathBuf> {
let index_name = format!("{base_name}.index.json");
let index_path = dir.join(&index_name);
if index_path.exists() {
Some(index_path)
} else {
None
}
}