1use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29 match dtype {
30 Dtype::F32 => Ok(SrcDtype::F32),
31 Dtype::F16 => Ok(SrcDtype::F16),
32 Dtype::BF16 => Ok(SrcDtype::BF16),
33 other => Err(FerrumError::model(format!(
34 "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35 ))),
36 }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45struct Shard {
47 mmap: Mmap,
48 names: Vec<String>,
52}
53
54impl Shard {
55 fn open(path: &Path) -> Result<Self> {
56 let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
57 let mmap = unsafe {
58 Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
59 };
60 let st = SafeTensors::deserialize(&mmap)
63 .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
64 let names = st.names().iter().map(|s| s.to_string()).collect();
65 Ok(Self { mmap, names })
66 }
67
68 fn get<'a>(&'a self, name: &str) -> Result<safetensors::tensor::TensorView<'a>> {
69 let st = SafeTensors::deserialize(&self.mmap)
70 .map_err(|e| FerrumError::model(format!("reparse: {e}")))?;
71 st.tensor(name)
72 .map_err(|e| FerrumError::model(format!("tensor '{name}': {e}")))
73 }
74}
75
76pub struct NativeSafetensorsLoader<B: Backend> {
79 shards: Vec<Shard>,
81 index: HashMap<String, usize>,
83 quant_config: Option<QuantConfig>,
85 _m: std::marker::PhantomData<B>,
86}
87
88impl<B: Backend> NativeSafetensorsLoader<B> {
89 pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
91 let dir = model_dir.as_ref();
92
93 let shard_paths = if dir.join("model.safetensors").exists() {
94 vec![dir.join("model.safetensors")]
95 } else if dir.join("model.safetensors.index.json").exists() {
96 Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
97 .into_iter()
98 .map(|name| dir.join(name))
99 .collect()
100 } else {
101 return Err(FerrumError::model(format!(
102 "no safetensors files in {dir:?}"
103 )));
104 };
105
106 let mut shards = Vec::with_capacity(shard_paths.len());
107 let mut index: HashMap<String, usize> = HashMap::new();
108 for (i, p) in shard_paths.iter().enumerate() {
109 let shard = Shard::open(p)?;
110 for name in &shard.names {
111 index.insert(name.clone(), i);
112 }
113 shards.push(shard);
114 }
115
116 let quant_config = load_quantize_config(dir)?;
117
118 Ok(Self {
119 shards,
120 index,
121 quant_config,
122 _m: std::marker::PhantomData,
123 })
124 }
125
126 fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
127 let data = std::fs::read_to_string(index_path)
128 .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
129 let json: serde_json::Value = serde_json::from_str(&data)
130 .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
131 let weight_map = json
132 .get("weight_map")
133 .and_then(|v| v.as_object())
134 .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
135 let mut files: Vec<String> = weight_map
136 .values()
137 .filter_map(|v| v.as_str().map(|s| s.to_string()))
138 .collect();
139 files.sort();
140 files.dedup();
141 Ok(files)
142 }
143
144 fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
146 let shard_idx = *self
147 .index
148 .get(name)
149 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
150 let view = self.shards[shard_idx].get(name)?;
151 let shape = view.shape().to_vec();
152 let data = dtype_to_f32(view.dtype(), view.data())?;
153 Ok((data, shape))
154 }
155
156 fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
160 let shard_idx = *self
161 .index
162 .get(name)
163 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
164 let view = self.shards[shard_idx].get(name)?;
165 let shape = view.shape().to_vec();
166 let dtype = map_src_dtype(view.dtype())?;
167 Ok((view.data(), dtype, shape))
168 }
169
170 fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
174 let mut total_rows = 0usize;
175 let mut cols = 0usize;
176 let mut dtype: Option<SrcDtype> = None;
177 let mut bytes: Vec<u8> = Vec::new();
178 for n in names {
179 let (raw, d, shape) = self.read_bytes_typed(n)?;
180 if shape.len() != 2 {
181 return Err(FerrumError::model(format!(
182 "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
183 )));
184 }
185 match dtype {
186 Some(prev) if prev != d => {
187 return Err(FerrumError::model(format!(
188 "cat_rows_bytes: dtype mismatch on '{n}'"
189 )))
190 }
191 _ => dtype = Some(d),
192 }
193 if cols == 0 {
194 cols = shape[1];
195 } else if cols != shape[1] {
196 return Err(FerrumError::model(format!(
197 "cat_rows_bytes: col mismatch {cols} vs {}",
198 shape[1]
199 )));
200 }
201 total_rows += shape[0];
202 bytes.extend_from_slice(raw);
203 }
204 Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
205 }
206
207 fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
209 let shard_idx = *self
210 .index
211 .get(name)
212 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
213 let view = self.shards[shard_idx].get(name)?;
214 let shape = view.shape().to_vec();
215 if view.dtype() != Dtype::I32 {
216 return Err(FerrumError::model(format!(
217 "'{name}': expected I32, got {:?}",
218 view.dtype()
219 )));
220 }
221 let bytes = view.data();
222 debug_assert_eq!(bytes.len() % 4, 0);
223 let mut out = vec![0i32; bytes.len() / 4];
224 out.as_mut_slice()
225 .iter_mut()
226 .zip(bytes.chunks_exact(4))
227 .for_each(|(d, chunk)| {
228 *d = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
229 });
230 Ok((out, shape))
231 }
232
233 fn has(&self, name: &str) -> bool {
234 self.index.contains_key(name)
235 }
236}
237
238impl<B: Backend> WeightLoader<B> for NativeSafetensorsLoader<B> {
239 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
240 let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
245 Ok(B::from_weight_bytes(raw, src_dtype))
246 }
247
248 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
249 let qw_key = format!("{name}.qweight");
251 if self.has(&qw_key) {
252 return self.load_gptq_linear(name);
253 }
254 if let Some(prefix) = name.strip_suffix("qkv_proj") {
258 let parts = [
259 format!("{prefix}q_proj"),
260 format!("{prefix}k_proj"),
261 format!("{prefix}v_proj"),
262 ];
263 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
264 return self.load_gptq_linear_fused(&parts);
265 }
266 }
267 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
268 let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
269 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
270 return self.load_gptq_linear_fused(&parts);
271 }
272 }
273
274 let direct = format!("{name}.weight");
277 if self.has(&direct) {
278 let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
279 if shape.len() != 2 {
280 return Err(FerrumError::model(format!(
281 "linear '{name}': expected 2D weight, got {shape:?}"
282 )));
283 }
284 let weight = B::from_weight_bytes(raw, src_dtype);
285 return Ok(Box::new(DenseLinear::<B>::from_buffer(
286 weight, shape[0], shape[1],
287 )));
288 }
289
290 if let Some(prefix) = name.strip_suffix("qkv_proj") {
295 let parts = [
296 format!("{prefix}q_proj.weight"),
297 format!("{prefix}k_proj.weight"),
298 format!("{prefix}v_proj.weight"),
299 ];
300 if parts.iter().all(|p| self.has(p)) {
301 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
302 let weight = B::from_weight_bytes(&bytes, dtype);
303 return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
304 }
305 }
306 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
307 let parts = [
308 format!("{prefix}gate_proj.weight"),
309 format!("{prefix}up_proj.weight"),
310 ];
311 if parts.iter().all(|p| self.has(p)) {
312 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
313 let weight = B::from_weight_bytes(&bytes, dtype);
314 return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
315 }
316 }
317
318 Err(FerrumError::model(format!(
319 "could not load linear '{name}' — no direct `.weight`, no split components"
320 )))
321 }
322
323 fn has_tensor(&self, name: &str) -> bool {
324 self.has(name)
325 }
326
327 fn quant_config(&self) -> Option<&QuantConfig> {
328 self.quant_config.as_ref()
329 }
330}
331
332impl<B: Backend> NativeSafetensorsLoader<B> {
333 fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
338 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
339 FerrumError::model(format!(
340 "'{name}.qweight' present but no quantize_config.json — \
341 can't determine bits/group_size"
342 ))
343 })?;
344 if qcfg.method != QuantMethod::Gptq {
345 return Err(FerrumError::model(format!(
346 "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
347 qcfg.method
348 )));
349 }
350
351 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
352 let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
353 let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
354 let g_idx = if self.has(&format!("{name}.g_idx")) {
355 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
356 } else {
357 None
358 };
359
360 if qw_shape.len() != 2 {
363 return Err(FerrumError::model(format!(
364 "'{name}.qweight' expected 2D, got {qw_shape:?}"
365 )));
366 }
367 let in_features = qw_shape[0] * 8;
368 let out_features = qw_shape[1];
369 if sc_shape.len() != 2 || sc_shape[1] != out_features {
370 return Err(FerrumError::model(format!(
371 "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
372 )));
373 }
374
375 let mut linear = GptqLinear::<B>::from_raw(
376 &qweight,
377 &scales_f32,
378 &qzeros,
379 g_idx.as_deref(),
380 qcfg.bits,
381 qcfg.group_size,
382 in_features,
383 out_features,
384 )?;
385
386 let bias_key = format!("{name}.bias");
388 if self.has(&bias_key) {
389 let (bias, bias_shape) = self.read_f32(&bias_key)?;
390 if bias_shape != [out_features] {
391 return Err(FerrumError::model(format!(
392 "'{bias_key}' {bias_shape:?} != [{out_features}]"
393 )));
394 }
395 linear = linear.with_bias(&bias);
396 }
397 Ok(Box::new(linear))
398 }
399
400 fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
413 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
414 FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
415 })?;
416 if qcfg.method != QuantMethod::Gptq {
417 return Err(FerrumError::model(format!(
418 "GPTQ fusion but quant_method={:?}",
419 qcfg.method
420 )));
421 }
422
423 let mut qw_acc: Vec<i32> = Vec::new();
424 let mut sc_acc: Vec<f32> = Vec::new();
425 let mut qz_acc: Vec<i32> = Vec::new();
426 let mut qw_rows = 0usize;
427 let mut sc_rows = 0usize;
428 let mut qz_rows = 0usize;
429 let mut total_n = 0usize;
430 let mut total_n_scales = 0usize;
431 let mut total_n_zeros = 0usize;
432 let mut g_idx: Option<Vec<i32>> = None;
433 let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
436 let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
437
438 for p in parts {
439 let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
440 let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
441 let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
442 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
443 return Err(FerrumError::model(format!(
444 "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
445 )));
446 }
447 if qw_rows == 0 {
448 qw_rows = qw_sh[0];
449 sc_rows = sc_sh[0];
450 qz_rows = qz_sh[0];
451 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
452 return Err(FerrumError::model(format!(
453 "GPTQ fusion row mismatch on '{p}'"
454 )));
455 }
456 total_n += qw_sh[1];
457 total_n_scales += sc_sh[1];
458 total_n_zeros += qz_sh[1];
459 qw_parts.push((qw, qw_sh[0], qw_sh[1]));
460 sc_parts.push((sc, sc_sh[0], sc_sh[1]));
461 qz_parts.push((qz, qz_sh[0], qz_sh[1]));
462
463 if g_idx.is_none() && self.has(&format!("{p}.g_idx")) {
465 g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
466 }
467 }
468
469 qw_acc.reserve(qw_rows * total_n);
471 for r in 0..qw_rows {
472 for (part, _rows, cols) in &qw_parts {
473 qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
474 }
475 }
476 sc_acc.reserve(sc_rows * total_n_scales);
477 for r in 0..sc_rows {
478 for (part, _rows, cols) in &sc_parts {
479 sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
480 }
481 }
482 qz_acc.reserve(qz_rows * total_n_zeros);
483 for r in 0..qz_rows {
484 for (part, _rows, cols) in &qz_parts {
485 qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
486 }
487 }
488
489 let in_features = qw_rows * 8;
490 let out_features = total_n;
491 let mut linear = GptqLinear::<B>::from_raw(
492 &qw_acc,
493 &sc_acc,
494 &qz_acc,
495 g_idx.as_deref(),
496 qcfg.bits,
497 qcfg.group_size,
498 in_features,
499 out_features,
500 )?;
501
502 let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
505 let any = bias_keys.iter().any(|k| self.has(k));
506 let all = bias_keys.iter().all(|k| self.has(k));
507 if any && !all {
508 return Err(FerrumError::model(
509 "GPTQ fusion: inconsistent bias presence across parts".to_string(),
510 ));
511 }
512 if all {
513 let mut fused: Vec<f32> = Vec::with_capacity(out_features);
514 for k in &bias_keys {
515 let (b, _) = self.read_f32(k)?;
516 fused.extend_from_slice(&b);
517 }
518 if fused.len() != out_features {
519 return Err(FerrumError::model(format!(
520 "GPTQ fusion bias length {} != out_features {out_features}",
521 fused.len()
522 )));
523 }
524 linear = linear.with_bias(&fused);
525 }
526 Ok(Box::new(linear))
527 }
528
529 #[allow(dead_code)]
533 fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
534 let mut total_rows = 0usize;
535 let mut cols = 0usize;
536 let mut out: Vec<f32> = Vec::new();
537 for n in names {
538 let (data, shape) = self.read_f32(n)?;
539 if shape.len() != 2 {
540 return Err(FerrumError::model(format!(
541 "cat_rows: '{n}' is {shape:?}, need 2D"
542 )));
543 }
544 if cols == 0 {
545 cols = shape[1];
546 } else if cols != shape[1] {
547 return Err(FerrumError::model(format!(
548 "cat_rows: col mismatch {cols} vs {}",
549 shape[1]
550 )));
551 }
552 total_rows += shape[0];
553 out.extend_from_slice(&data);
554 }
555 Ok((total_rows, cols, out))
556 }
557}
558
559fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
560 match dtype {
561 Dtype::F32 => {
562 debug_assert_eq!(raw.len() % 4, 0);
563 let n = raw.len() / 4;
564 let mut out = vec![0.0f32; n];
565 for i in 0..n {
566 let bytes = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
567 out[i] = f32::from_le_bytes(bytes);
568 }
569 Ok(out)
570 }
571 Dtype::F16 => {
572 debug_assert_eq!(raw.len() % 2, 0);
573 let n = raw.len() / 2;
574 let mut out = vec![0.0f32; n];
575 for i in 0..n {
576 let bytes = [raw[i * 2], raw[i * 2 + 1]];
577 out[i] = f16::from_le_bytes(bytes).to_f32();
578 }
579 Ok(out)
580 }
581 Dtype::BF16 => {
582 debug_assert_eq!(raw.len() % 2, 0);
583 let n = raw.len() / 2;
584 let mut out = vec![0.0f32; n];
585 for i in 0..n {
586 let bytes = [raw[i * 2], raw[i * 2 + 1]];
587 out[i] = bf16::from_le_bytes(bytes).to_f32();
588 }
589 Ok(out)
590 }
591 other => Err(FerrumError::model(format!(
592 "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
593 use a format-specific loader (GPTQ / AWQ / GGUF)",
594 ))),
595 }
596}
597
598fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
599 let p = dir.join("quantize_config.json");
601 if p.exists() {
602 let data =
603 std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
604 let qc: QuantConfig = serde_json::from_str(&data)
605 .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
606 return Ok(Some(qc));
607 }
608 let cfg = dir.join("config.json");
611 if cfg.exists() {
612 let data = std::fs::read_to_string(&cfg)
613 .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
614 let root: serde_json::Value = serde_json::from_str(&data)
615 .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
616 if let Some(qc_val) = root.get("quantization_config") {
617 let method = qc_val
619 .get("quant_method")
620 .and_then(|v| v.as_str())
621 .unwrap_or("none");
622 let method = match method.to_lowercase().as_str() {
623 "gptq" => QuantMethod::Gptq,
624 "awq" => QuantMethod::Awq,
625 "gguf" => QuantMethod::Gguf,
626 _ => QuantMethod::None,
627 };
628 let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
629 let group_size = qc_val
630 .get("group_size")
631 .and_then(|v| v.as_i64())
632 .unwrap_or(128)
633 .max(0) as usize;
634 let desc_act = qc_val
635 .get("desc_act")
636 .and_then(|v| v.as_bool())
637 .unwrap_or(false);
638 let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
639 if method != QuantMethod::None {
640 return Ok(Some(QuantConfig {
641 method,
642 bits,
643 group_size,
644 desc_act,
645 sym,
646 }));
647 }
648 }
649 }
650 Ok(None)
651}