1use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::Backend;
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27use crate::config::{QuantConfig, QuantMethod};
28use crate::dense::DenseLinear;
29use crate::gptq::GptqLinear;
30use crate::loader::WeightLoader;
31use crate::traits::Linear;
32
33struct Shard {
35 mmap: Mmap,
36 names: Vec<String>,
40}
41
42impl Shard {
43 fn open(path: &Path) -> Result<Self> {
44 let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
45 let mmap = unsafe {
46 Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
47 };
48 let st = SafeTensors::deserialize(&mmap)
51 .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
52 let names = st.names().iter().map(|s| s.to_string()).collect();
53 Ok(Self { mmap, names })
54 }
55
56 fn get<'a>(&'a self, name: &str) -> Result<safetensors::tensor::TensorView<'a>> {
57 let st = SafeTensors::deserialize(&self.mmap)
58 .map_err(|e| FerrumError::model(format!("reparse: {e}")))?;
59 st.tensor(name)
60 .map_err(|e| FerrumError::model(format!("tensor '{name}': {e}")))
61 }
62}
63
64pub struct NativeSafetensorsLoader<B: Backend> {
67 shards: Vec<Shard>,
69 index: HashMap<String, usize>,
71 quant_config: Option<QuantConfig>,
73 _m: std::marker::PhantomData<B>,
74}
75
76impl<B: Backend> NativeSafetensorsLoader<B> {
77 pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
79 let dir = model_dir.as_ref();
80
81 let shard_paths = if dir.join("model.safetensors").exists() {
82 vec![dir.join("model.safetensors")]
83 } else if dir.join("model.safetensors.index.json").exists() {
84 Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
85 .into_iter()
86 .map(|name| dir.join(name))
87 .collect()
88 } else {
89 return Err(FerrumError::model(format!(
90 "no safetensors files in {dir:?}"
91 )));
92 };
93
94 let mut shards = Vec::with_capacity(shard_paths.len());
95 let mut index: HashMap<String, usize> = HashMap::new();
96 for (i, p) in shard_paths.iter().enumerate() {
97 let shard = Shard::open(p)?;
98 for name in &shard.names {
99 index.insert(name.clone(), i);
100 }
101 shards.push(shard);
102 }
103
104 let quant_config = load_quantize_config(dir)?;
105
106 Ok(Self {
107 shards,
108 index,
109 quant_config,
110 _m: std::marker::PhantomData,
111 })
112 }
113
114 fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
115 let data = std::fs::read_to_string(index_path)
116 .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
117 let json: serde_json::Value = serde_json::from_str(&data)
118 .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
119 let weight_map = json
120 .get("weight_map")
121 .and_then(|v| v.as_object())
122 .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
123 let mut files: Vec<String> = weight_map
124 .values()
125 .filter_map(|v| v.as_str().map(|s| s.to_string()))
126 .collect();
127 files.sort();
128 files.dedup();
129 Ok(files)
130 }
131
132 fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
134 let shard_idx = *self
135 .index
136 .get(name)
137 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
138 let view = self.shards[shard_idx].get(name)?;
139 let shape = view.shape().to_vec();
140 let data = dtype_to_f32(view.dtype(), view.data())?;
141 Ok((data, shape))
142 }
143
144 fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
146 let shard_idx = *self
147 .index
148 .get(name)
149 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
150 let view = self.shards[shard_idx].get(name)?;
151 let shape = view.shape().to_vec();
152 if view.dtype() != Dtype::I32 {
153 return Err(FerrumError::model(format!(
154 "'{name}': expected I32, got {:?}",
155 view.dtype()
156 )));
157 }
158 let bytes = view.data();
159 debug_assert_eq!(bytes.len() % 4, 0);
160 let mut out = vec![0i32; bytes.len() / 4];
161 out.as_mut_slice()
162 .iter_mut()
163 .zip(bytes.chunks_exact(4))
164 .for_each(|(d, chunk)| {
165 *d = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
166 });
167 Ok((out, shape))
168 }
169
170 fn has(&self, name: &str) -> bool {
171 self.index.contains_key(name)
172 }
173}
174
175impl<B: Backend> WeightLoader<B> for NativeSafetensorsLoader<B> {
176 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
177 let (data, _) = self.read_f32(name)?;
178 Ok(B::from_slice(&data))
179 }
180
181 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
182 let qw_key = format!("{name}.qweight");
184 if self.has(&qw_key) {
185 return self.load_gptq_linear(name);
186 }
187 if name.ends_with("qkv_proj") {
191 let prefix = &name[..name.len() - "qkv_proj".len()];
192 let parts = [
193 format!("{prefix}q_proj"),
194 format!("{prefix}k_proj"),
195 format!("{prefix}v_proj"),
196 ];
197 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
198 return self.load_gptq_linear_fused(&parts);
199 }
200 }
201 if name.ends_with("gate_up_proj") {
202 let prefix = &name[..name.len() - "gate_up_proj".len()];
203 let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
204 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
205 return self.load_gptq_linear_fused(&parts);
206 }
207 }
208
209 let direct = format!("{name}.weight");
211 if self.has(&direct) {
212 let (data, shape) = self.read_f32(&direct)?;
213 if shape.len() != 2 {
214 return Err(FerrumError::model(format!(
215 "linear '{name}': expected 2D weight, got {shape:?}"
216 )));
217 }
218 return Ok(Box::new(DenseLinear::<B>::from_rows(
219 &data, shape[0], shape[1],
220 )));
221 }
222
223 if name.ends_with("qkv_proj") {
226 let prefix = &name[..name.len() - "qkv_proj".len()];
227 let parts = [
228 format!("{prefix}q_proj.weight"),
229 format!("{prefix}k_proj.weight"),
230 format!("{prefix}v_proj.weight"),
231 ];
232 if parts.iter().all(|p| self.has(p)) {
233 let (rows, cols, data) = self.cat_rows(&parts)?;
234 return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
235 }
236 }
237 if name.ends_with("gate_up_proj") {
238 let prefix = &name[..name.len() - "gate_up_proj".len()];
239 let parts = [
240 format!("{prefix}gate_proj.weight"),
241 format!("{prefix}up_proj.weight"),
242 ];
243 if parts.iter().all(|p| self.has(p)) {
244 let (rows, cols, data) = self.cat_rows(&parts)?;
245 return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
246 }
247 }
248
249 Err(FerrumError::model(format!(
250 "could not load linear '{name}' — no direct `.weight`, no split components"
251 )))
252 }
253
254 fn has_tensor(&self, name: &str) -> bool {
255 self.has(name)
256 }
257
258 fn quant_config(&self) -> Option<&QuantConfig> {
259 self.quant_config.as_ref()
260 }
261}
262
263impl<B: Backend> NativeSafetensorsLoader<B> {
264 fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
269 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
270 FerrumError::model(format!(
271 "'{name}.qweight' present but no quantize_config.json — \
272 can't determine bits/group_size"
273 ))
274 })?;
275 if qcfg.method != QuantMethod::Gptq {
276 return Err(FerrumError::model(format!(
277 "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
278 qcfg.method
279 )));
280 }
281
282 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
283 let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
284 let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
285 let g_idx = if self.has(&format!("{name}.g_idx")) {
286 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
287 } else {
288 None
289 };
290
291 if qw_shape.len() != 2 {
294 return Err(FerrumError::model(format!(
295 "'{name}.qweight' expected 2D, got {qw_shape:?}"
296 )));
297 }
298 let in_features = qw_shape[0] * 8;
299 let out_features = qw_shape[1];
300 if sc_shape.len() != 2 || sc_shape[1] != out_features {
301 return Err(FerrumError::model(format!(
302 "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
303 )));
304 }
305
306 let mut linear = GptqLinear::<B>::from_raw(
307 &qweight,
308 &scales_f32,
309 &qzeros,
310 g_idx.as_deref(),
311 qcfg.bits,
312 qcfg.group_size,
313 in_features,
314 out_features,
315 )?;
316
317 let bias_key = format!("{name}.bias");
319 if self.has(&bias_key) {
320 let (bias, bias_shape) = self.read_f32(&bias_key)?;
321 if bias_shape != [out_features] {
322 return Err(FerrumError::model(format!(
323 "'{bias_key}' {bias_shape:?} != [{out_features}]"
324 )));
325 }
326 linear = linear.with_bias(&bias);
327 }
328 Ok(Box::new(linear))
329 }
330
331 fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
344 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
345 FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
346 })?;
347 if qcfg.method != QuantMethod::Gptq {
348 return Err(FerrumError::model(format!(
349 "GPTQ fusion but quant_method={:?}",
350 qcfg.method
351 )));
352 }
353
354 let mut qw_acc: Vec<i32> = Vec::new();
355 let mut sc_acc: Vec<f32> = Vec::new();
356 let mut qz_acc: Vec<i32> = Vec::new();
357 let mut qw_rows = 0usize;
358 let mut sc_rows = 0usize;
359 let mut qz_rows = 0usize;
360 let mut total_n = 0usize;
361 let mut total_n_scales = 0usize;
362 let mut total_n_zeros = 0usize;
363 let mut g_idx: Option<Vec<i32>> = None;
364 let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
367 let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
368
369 for p in parts {
370 let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
371 let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
372 let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
373 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
374 return Err(FerrumError::model(format!(
375 "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
376 )));
377 }
378 if qw_rows == 0 {
379 qw_rows = qw_sh[0];
380 sc_rows = sc_sh[0];
381 qz_rows = qz_sh[0];
382 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
383 return Err(FerrumError::model(format!(
384 "GPTQ fusion row mismatch on '{p}'"
385 )));
386 }
387 total_n += qw_sh[1];
388 total_n_scales += sc_sh[1];
389 total_n_zeros += qz_sh[1];
390 qw_parts.push((qw, qw_sh[0], qw_sh[1]));
391 sc_parts.push((sc, sc_sh[0], sc_sh[1]));
392 qz_parts.push((qz, qz_sh[0], qz_sh[1]));
393
394 if g_idx.is_none() {
396 if self.has(&format!("{p}.g_idx")) {
397 g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
398 }
399 }
400 }
401
402 qw_acc.reserve(qw_rows * total_n);
404 for r in 0..qw_rows {
405 for (part, _rows, cols) in &qw_parts {
406 qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
407 }
408 }
409 sc_acc.reserve(sc_rows * total_n_scales);
410 for r in 0..sc_rows {
411 for (part, _rows, cols) in &sc_parts {
412 sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
413 }
414 }
415 qz_acc.reserve(qz_rows * total_n_zeros);
416 for r in 0..qz_rows {
417 for (part, _rows, cols) in &qz_parts {
418 qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
419 }
420 }
421
422 let in_features = qw_rows * 8;
423 let out_features = total_n;
424 let mut linear = GptqLinear::<B>::from_raw(
425 &qw_acc,
426 &sc_acc,
427 &qz_acc,
428 g_idx.as_deref(),
429 qcfg.bits,
430 qcfg.group_size,
431 in_features,
432 out_features,
433 )?;
434
435 let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
438 let any = bias_keys.iter().any(|k| self.has(k));
439 let all = bias_keys.iter().all(|k| self.has(k));
440 if any && !all {
441 return Err(FerrumError::model(
442 "GPTQ fusion: inconsistent bias presence across parts".to_string(),
443 ));
444 }
445 if all {
446 let mut fused: Vec<f32> = Vec::with_capacity(out_features);
447 for k in &bias_keys {
448 let (b, _) = self.read_f32(k)?;
449 fused.extend_from_slice(&b);
450 }
451 if fused.len() != out_features {
452 return Err(FerrumError::model(format!(
453 "GPTQ fusion bias length {} != out_features {out_features}",
454 fused.len()
455 )));
456 }
457 linear = linear.with_bias(&fused);
458 }
459 Ok(Box::new(linear))
460 }
461
462 fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
464 let mut total_rows = 0usize;
465 let mut cols = 0usize;
466 let mut out: Vec<f32> = Vec::new();
467 for n in names {
468 let (data, shape) = self.read_f32(n)?;
469 if shape.len() != 2 {
470 return Err(FerrumError::model(format!(
471 "cat_rows: '{n}' is {shape:?}, need 2D"
472 )));
473 }
474 if cols == 0 {
475 cols = shape[1];
476 } else if cols != shape[1] {
477 return Err(FerrumError::model(format!(
478 "cat_rows: col mismatch {cols} vs {}",
479 shape[1]
480 )));
481 }
482 total_rows += shape[0];
483 out.extend_from_slice(&data);
484 }
485 Ok((total_rows, cols, out))
486 }
487}
488
489fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
490 match dtype {
491 Dtype::F32 => {
492 debug_assert_eq!(raw.len() % 4, 0);
493 let n = raw.len() / 4;
494 let mut out = vec![0.0f32; n];
495 for i in 0..n {
496 let bytes = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
497 out[i] = f32::from_le_bytes(bytes);
498 }
499 Ok(out)
500 }
501 Dtype::F16 => {
502 debug_assert_eq!(raw.len() % 2, 0);
503 let n = raw.len() / 2;
504 let mut out = vec![0.0f32; n];
505 for i in 0..n {
506 let bytes = [raw[i * 2], raw[i * 2 + 1]];
507 out[i] = f16::from_le_bytes(bytes).to_f32();
508 }
509 Ok(out)
510 }
511 Dtype::BF16 => {
512 debug_assert_eq!(raw.len() % 2, 0);
513 let n = raw.len() / 2;
514 let mut out = vec![0.0f32; n];
515 for i in 0..n {
516 let bytes = [raw[i * 2], raw[i * 2 + 1]];
517 out[i] = bf16::from_le_bytes(bytes).to_f32();
518 }
519 Ok(out)
520 }
521 other => Err(FerrumError::model(format!(
522 "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
523 use a format-specific loader (GPTQ / AWQ / GGUF)",
524 ))),
525 }
526}
527
528fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
529 let p = dir.join("quantize_config.json");
531 if p.exists() {
532 let data =
533 std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
534 let qc: QuantConfig = serde_json::from_str(&data)
535 .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
536 return Ok(Some(qc));
537 }
538 let cfg = dir.join("config.json");
541 if cfg.exists() {
542 let data = std::fs::read_to_string(&cfg)
543 .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
544 let root: serde_json::Value = serde_json::from_str(&data)
545 .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
546 if let Some(qc_val) = root.get("quantization_config") {
547 let method = qc_val
549 .get("quant_method")
550 .and_then(|v| v.as_str())
551 .unwrap_or("none");
552 let method = match method.to_lowercase().as_str() {
553 "gptq" => QuantMethod::Gptq,
554 "awq" => QuantMethod::Awq,
555 "gguf" => QuantMethod::Gguf,
556 _ => QuantMethod::None,
557 };
558 let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
559 let group_size = qc_val
560 .get("group_size")
561 .and_then(|v| v.as_i64())
562 .unwrap_or(128)
563 .max(0) as usize;
564 let desc_act = qc_val
565 .get("desc_act")
566 .and_then(|v| v.as_bool())
567 .unwrap_or(false);
568 let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
569 if method != QuantMethod::None {
570 return Ok(Some(QuantConfig {
571 method,
572 bits,
573 group_size,
574 desc_act,
575 sym,
576 }));
577 }
578 }
579 }
580 Ok(None)
581}