ferrum_quantization/loader.rs
1//! `WeightLoader` trait — unified interface for loading tensor/linear weights
2//! into a specific backend.
3//!
4//! Implementations (landing in Phase B):
5//! - `SafeTensorsLoader` — reads `.safetensors` files, returns `DenseLinear`
6//! unless `quantize_config.json` indicates GPTQ/AWQ, in which case it
7//! returns `GptqLinear` / `AwqLinear`.
8//! - `GgufLoader` — reads `.gguf` files, returns `GgufLinear`.
9//!
10//! The trait is generic over `B: Backend` so the loader can materialise
11//! tensors directly into backend-native buffers (zero-copy on Apple Silicon
12//! shared memory, dtoh/htod for CUDA, etc.).
13
14use ferrum_kernels::backend::Backend;
15use ferrum_types::Result;
16
17use crate::config::QuantConfig;
18use crate::traits::Linear;
19
20pub trait WeightLoader<B: Backend>: Send + Sync {
21 /// Load a single tensor by fully qualified name
22 /// (e.g. `"model.embed_tokens.weight"`).
23 fn load_tensor(&self, name: &str) -> Result<B::Buffer>;
24
25 /// Load a projection as a `Linear<B>`. The concrete implementation
26 /// (DenseLinear / GptqLinear / AwqLinear / GgufLinear) depends on the
27 /// loader's file format and quant config.
28 ///
29 /// `name` is the module path without the `.weight` suffix, e.g.
30 /// `"model.layers.0.self_attn.qkv_proj"`.
31 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>>;
32
33 /// Whether a tensor with this name exists in the source.
34 fn has_tensor(&self, name: &str) -> bool;
35
36 /// Quantization metadata (parsed from `quantize_config.json` or a GGUF header).
37 /// `None` means the source is dense.
38 fn quant_config(&self) -> Option<&QuantConfig>;
39}
40
41/// Adapter that prepends a fixed prefix to every tensor name before
42/// delegating to an underlying loader.
43///
44/// Use case: a single safetensors file contains a sub-model (e.g.
45/// Qwen3-TTS stores the Talker LM under `talker.model.*`) and we want
46/// to reuse a backbone loader like `LlamaFamilyModel::new` that
47/// expects bare `model.*` names. Wrapping with
48/// `PrefixedLoader { inner, prefix: "talker." }` lets the backbone
49/// code stay prefix-agnostic.
50pub struct PrefixedLoader<'a, B: Backend> {
51 inner: &'a dyn WeightLoader<B>,
52 prefix: String,
53}
54
55impl<'a, B: Backend> PrefixedLoader<'a, B> {
56 pub fn new(inner: &'a dyn WeightLoader<B>, prefix: impl Into<String>) -> Self {
57 Self {
58 inner,
59 prefix: prefix.into(),
60 }
61 }
62}
63
64impl<'a, B: Backend> WeightLoader<B> for PrefixedLoader<'a, B> {
65 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
66 self.inner.load_tensor(&format!("{}{}", self.prefix, name))
67 }
68
69 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
70 self.inner.load_linear(&format!("{}{}", self.prefix, name))
71 }
72
73 fn has_tensor(&self, name: &str) -> bool {
74 self.inner.has_tensor(&format!("{}{}", self.prefix, name))
75 }
76
77 fn quant_config(&self) -> Option<&QuantConfig> {
78 self.inner.quant_config()
79 }
80}