Skip to main content

oxide_rs/model/
loader.rs

1use std::fs::File;
2use std::io::{Cursor, Seek};
3use std::path::PathBuf;
4
5use anyhow::{Context, Result};
6use candle_core::quantized::gguf_file;
7use candle_core::{Device, Tensor};
8use candle_transformers::models::quantized_lfm2::ModelWeights as Lfm2Model;
9use candle_transformers::models::quantized_llama::ModelWeights as LlamaModel;
10use candle_transformers::models::quantized_qwen2::ModelWeights as Qwen2Model;
11use candle_transformers::models::quantized_qwen3::ModelWeights as Qwen3Model;
12use memmap2::Mmap;
13
14use crate::model::quantized_qwen35::ModelWeights as Qwen35Model;
15
16#[derive(Debug, Clone)]
17pub struct GgufMetadata {
18    pub name: String,
19    pub architecture: String,
20    pub n_layer: usize,
21    pub n_embd: usize,
22    pub vocab_size: usize,
23    pub context_length: usize,
24    pub file_size: u64,
25    pub chat_template: Option<String>,
26    pub quantization: Option<String>,
27}
28
29pub enum ModelInner {
30    Llama(LlamaModel),
31    Lfm2(Lfm2Model),
32    Qwen2(Qwen2Model),
33    Qwen3(Qwen3Model),
34    Qwen35(Qwen35Model),
35}
36
37pub struct Model {
38    inner: ModelInner,
39    metadata: GgufMetadata,
40}
41
42pub struct ModelWithMmap {
43    pub model: Model,
44    pub mmap: Mmap,
45}
46
47impl Model {
48    pub fn load(path: &PathBuf) -> Result<Self> {
49        let (_, model) = Self::load_with_mmap(path)?;
50        Ok(model)
51    }
52
53    pub fn load_with_mmap(path: &PathBuf) -> Result<(Mmap, Self)> {
54        let file_size = std::fs::metadata(path)?.len();
55        let filename = path
56            .file_name()
57            .and_then(|n| n.to_str())
58            .unwrap_or("unknown");
59
60        let device = Device::Cpu;
61
62        let file =
63            File::open(path).with_context(|| format!("Failed to open model file: {:?}", path))?;
64
65        tracing::info!("Memory-mapping GGUF file ({} MB)...", file_size / 1_000_000);
66
67        let mmap = unsafe { Mmap::map(&file)? };
68
69        // Apply madvise hints BEFORE reading tensor data so the kernel begins
70        // async read-ahead while candle's sequential seek+read_exact calls follow.
71        // Calling these after from_gguf() would be useless — data already read.
72        {
73            let ptr = mmap.as_ptr() as *mut std::ffi::c_void;
74            let size = mmap.len();
75            unsafe {
76                libc::madvise(ptr, size, libc::MADV_SEQUENTIAL);
77                #[cfg(target_os = "linux")]
78                libc::madvise(ptr, size, libc::MADV_HUGEPAGE);
79                libc::madvise(ptr, size, libc::MADV_WILLNEED);
80            }
81            tracing::info!(
82                "madvise hints applied ({} MB): SEQUENTIAL + HUGEPAGE + WILLNEED",
83                size / 1_000_000
84            );
85        }
86
87        let mut cursor = Cursor::new(&mmap);
88
89        let content = gguf_file::Content::read(&mut cursor)
90            .with_context(|| format!("Failed to read GGUF file: {:?}", path))?;
91
92        let metadata = Self::extract_metadata(&content, filename, file_size)?;
93
94        let arch = metadata.architecture.as_str();
95        tracing::info!(
96            "Loading model: {} ({} layers, {} embedding dim, {} vocab, arch: {})",
97            metadata.name,
98            metadata.n_layer,
99            metadata.n_embd,
100            metadata.vocab_size,
101            arch
102        );
103
104        cursor.seek(std::io::SeekFrom::Start(0))?;
105
106        let inner = if arch == "lfm2" {
107            let weights = Lfm2Model::from_gguf(content, &mut cursor, &device)
108                .with_context(|| "Failed to load LFM2 model weights from GGUF")?;
109            ModelInner::Lfm2(weights)
110        } else if arch == "qwen2" {
111            let weights = Qwen2Model::from_gguf(content, &mut cursor, &device)
112                .with_context(|| "Failed to load Qwen2 model weights from GGUF")?;
113            ModelInner::Qwen2(weights)
114        } else if arch == "qwen3" {
115            let weights = Qwen3Model::from_gguf(content, &mut cursor, &device)
116                .with_context(|| "Failed to load Qwen3 model weights from GGUF")?;
117            ModelInner::Qwen3(weights)
118        } else if arch == "qwen35" {
119            let weights = Qwen35Model::from_gguf(content, &mut cursor, &device)
120                .with_context(|| "Failed to load Qwen3.5 model weights from GGUF")?;
121            ModelInner::Qwen35(weights)
122        } else {
123            let weights = LlamaModel::from_gguf(content, &mut cursor, &device)
124                .with_context(|| "Failed to load LLaMA model weights from GGUF")?;
125            ModelInner::Llama(weights)
126        };
127
128        tracing::info!("Model loaded successfully");
129
130        let model = Self { inner, metadata };
131        Ok((mmap, model))
132    }
133
134    /// No-op. madvise hints are now applied inside `load_with_mmap()` immediately
135    /// after the mmap is created and before tensor data is read, which is the only
136    /// point where they have effect. Calling this after load returns is useless.
137    #[allow(unused_variables)]
138    pub fn prefetch_mmap(_mmap: &Mmap) {}
139
140    fn extract_metadata(
141        content: &gguf_file::Content,
142        filename: &str,
143        file_size: u64,
144    ) -> Result<GgufMetadata> {
145        let md = &content.metadata;
146
147        let arch: String = match md.get("general.architecture") {
148            Some(v) => v
149                .to_string()
150                .cloned()
151                .unwrap_or_else(|_| "llama".to_string()),
152            None => "llama".to_string(),
153        };
154
155        let model_name: String = md
156            .get("general.name")
157            .and_then(|v| v.to_string().ok().cloned())
158            .unwrap_or_else(|| filename.to_string());
159
160        let find_key = |key_suffix: &str| -> Option<usize> {
161            let as_usize = |v: &gguf_file::Value| v.to_u64().ok().map(|n| n as usize);
162
163            if let Some(v) = md.get(&format!("{}.{}", arch, key_suffix)) {
164                return as_usize(v);
165            }
166
167            for (k, v) in md.iter() {
168                if k.ends_with(&format!(".{}", key_suffix)) {
169                    if let Some(val) = as_usize(v) {
170                        return Some(val);
171                    }
172                }
173            }
174            None
175        };
176
177        let get_required = |key_suffix: &str| -> Result<usize> {
178            find_key(key_suffix)
179                .ok_or_else(|| anyhow::anyhow!("Missing metadata key: {}", key_suffix))
180        };
181
182        let get_optional =
183            |key_suffix: &str, default: usize| -> usize { find_key(key_suffix).unwrap_or(default) };
184
185        let chat_template = md
186            .get("tokenizer.chat_template")
187            .and_then(|v| v.to_string().ok().cloned());
188
189        let quantization: Option<String> = md
190            .get("general.quantization")
191            .and_then(|v| v.to_string().ok().map(|s| s.to_string()))
192            .or_else(|| {
193                // Try to find quantization in various model metadata keys
194                let quant_keys = [
195                    "quantization_version",
196                    "quantization",
197                    "quantization_format",
198                ];
199                for key in quant_keys {
200                    if let Some(v) = md.get(&format!("{}.{}", arch, key)) {
201                        if let Ok(s) = v.to_string() {
202                            if !s.is_empty() {
203                                return Some(s.to_string());
204                            }
205                        }
206                    }
207                }
208                None
209            })
210            .or_else(|| {
211                filename
212                    .split('.')
213                    .filter(|s| !s.eq_ignore_ascii_case("gguf") && !s.eq_ignore_ascii_case("bin"))
214                    .last()
215                    .map(|s| s.to_string())
216                    .filter(|s| {
217                        s.len() >= 2
218                            && (s.starts_with("q") || s.starts_with("Q"))
219                            && s.chars()
220                                .skip(1)
221                                .all(|c| c.is_ascii_digit() || c == '_' || c == '-')
222                    })
223            });
224
225        Ok(GgufMetadata {
226            name: model_name,
227            architecture: arch.clone(),
228            n_layer: get_required("block_count")?,
229            n_embd: get_required("embedding_length")?,
230            vocab_size: find_key("vocab_size")
231                .or_else(|| {
232                    // Fallback: derive from tokenizer token list length.
233                    // Newer GGUF files (e.g. qwen35) omit the explicit vocab_size key.
234                    md.get("tokenizer.ggml.tokens")
235                        .and_then(|v| v.to_vec().ok())
236                        .map(|arr| arr.len())
237                })
238                .ok_or_else(|| anyhow::anyhow!("Missing metadata key: vocab_size"))?,
239            context_length: get_optional("context_length", 4096),
240            file_size,
241            chat_template,
242            quantization,
243        })
244    }
245
246    pub fn metadata(&self) -> &GgufMetadata {
247        &self.metadata
248    }
249
250    pub fn clear_kv_cache(&mut self) {
251        match &mut self.inner {
252            ModelInner::Qwen3(m) => m.clear_kv_cache(),
253            ModelInner::Qwen35(m) => m.clear_kv_cache(),
254            ModelInner::Llama(_) | ModelInner::Lfm2(_) | ModelInner::Qwen2(_) => {}
255        }
256    }
257
258    pub fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Tensor> {
259        if pos == 0 {
260            self.clear_kv_cache();
261        }
262        let input = Tensor::new(tokens, &Device::Cpu)?.unsqueeze(0)?;
263
264        let logits = match &mut self.inner {
265            ModelInner::Llama(m) => m.forward(&input, pos)?,
266            ModelInner::Lfm2(m) => m.forward(&input, pos)?,
267            ModelInner::Qwen2(m) => m.forward(&input, pos)?,
268            ModelInner::Qwen3(m) => m.forward(&input, pos)?,
269            ModelInner::Qwen35(m) => m.forward(&input, pos)?,
270        };
271        Ok(logits)
272    }
273}