1use std::fmt::{self, Display};
2
3use crate::paged_attention::{
4 calculate_cache_config, MemoryGpuConfig, ModelConfigLike, DEFAULT_PAGED_ATTENTION_BLOCK_SIZE,
5};
6use crate::utils::debug::DeviceRepr;
7use crate::{DeviceLayerMapMetadata, DeviceMapMetadata, MemoryUsage, PagedAttentionConfig};
8use anyhow::{Context, Result};
9use candle_core::{DType, Device};
10use itertools::Itertools;
11use tracing::{info, warn};
12
13use super::DeviceMappedModelLoader;
14
15const GPU_RESERVE_FRACTION: f64 = 0.02;
16const GPU_MIN_RESERVE_BYTES: usize = 512 * 1024 * 1024; #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
21fn device_cap(avail_bytes: usize, dev: &Device) -> usize {
22 if dev.is_cpu() {
23 avail_bytes
24 } else {
25 let reserve_frac = (avail_bytes as f64 * GPU_RESERVE_FRACTION) as usize;
26 let reserve = reserve_frac.max(GPU_MIN_RESERVE_BYTES).min(avail_bytes);
27 avail_bytes.saturating_sub(reserve)
28 }
29}
30
31#[derive(Clone, Debug)]
32pub(crate) enum NonMappedSubModel {
33 Vision,
34 Audio,
35}
36
37impl Display for NonMappedSubModel {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 NonMappedSubModel::Vision => write!(f, "vision"),
41 NonMappedSubModel::Audio => write!(f, "audio"),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub enum AutoDeviceMapParams {
48 Text {
49 max_seq_len: usize,
50 max_batch_size: usize,
51 },
52 Multimodal {
53 max_seq_len: usize,
54 max_batch_size: usize,
55 max_image_shape: (usize, usize),
56 max_num_images: usize,
57 },
58}
59
60impl AutoDeviceMapParams {
61 pub fn maybe_promote_to_multimodal(&self) -> Self {
62 match *self {
63 Self::Text {
64 max_seq_len,
65 max_batch_size,
66 } => Self::Multimodal {
67 max_seq_len,
68 max_batch_size,
69 max_image_shape: (
70 Self::DEFAULT_MAX_IMAGE_LENGTH,
71 Self::DEFAULT_MAX_IMAGE_LENGTH,
72 ),
73 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
74 },
75 Self::Multimodal {
76 max_seq_len,
77 max_batch_size,
78 max_image_shape,
79 max_num_images,
80 } => Self::Multimodal {
81 max_seq_len,
82 max_batch_size,
83 max_image_shape,
84 max_num_images,
85 },
86 }
87 }
88
89 pub fn max_seq_len(&self) -> usize {
90 match self {
91 Self::Text { max_seq_len, .. } | Self::Multimodal { max_seq_len, .. } => *max_seq_len,
92 }
93 }
94
95 pub fn max_batch_size(&self) -> usize {
96 match self {
97 Self::Text { max_batch_size, .. } | Self::Multimodal { max_batch_size, .. } => {
98 *max_batch_size
99 }
100 }
101 }
102}
103
104impl Display for AutoDeviceMapParams {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Self::Text {
108 max_seq_len,
109 max_batch_size,
110 } => write!(
111 f,
112 "text[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}]"
113 ),
114 Self::Multimodal {
115 max_seq_len,
116 max_batch_size,
117 max_image_shape,
118 max_num_images,
119 } => write!(
120 f,
121 "multimodal[max_seq_len: {max_seq_len}, max_batch_size: {max_batch_size}, max_image_shape: {max_image_shape:?}, max_num_images: {max_num_images}]"
122 ),
123 }
124 }
125}
126
127impl AutoDeviceMapParams {
128 pub const DEFAULT_MAX_SEQ_LEN: usize = 4 * 1024;
130 pub const DEFAULT_MAX_BATCH_SIZE: usize = 1;
131 pub const DEFAULT_MAX_NUM_IMAGES: usize = 1;
132 pub const DEFAULT_MAX_IMAGE_LENGTH: usize = 1024;
133
134 pub fn default_text() -> Self {
135 Self::Text {
136 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
137 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
138 }
139 }
140
141 pub fn default_multimodal() -> Self {
142 Self::Multimodal {
143 max_seq_len: Self::DEFAULT_MAX_SEQ_LEN,
144 max_batch_size: Self::DEFAULT_MAX_BATCH_SIZE,
145 max_num_images: Self::DEFAULT_MAX_NUM_IMAGES,
146 max_image_shape: (
147 Self::DEFAULT_MAX_IMAGE_LENGTH,
148 Self::DEFAULT_MAX_IMAGE_LENGTH,
149 ),
150 }
151 }
152}
153
154fn calculate_key_block_shape(
155 model_config: &dyn ModelConfigLike,
156 dtype: DType,
157 block_size: usize,
158) -> (usize, usize, usize, usize) {
159 let element_size = dtype.size_in_bytes();
160 let x = 16 / element_size;
161 (
162 model_config.num_kv_heads(),
163 model_config.k_head_dim() / x,
164 block_size,
165 x,
166 )
167}
168
169fn calculate_value_block_shape(
170 model_config: &dyn ModelConfigLike,
171 block_size: usize,
172) -> (usize, usize, usize) {
173 (
174 model_config.num_kv_heads(),
175 model_config.v_head_dim(),
176 block_size,
177 )
178}
179
180macro_rules! b_to_mb {
181 ($x:expr) => {
182 $x / (1024 * 1024)
183 };
184}
185
186#[allow(
187 clippy::too_many_arguments,
188 clippy::cast_possible_truncation,
189 clippy::cast_precision_loss
190)]
191pub fn get_device_layers(
193 loader: &dyn DeviceMappedModelLoader,
194 config: &str,
195 num_layers: usize,
196 mut layer_sizes_in_bytes: Vec<usize>,
197 non_mapped_size_in_bytes: usize,
198 total_model_size_in_bytes: usize,
199 devices: &[Device],
200 dtype: DType,
201 params: &AutoDeviceMapParams,
202 paged_attn_config: Option<&PagedAttentionConfig>,
203) -> Result<DeviceMapMetadata> {
204 let mapped_max = loader.mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
205 let non_mapped_max =
206 loader.non_mapped_max_act_size_elems(config, params)? * dtype.size_in_bytes();
207
208 let mut layer_sizes_backup = if paged_attn_config.is_some() {
209 Some(layer_sizes_in_bytes.clone())
210 } else {
211 None
212 };
213
214 let mut remaining = total_model_size_in_bytes;
215 let max_seq_len = match params {
216 AutoDeviceMapParams::Text { max_seq_len, .. }
217 | AutoDeviceMapParams::Multimodal { max_seq_len, .. } => *max_seq_len,
218 };
219 let max_batch_size = match params {
220 AutoDeviceMapParams::Text { max_batch_size, .. }
221 | AutoDeviceMapParams::Multimodal { max_batch_size, .. } => *max_batch_size,
222 };
223
224 let model_cfg = loader.model_config(config)?;
225 let kv_cache_elems = match paged_attn_config {
226 Some(cfg) => {
227 let effective_mem_gpu = match cfg.mem_gpu {
231 MemoryGpuConfig::MbAmount(user_mb) => {
232 let primary_dev = &devices[0];
234 let avail_bytes = MemoryUsage.get_memory_available(primary_dev)?;
235 let cap = device_cap(avail_bytes, primary_dev);
236 let act_overhead = non_mapped_max.max(mapped_max);
237 let budget_mb = cap.saturating_sub(act_overhead) / (1024 * 1024);
238 MemoryGpuConfig::MbAmount(budget_mb.min(user_mb))
239 }
240 MemoryGpuConfig::Utilization(f) => {
241 let primary_dev = &devices[0];
246 let avail_bytes = MemoryUsage.get_memory_available(primary_dev)?;
247 let cap = device_cap(avail_bytes, primary_dev);
248 let act_overhead = non_mapped_max.max(mapped_max);
249 let budget_mb = ((cap as f64 * f as f64) as usize)
250 .saturating_sub(remaining + act_overhead)
251 / (1024 * 1024);
252 MemoryGpuConfig::MbAmount(budget_mb)
253 }
254 other => other,
256 };
257
258 let cache = calculate_cache_config(
259 effective_mem_gpu,
260 Some(cfg.block_size.unwrap_or(DEFAULT_PAGED_ATTENTION_BLOCK_SIZE)),
261 dtype,
262 paged_attn_config
263 .map(|cfg| cfg.cache_type)
264 .unwrap_or_default(),
265 &*model_cfg,
266 &devices[0],
267 &devices.iter().map(|d| Some(d.clone())).collect::<Vec<_>>(),
268 true,
269 Some(total_model_size_in_bytes),
270 Some(max_seq_len * max_batch_size),
271 )?;
272 let key_shape = calculate_key_block_shape(&*model_cfg, dtype, cache.block_size);
273 let key_sz =
274 cache.num_gpu_blocks * key_shape.0 * key_shape.1 * key_shape.2 * key_shape.3;
275 let val_shape = calculate_value_block_shape(&*model_cfg, cache.block_size);
276 let val_sz = cache.num_gpu_blocks * val_shape.0 * val_shape.1 * val_shape.2;
277 key_sz + val_sz
278 }
279 None => {
280 let key_shape = [
281 max_batch_size,
282 model_cfg.num_kv_heads(),
283 max_seq_len,
284 model_cfg.k_head_dim(),
285 ];
286 let val_shape = [
287 max_batch_size,
288 model_cfg.num_kv_heads(),
289 max_seq_len,
290 model_cfg.v_head_dim(),
291 ];
292 key_shape.iter().product::<usize>() + val_shape.iter().product::<usize>()
293 }
294 };
295 let kv_cache_bytes = kv_cache_elems * dtype.size_in_bytes();
296
297 let has_unified_memory = devices.iter().any(crate::utils::normal::is_integrated_gpu);
299
300 let mut avail = Vec::new();
301 for dev in devices {
302 let a = MemoryUsage.get_memory_available(dev)?;
303 avail.push((a, dev.clone()));
304 }
305 if !has_unified_memory {
308 let a = MemoryUsage.get_memory_available(&Device::Cpu)?;
309 avail.push((a, Device::Cpu));
310 }
311
312 avail.reverse();
313 layer_sizes_in_bytes.reverse();
314
315 let mut mappings = Vec::new();
316 info!("Using automatic device mapping parameters: {params}.");
317 if let Some(subs) = loader.non_mapped_sub_models() {
318 let (_, last) = avail.last().unwrap();
319 info!(
320 "The following sub-models will not be device mapped and will be loaded on {}: {}",
321 last.device_pretty_repr(),
322 subs.iter().map(|x| x.to_string()).join(", ")
323 );
324 }
325
326 let mut ordinal = 0;
327 let mut layer = 0;
328 let avail_copy = avail.clone();
329 let mut includes_cpu = false;
330 while remaining > 0 && !avail.is_empty() {
331 let (avail_bytes, dev) = avail
332 .pop()
333 .context("No more devices to map to. The model does not fit on this system.")?;
334
335 let cap = device_cap(avail_bytes, &dev);
337
338 let required_whole_capacity = if ordinal == 0 {
345 remaining + non_mapped_max.max(mapped_max) + kv_cache_bytes * (num_layers - layer)
346 } else {
347 remaining + mapped_max + kv_cache_bytes * (num_layers - layer)
348 };
349
350 let layers_on_dev = if cap >= required_whole_capacity {
351 remaining = 0;
352 num_layers - layer
353 } else {
354 let mut used = mapped_max;
355 let mut used_weight_bytes = 0;
356 let mut count = 0;
357 if ordinal == 0 {
358 used = used.max(non_mapped_max) + non_mapped_size_in_bytes;
359 used_weight_bytes += non_mapped_size_in_bytes;
360 }
361 while let Some(&sz) = layer_sizes_in_bytes.last() {
362 let delta = sz + kv_cache_bytes;
363 if used + delta > cap {
364 break;
365 }
366 layer_sizes_in_bytes.pop();
367 used += delta;
368 used_weight_bytes += sz;
369 count += 1;
370 }
371 if count > 0 {
372 remaining = remaining.saturating_sub(used_weight_bytes);
373 } else {
374 warn!(
375 "Device {} can fit 0 layers. Consider reducing auto map params from current: {params} (ex. reducing max seq len or max num images)",
376 dev.device_pretty_repr(),
377 );
378 ordinal += 1;
379 continue;
380 }
381 count
382 };
383 if !dev.is_cpu() {
384 mappings.push(DeviceLayerMapMetadata {
385 ordinal,
386 layers: layers_on_dev,
387 });
388 ordinal += 1;
389 } else {
390 includes_cpu = true;
391 }
392 layer += layers_on_dev;
393 }
394 if remaining > 0 {
395 let over = b_to_mb!(remaining);
396 anyhow::bail!(
397 "This model does not fit on the devices {:?}, and exceeds total capacity by {}MB. Auto device mapping params: {params}",
398 avail_copy.iter().rev().map(|(a, d)| format!("{} (avail: {}MB)", d.device_pretty_repr(), b_to_mb!(a))).collect::<Vec<_>>(),
399 over
400 );
401 }
402 if paged_attn_config.is_some_and(|_| includes_cpu) {
403 let original_layers = layer_sizes_backup
404 .take()
405 .expect("layer sizes backup missing for paged attention fallback");
406 return get_device_layers(
409 loader,
410 config,
411 num_layers,
412 original_layers,
413 non_mapped_size_in_bytes,
414 total_model_size_in_bytes,
415 devices,
416 dtype,
417 params,
418 None,
419 );
420 }
421 Ok(DeviceMapMetadata::from_num_device_layers(mappings))
422}