Skip to main content

mistralrs_core/
device_map.rs

1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4    pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
7use mistralrs_quant::log::once_log_info;
8use mistralrs_quant::ShardedVarBuilder;
9use serde::{Deserialize, Serialize};
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Serialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14    pub ordinal: usize,
15    pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20    /// Manual device mapping.
21    Map(DeviceMapMetadata),
22    /// Automatic device mapping (recommended).
23    Auto(AutoDeviceMapParams),
24    /// Dummy device mapping for a NCCL pipeline
25    DummyNccl { nm_device: Device },
26    /// Real device mapping for a NCCL pipeline
27    Nccl {
28        nm_device: Device,
29        comm: Arc<mistralrs_quant::Comm>,
30    },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34/// Metadata to initialize the device mapper.
35pub struct DeviceMapMetadata {
36    device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37    host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41    pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42        Self {
43            device_layers: Some(device_layers),
44            host_layers: None,
45        }
46    }
47    /// A device mapper to not map device.
48    pub fn dummy() -> Self {
49        Self {
50            device_layers: None,
51            host_layers: None,
52        }
53    }
54
55    pub fn device_layers(&self) -> Option<&[DeviceLayerMapMetadata]> {
56        self.device_layers.as_deref()
57    }
58
59    pub fn host_layers(&self) -> Option<usize> {
60        self.host_layers
61    }
62
63    pub fn to_cli_spec(&self) -> Option<String> {
64        let layers = self.device_layers.as_ref()?;
65        if layers.is_empty() {
66            return None;
67        }
68        Some(
69            layers
70                .iter()
71                .map(|l| format!("{}:{}", l.ordinal, l.layers))
72                .collect::<Vec<_>>()
73                .join(";"),
74        )
75    }
76}
77
78impl DeviceMapSetting {
79    /// A device mapper to not map device.
80    pub fn dummy() -> Self {
81        Self::Map(DeviceMapMetadata::dummy())
82    }
83    pub fn into_mapper(
84        &self,
85        model_layers: usize,
86        device: &Device,
87        topology: Option<&Topology>,
88        all_devices: &[Device],
89    ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
90        match self {
91            Self::Nccl { nm_device, comm } => {
92                once_log_info("Loading model using a NCCL-parallelized pipeline.");
93                Ok(Box::new(NcclDeviceMapper {
94                    nm_device: nm_device.clone(),
95                    model_layers,
96                    comm: Some(comm.clone()),
97                }))
98            }
99
100            Self::DummyNccl { nm_device } => {
101                once_log_info("Loading model using a NCCL-parallelized pipeline.");
102                Ok(Box::new(NcclDeviceMapper {
103                    nm_device: nm_device.clone(),
104                    model_layers,
105                    comm: None,
106                }))
107            }
108
109            Self::Map(DeviceMapMetadata {
110                device_layers,
111                host_layers,
112            }) => {
113                if let Some(topology) = topology {
114                    if topology.layers.iter().all(|x| x.is_none()) {
115                        return Ok(Box::new(DummyDeviceMapper {
116                            nm_device: device.clone(),
117                        }));
118                    } else {
119                        let layers = topology
120                            .layers
121                            .iter()
122                            .map(|layer| {
123                                layer
124                                    .as_ref()
125                                    .map(|x| x.device.clone().unwrap_or(device.clone()))
126                                    .unwrap_or(device.clone())
127                            })
128                            .collect::<Vec<_>>();
129
130                        info!("Loading model according to the following repeating layer mappings based on topology:");
131                        for (i, dev) in layers.iter().enumerate() {
132                            info!("Layer {i}: {}", dev.device_pretty_repr());
133                        }
134
135                        return Ok(Box::new(LayerDeviceMapper {
136                            mappings: layers,
137                            nm_device: device.clone(),
138                        }));
139                    }
140                }
141
142                // How many device layers
143                // Clamp to max of model layers
144                let n_device_layers = if let Some(layers) = &device_layers {
145                    layers
146                        .iter()
147                        .map(|metadata| metadata.layers)
148                        .sum::<usize>()
149                        .clamp(0, model_layers)
150                } else {
151                    return Ok(Box::new(DummyDeviceMapper {
152                        nm_device: device.clone(),
153                    }));
154                };
155                // How many host (cpu) layers, defaulting to automatically filling the rest.
156                // If n_device_layers > model_layers, n_host_layers = 0
157                let n_host_layers =
158                    host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
159                if n_device_layers + n_host_layers != model_layers {
160                    candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
161                }
162                once_log_info(format!("Model has {model_layers} repeating layers."));
163
164                // Handle multi-GPU mapping here
165                let mut combined = Vec::with_capacity(model_layers);
166                if device_layers
167                    .as_ref()
168                    .is_some_and(|layers| layers.len() == 1)
169                {
170                    combined.extend(vec![device.clone(); n_device_layers]);
171                } else {
172                    let original_seed = if !device.is_cpu() {
173                        Some(device.get_current_seed()?)
174                    } else {
175                        None
176                    };
177                    for DeviceLayerMapMetadata { ordinal, layers } in
178                        device_layers.as_ref().unwrap()
179                    {
180                        let dev = match device.location() {
181                            DeviceLocation::Cpu => Device::Cpu,
182                            DeviceLocation::Cuda { gpu_id: device_ord } => {
183                                if device_ord == *ordinal {
184                                    device.clone()
185                                } else {
186                                    let cuda_device = all_devices
187                                        .iter()
188                                        .filter(|d| d.is_cuda())
189                                        .map(|d| {
190                                            // should implement this in candle and get the ordinal back from the device location directly
191                                            let ordinal = match d.location() {
192                                                DeviceLocation::Cpu => 0,
193                                                DeviceLocation::Cuda { gpu_id } => gpu_id,
194                                                DeviceLocation::Metal { gpu_id } => gpu_id,
195                                            };
196                                            (d.clone(), ordinal)
197                                        })
198                                        .find(|(_, other_device_ordinal)| {
199                                            other_device_ordinal == ordinal
200                                        });
201
202                                    if let Some((device, _)) = cuda_device {
203                                        device
204                                    } else {
205                                        candle_core::bail!(
206                                            "Could not find cuda device with ordinal {}",
207                                            ordinal
208                                        )
209                                    }
210                                }
211                            }
212                            DeviceLocation::Metal { gpu_id: device_ord } => {
213                                if device_ord == *ordinal {
214                                    device.clone()
215                                } else {
216                                    Device::new_metal(*ordinal)?
217                                }
218                            }
219                        };
220                        if !device.is_cpu() {
221                            dev.set_seed(original_seed.unwrap())?;
222                        }
223                        combined.extend(vec![dev; *layers]);
224                    }
225                }
226
227                // Always put the CPU layers at the end so that we reduce dtoh and htod copies
228                combined.extend(vec![Device::Cpu; n_host_layers]);
229
230                // Sanity
231                assert_eq!(combined.len(), model_layers);
232
233                // Print it out
234                {
235                    once_log_info(
236                        "Loading model according to the following repeating layer mappings:",
237                    );
238                    let mut start_index = 0;
239                    let mut current_dev = &combined[0];
240
241                    // Iterate starting from index 1 to detect when the variant changes
242                    for (i, variant) in combined.iter().enumerate().skip(1) {
243                        // If the variant changes, print the previous continuous block
244                        if !variant.same_device(current_dev) {
245                            once_log_info(format!(
246                                "Layers {}-{}: {} ({} GB)",
247                                start_index,
248                                i - 1,
249                                current_dev.device_pretty_repr(),
250                                MemoryUsage
251                                    .get_total_memory(current_dev)?
252                                    .div_ceil(1024 * 1024 * 1024),
253                            ));
254                            start_index = i; // start a new range
255                            current_dev = variant;
256                        }
257                    }
258
259                    once_log_info(format!(
260                        "Layers {}-{}: {} ({} GB)",
261                        start_index,
262                        combined.len() - 1,
263                        current_dev.device_pretty_repr(),
264                        MemoryUsage
265                            .get_total_memory(current_dev)?
266                            .div_ceil(1024 * 1024 * 1024),
267                    ));
268                }
269
270                Ok(Box::new(LayerDeviceMapper {
271                    mappings: combined,
272                    nm_device: device.clone(),
273                }))
274            }
275            Self::Auto(_) => {
276                candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
277            }
278        }
279    }
280}
281
282pub trait DeviceMapper: Debug {
283    // === DURING RUNTIME ===
284    /// Map during runtime
285    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
286
287    // === DURING LOADING TIME ===
288    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
289    fn set_device(
290        &self,
291        layer: usize,
292        varbuilder: ShardedVarBuilder,
293        loading_isq: bool,
294    ) -> ShardedVarBuilder;
295    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
296    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
297    fn get_unique_devices(&self) -> Vec<Device>;
298    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
299    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
300    /// Set non mapped layer device. This is for ISQ + device mapping support
301    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
302    fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
303    fn num_device_mapping_layers(&self) -> usize;
304    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
305
306    // === IMMEDIATELY AFTER INIT ===
307    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
308}
309
310#[derive(Debug)]
311/// A device mapper which does device mapping per hidden layer.
312pub struct LayerDeviceMapper {
313    mappings: Vec<Device>,
314    nm_device: Device,
315}
316
317impl DeviceMapper for LayerDeviceMapper {
318    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
319        input.to_device(&self.mappings[layer])
320    }
321    fn set_device<'a>(
322        &self,
323        layer: usize,
324        varbuilder: ShardedVarBuilder,
325        loading_isq: bool,
326    ) -> ShardedVarBuilder {
327        if loading_isq {
328            return varbuilder;
329        }
330        varbuilder.set_device(self.mappings[layer].clone())
331    }
332    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
333        if loading_isq {
334            return Some(&self.nm_device);
335        }
336        self.mappings.get(layer)
337    }
338    fn get_unique_devices(&self) -> Vec<Device> {
339        self.mappings.iter().fold(Vec::new(), |mut acc, device| {
340            if !acc.iter().any(|d| d.same_device(device)) {
341                acc.push(device.clone());
342            }
343            acc
344        })
345    }
346    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
347        if loading_isq {
348            x.to_device(&Device::Cpu)
349        } else {
350            x.to_device(&self.nm_device)
351        }
352    }
353    fn set_nm_device<'a>(
354        &self,
355        varbuilder: ShardedVarBuilder,
356        loading_isq: bool,
357    ) -> ShardedVarBuilder {
358        if loading_isq {
359            varbuilder
360        } else {
361            varbuilder.set_device(self.nm_device.clone())
362        }
363    }
364    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
365        dtype
366            .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
367            .map_err(candle_core::Error::msg)
368    }
369    fn num_device_mapping_layers(&self) -> usize {
370        self.mappings.len()
371    }
372    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
373        let id = mistralrs_quant::Id::new();
374        Ok(Arc::new(mistralrs_quant::Comm::from_device(
375            id,
376            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
377            0,
378            1,
379        )?))
380    }
381}
382
383#[derive(Debug)]
384pub struct DummyDeviceMapper {
385    pub(crate) nm_device: Device,
386}
387
388impl DeviceMapper for DummyDeviceMapper {
389    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
390        Ok(input)
391    }
392    fn set_device<'a>(
393        &self,
394        _: usize,
395        varbuilder: ShardedVarBuilder,
396        loading_isq: bool,
397    ) -> ShardedVarBuilder {
398        if loading_isq {
399            varbuilder.set_device(Device::Cpu)
400        } else {
401            varbuilder.set_device(self.nm_device.clone())
402        }
403    }
404    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
405        Some(&self.nm_device)
406    }
407    fn get_unique_devices(&self) -> Vec<Device> {
408        vec![self.nm_device.clone()]
409    }
410    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
411        if loading_isq {
412            x.to_device(&Device::Cpu)
413        } else {
414            x.to_device(&self.nm_device)
415        }
416    }
417    fn set_nm_device<'a>(
418        &self,
419        varbuilder: ShardedVarBuilder,
420        loading_isq: bool,
421    ) -> ShardedVarBuilder {
422        if loading_isq {
423            varbuilder.set_device(Device::Cpu)
424        } else {
425            varbuilder.set_device(self.nm_device.clone())
426        }
427    }
428    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
429        dtype
430            .try_into_dtype(&[&self.nm_device])
431            .map_err(candle_core::Error::msg)
432    }
433    fn num_device_mapping_layers(&self) -> usize {
434        // Effectively one layer
435        1
436    }
437    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
438        let id = mistralrs_quant::Id::new();
439        Ok(Arc::new(mistralrs_quant::Comm::from_device(
440            id,
441            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
442            0,
443            1,
444        )?))
445    }
446}
447
448#[derive(Debug)]
449pub struct NcclDeviceMapper {
450    nm_device: Device,
451    model_layers: usize,
452    comm: Option<Arc<mistralrs_quant::Comm>>,
453}
454
455impl DeviceMapper for NcclDeviceMapper {
456    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
457        Ok(input)
458    }
459    fn set_device<'a>(
460        &self,
461        _: usize,
462        varbuilder: ShardedVarBuilder,
463        loading_isq: bool,
464    ) -> ShardedVarBuilder {
465        if loading_isq {
466            varbuilder.set_device(Device::Cpu)
467        } else {
468            varbuilder.set_device(self.nm_device.clone())
469        }
470    }
471    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
472        Some(&self.nm_device)
473    }
474    fn get_unique_devices(&self) -> Vec<Device> {
475        vec![self.nm_device.clone()]
476    }
477    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
478        if loading_isq {
479            x.to_device(&Device::Cpu)
480        } else {
481            x.to_device(&self.nm_device)
482        }
483    }
484    fn set_nm_device<'a>(
485        &self,
486        varbuilder: ShardedVarBuilder,
487        loading_isq: bool,
488    ) -> ShardedVarBuilder {
489        if loading_isq {
490            varbuilder.set_device(Device::Cpu)
491        } else {
492            varbuilder.set_device(self.nm_device.clone())
493        }
494    }
495    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
496        dtype
497            .try_into_dtype(&[&self.nm_device])
498            .map_err(candle_core::Error::msg)
499    }
500    fn num_device_mapping_layers(&self) -> usize {
501        self.model_layers
502    }
503    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
504        if let Some(comm) = &self.comm {
505            Ok(comm.clone())
506        } else {
507            let id = mistralrs_quant::Id::new();
508            Ok(Arc::new(mistralrs_quant::Comm::from_device(
509                id,
510                self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
511                0,
512                1,
513            )?))
514        }
515    }
516}
517
518#[derive(Debug)]
519#[allow(dead_code)]
520/// A device mapper which does device mapping per hidden layer.
521pub struct NcclPipelineParallelMapper {
522    mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
523    nm_device: Device,
524}
525
526impl DeviceMapper for NcclPipelineParallelMapper {
527    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
528        input.to_device(&self.mappings[layer].1)
529    }
530    fn set_device<'a>(
531        &self,
532        layer: usize,
533        varbuilder: ShardedVarBuilder,
534        loading_isq: bool,
535    ) -> ShardedVarBuilder {
536        if loading_isq {
537            return varbuilder;
538        }
539        varbuilder.set_device(self.mappings[layer].1.clone())
540    }
541    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
542        if loading_isq {
543            return Some(&self.nm_device);
544        }
545        self.mappings.get(layer).map(|(_, x)| x)
546    }
547    fn get_unique_devices(&self) -> Vec<Device> {
548        self.mappings
549            .iter()
550            .fold(Vec::new(), |mut acc, (_, device)| {
551                if !acc.iter().any(|d| d.same_device(device)) {
552                    acc.push(device.clone());
553                }
554                acc
555            })
556    }
557    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
558        if loading_isq {
559            x.to_device(&Device::Cpu)
560        } else {
561            x.to_device(&self.nm_device)
562        }
563    }
564    fn set_nm_device<'a>(
565        &self,
566        varbuilder: ShardedVarBuilder,
567        loading_isq: bool,
568    ) -> ShardedVarBuilder {
569        if loading_isq {
570            varbuilder
571        } else {
572            varbuilder.set_device(self.nm_device.clone())
573        }
574    }
575    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
576        dtype
577            .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
578            .map_err(candle_core::Error::msg)
579    }
580    fn num_device_mapping_layers(&self) -> usize {
581        self.mappings.len()
582    }
583    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
584        Ok(self.mappings[layer_idx].0.clone())
585    }
586}
587
588/// Get all devices on the same device type but different ordinals
589pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
590    let mut devices = Vec::new();
591    match base {
592        Device::Cpu => return Ok(vec![Device::Cpu]),
593        Device::Cuda(_) => {
594            let mut ord = 0;
595            let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
596                candle_core::bail!("location and device do not match");
597            };
598            loop {
599                if base_ord == ord {
600                    devices.push(base.clone());
601                    ord += 1;
602                    continue;
603                }
604                // Needs to be without a stream as PagedAttention doesn't like it otherwise.
605                if let Ok(dev) = Device::new_cuda(ord) {
606                    devices.push(dev);
607                    ord += 1;
608                } else {
609                    break;
610                }
611            }
612        }
613        #[cfg(not(feature = "metal"))]
614        Device::Metal(_) => {
615            candle_core::bail!("Not compiled with metal features, but have a metal device.");
616        }
617        #[cfg(feature = "metal")]
618        Device::Metal(_) => {
619            #[cfg(feature = "metal")]
620            let total_ords = candle_metal_kernels::metal::Device::all().len();
621            #[cfg(not(feature = "metal"))]
622            let total_ords = 0;
623            let mut ord = 0;
624            let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
625                candle_core::bail!("location and device do not match");
626            };
627            loop {
628                if base_ord == ord {
629                    devices.push(base.clone());
630                    ord += 1;
631                    continue;
632                }
633                if total_ords == ord {
634                    break;
635                }
636                if let Ok(dev) = Device::new_metal(ord) {
637                    devices.push(dev);
638                    ord += 1;
639                } else {
640                    break;
641                }
642            }
643        }
644    }
645    Ok(devices)
646}