Skip to main content

mistralrs_core/pipeline/loaders/
auto_device_map.rs

1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4    calculate_cache_config, MemoryGpuConfig, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15const GPU_RESERVE_FRACTION: f64 = 0.02;
16const GPU_MIN_RESERVE_BYTES: usize = 512 * 1024 * 1024; // 512MB safety buffer
17
18/// Usable device capacity after subtracting a small safety reserve for GPUs.
19/// CPU devices return `avail_bytes` unchanged.
20#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
21fn device_cap(avail_bytes: usize, dev: &Device) -> usize {
22    if dev.is_cpu() {
23        avail_bytes
24    } else {
25        let reserve_frac = (avail_bytes as f64 * GPU_RESERVE_FRACTION) as usize;
26        let reserve = reserve_frac.max(GPU_MIN_RESERVE_BYTES).min(avail_bytes);
27        avail_bytes.saturating_sub(reserve)
28    }
29}
30
31#[derive(Clone, Debug)]
32pub(crate) enum NonMappedSubModel {
33    Vision,
34    Audio,
35}
36
37impl Display for NonMappedSubModel {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            NonMappedSubModel::Vision => write!(f, "vision"),
41            NonMappedSubModel::Audio => write!(f, "audio"),
42        }
43    }
44}
45
46#[derive(Debug, Clone)]
47pub enum AutoDeviceMapParams {
48    Text {
49        max_seq_len: usize,
50        max_batch_size: usize,
51    },
52    Multimodal {
53        max_seq_len: usize,
54        max_batch_size: usize,
55        max_image_shape: (usize, usize),
56        max_num_images: usize,
57    },
58}
59
60impl AutoDeviceMapParams {
61    pub fn maybe_promote_to_multimodal(&self) -> Self {
62        match *self {
63            Self::Text {
64                max_seq_len,
65                max_batch_size,
66            } => Self::Multimodal {
67                max_seq_len,
68                max_batch_size,
69                max_image_shape: (
70                    Self::DEFAULT_MAX_IMAGE_LENGTH,
71                    Self::DEFAULT_MAX_IMAGE_LENGTH,
72                ),
73                max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
74            },
75            Self::Multimodal {
76                max_seq_len,
77                max_batch_size,
78                max_image_shape,
79                max_num_images,
80            } => Self::Multimodal {
81                max_seq_len,
82                max_batch_size,
83                max_image_shape,
84                max_num_images,
85            },
86        }
87    }
88
89    pub fn max_seq_len(&self) -> usize {
90        match self {
91            Self::Text { max_seq_len, .. } | Self::Multimodal { max_seq_len, .. } => *max_seq_len,
92        }
93    }
94
95    pub fn max_batch_size(&self) -> usize {
96        match self {
97            Self::Text { max_batch_size, .. } | Self::Multimodal { max_batch_size, .. } => {
98                *max_batch_size
99            }
100        }
101    }
102}
103
104impl Display for AutoDeviceMapParams {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Self::Text {
108                max_seq_len,
109                max_batch_size,
110            } => write!(
111                f,
112                "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
113            ),
114            Self::Multimodal {
115                max_seq_len,
116                max_batch_size,
117                max_image_shape,
118                max_num_images,
119            } => write!(
120                f,
121                "multimodal[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
122            ),
123        }
124    }
125}
126
127impl AutoDeviceMapParams {
128    // Default max sequence length for memory estimation when not specified
129    pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
130    pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
131    pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
132    pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
133
134    pub fn default_text() -> Self {
135        Self::Text {
136            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
137            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
138        }
139    }
140
141    pub fn default_multimodal() -> Self {
142        Self::Multimodal {
143            max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
144            max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
145            max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
146            max_image_shape: (
147                Self::DEFAULT_MAX_IMAGE_LENGTH,
148                Self::DEFAULT_MAX_IMAGE_LENGTH,
149            ),
150        }
151    }
152}
153
154fn calculate_key_block_shape(
155    model_config: &dyn ModelConfigLike,
156    dtype: DType,
157    block_size: usize,
158) -> (usize, usize, usize, usize) {
159    let element_size = dtype.size_in_bytes();
160    let x = 16 / element_size;
161    (
162        model_config.num_kv_heads(),
163        model_config.k_head_dim() / x,
164        block_size,
165        x,
166    )
167}
168
169fn calculate_value_block_shape(
170    model_config: &dyn ModelConfigLike,
171    block_size: usize,
172) -> (usize, usize, usize) {
173    (
174        model_config.num_kv_heads(),
175        model_config.v_head_dim(),
176        block_size,
177    )
178}
179
180macro_rules! b_to_mb {
181    ($x:expr) => {
182        $x / (1024 * 1024)
183    };
184}
185
186#[allow(
187    clippy::too_many_arguments,
188    clippy::cast_possible_truncation,
189    clippy::cast_precision_loss
190)]
191/// Core logic for automatic device mapping
192pub fn get_device_layers(
193    loader: &dyn DeviceMappedModelLoader,
194    config: &str,
195    num_layers: usize,
196    mut layer_sizes_in_bytes: Vec<usize>,
197    non_mapped_size_in_bytes: usize,
198    total_model_size_in_bytes: usize,
199    devices: &[Device],
200    dtype: DType,
201    params: &AutoDeviceMapParams,
202    paged_attn_config: Option<&PagedAttentionConfig>,
203) -> Result<DeviceMapMetadata> {
204    let mapped_max = loader.mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
205    let non_mapped_max =
206        loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
207
208    let mut layer_sizes_backup = if paged_attn_config.is_some() {
209        Some(layer_sizes_in_bytes.clone())
210    } else {
211        None
212    };
213
214    let mut remaining = total_model_size_in_bytes;
215    let max_seq_len = match params {
216        AutoDeviceMapParams::Text { max_seq_len, .. }
217        | AutoDeviceMapParams::Multimodal { max_seq_len, .. } => *max_seq_len,
218    };
219    let max_batch_size = match params {
220        AutoDeviceMapParams::Text { max_batch_size, .. }
221        | AutoDeviceMapParams::Multimodal { max_batch_size, .. } => *max_batch_size,
222    };
223
224    let model_cfg = loader.model_config(config)?;
225    let kv_cache_elems = match paged_attn_config {
226        Some(cfg) => {
227            // For MbAmount, clamp to available memory so the capacity check
228            // below stays consistent. Utilization and ContextSize pass through
229            // to calculate_cache_config which handles model weight subtraction.
230            let effective_mem_gpu = match cfg.mem_gpu {
231                MemoryGpuConfig::MbAmount(user_mb) => {
232                    // Clamp user's KV budget to available memory.
233                    let primary_dev = &devices[0];
234                    let avail_bytes = MemoryUsage.get_memory_available(primary_dev)?;
235                    let cap = device_cap(avail_bytes, primary_dev);
236                    let act_overhead = non_mapped_max.max(mapped_max);
237                    let budget_mb = cap.saturating_sub(act_overhead) / (1024 * 1024);
238                    MemoryGpuConfig::MbAmount(budget_mb.min(user_mb))
239                }
240                MemoryGpuConfig::Utilization(f) => {
241                    // Prevent overallocation when total_memory > available_memory
242                    // (e.g., unified memory systems, other GPU processes using VRAM).
243                    // Cap the KV budget so model + activations + KV fits within
244                    // the device capacity derived from *available* memory.
245                    let primary_dev = &devices[0];
246                    let avail_bytes = MemoryUsage.get_memory_available(primary_dev)?;
247                    let cap = device_cap(avail_bytes, primary_dev);
248                    let act_overhead = non_mapped_max.max(mapped_max);
249                    let budget_mb = ((cap as f64 * f as f64) as usize)
250                        .saturating_sub(remaining + act_overhead)
251                        / (1024 * 1024);
252                    MemoryGpuConfig::MbAmount(budget_mb)
253                }
254                // ContextSize passes through to calculate_cache_config.
255                other => other,
256            };
257
258            let cache = calculate_cache_config(
259                effective_mem_gpu,
260                Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
261                dtype,
262                paged_attn_config
263                    .map(|cfg| cfg.cache_type)
264                    .unwrap_or_default(),
265                &*model_cfg,
266                &devices[0],
267                &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
268                true,
269                Some(total_model_size_in_bytes),
270                Some(max_seq_len * max_batch_size),
271            )?;
272            let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
273            let key_sz =
274                cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
275            let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
276            let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
277            key_sz + val_sz
278        }
279        None => {
280            let key_shape = [
281                max_batch_size,
282                model_cfg.num_kv_heads(),
283                max_seq_len,
284                model_cfg.k_head_dim(),
285            ];
286            let val_shape = [
287                max_batch_size,
288                model_cfg.num_kv_heads(),
289                max_seq_len,
290                model_cfg.v_head_dim(),
291            ];
292            key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
293        }
294    };
295    let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
296
297    // prepare available memory per device, CPU fallback last (unless unified memory)
298    let has_unified_memory = devices.iter().any(crate::utils::normal::is_integrated_gpu);
299
300    let mut avail = Vec::new();
301    for dev in devices {
302        let a = MemoryUsage.get_memory_available(dev)?;
303        avail.push((a, dev.clone()));
304    }
305    // On unified memory systems (iGPUs), GPU and CPU share the same physical RAM.
306    // Don't add CPU as a fallback device since it would double-count memory.
307    if !has_unified_memory {
308        let a = MemoryUsage.get_memory_available(&Device::Cpu)?;
309        avail.push((a, Device::Cpu));
310    }
311
312    avail.reverse();
313    layer_sizes_in_bytes.reverse();
314
315    let mut mappings = Vec::new();
316    info!("Using automatic device mapping parameters: {params}.");
317    if let Some(subs) = loader.non_mapped_sub_models() {
318        let (_, last) = avail.last().unwrap();
319        info!(
320            "The following sub-models will not be device mapped and will be loaded on {}: {}",
321            last.device_pretty_repr(),
322            subs.iter().map(|x| x.to_string()).join(", ")
323        );
324    }
325
326    let mut ordinal = 0;
327    let mut layer = 0;
328    let avail_copy = avail.clone();
329    let mut includes_cpu = false;
330    while remaining > 0 && !avail.is_empty() {
331        let (avail_bytes, dev) = avail
332            .pop()
333            .context("No more devices to map to. The model does not fit on this system.")?;
334
335        // For GPU/accelerators: keep a small dynamic safety reserve to avoid OOMs
336        let cap = device_cap(avail_bytes, &dev);
337
338        // Algorithm is to check the following:
339        // 1) (no mapping) if *everything* fits on the first dev (non mapped and mapped)
340        // 2) if the mapped activations plus remaining fits on the nth device
341        // 3) common case, iteratively find the optimal amount of layers to put on the nth device
342        //   - if this is the first dev: must hold the non-mapped act and non-mapped model
343        //   - otherwise, must hold the mapped act
344        let required_whole_capacity = if ordinal == 0 {
345            remaining + non_mapped_max.max(mapped_max) + kv_cache_bytes * (num_layers - layer)
346        } else {
347            remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
348        };
349
350        let layers_on_dev = if cap >= required_whole_capacity {
351            remaining = 0;
352            num_layers - layer
353        } else {
354            let mut used = mapped_max;
355            let mut used_weight_bytes = 0;
356            let mut count = 0;
357            if ordinal == 0 {
358                used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
359                used_weight_bytes += non_mapped_size_in_bytes;
360            }
361            while let Some(&sz) = layer_sizes_in_bytes.last() {
362                let delta = sz + kv_cache_bytes;
363                if used + delta > cap {
364                    break;
365                }
366                layer_sizes_in_bytes.pop();
367                used += delta;
368                used_weight_bytes += sz;
369                count += 1;
370            }
371            if count > 0 {
372                remaining = remaining.saturating_sub(used_weight_bytes);
373            } else {
374                warn!(
375                    "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
376                    dev.device_pretty_repr(),
377                );
378                ordinal += 1;
379                continue;
380            }
381            count
382        };
383        if !dev.is_cpu() {
384            mappings.push(DeviceLayerMapMetadata {
385                ordinal,
386                layers: layers_on_dev,
387            });
388            ordinal += 1;
389        } else {
390            includes_cpu = true;
391        }
392        layer += layers_on_dev;
393    }
394    if remaining > 0 {
395        let over = b_to_mb!(remaining);
396        anyhow::bail!(
397            "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
398            avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
399            over
400        );
401    }
402    if paged_attn_config.is_some_and(|_| includes_cpu) {
403        let original_layers = layer_sizes_backup
404            .take()
405            .expect("layer sizes backup missing for paged attention fallback");
406        // The original vector was in forward order, but `get_device_layers` handles
407        // reversing internally, so we can pass it along unchanged.
408        return get_device_layers(
409            loader,
410            config,
411            num_layers,
412            original_layers,
413            non_mapped_size_in_bytes,
414            total_model_size_in_bytes,
415            devices,
416            dtype,
417            params,
418            None,
419        );
420    }
421    Ok(DeviceMapMetadata::from_num_device_layers(mappings))
422}