1use std::{fmt::Debug, sync::Arc};
2
3use crate::{
4 pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use candle_core::{DType, Device, DeviceLocation, Result, Tensor};
7use mistralrs_quant::log::once_log_info;
8use mistralrs_quant::ShardedVarBuilder;
9use serde::{Deserialize, Serialize};
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Serialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14 pub ordinal: usize,
15 pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20 Map(DeviceMapMetadata),
22 Auto(AutoDeviceMapParams),
24 DummyNccl { nm_device: Device },
26 Nccl {
28 nm_device: Device,
29 comm: Arc<mistralrs_quant::Comm>,
30 },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34pub struct DeviceMapMetadata {
36 device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37 host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41 pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42 Self {
43 device_layers: Some(device_layers),
44 host_layers: None,
45 }
46 }
47 pub fn dummy() -> Self {
49 Self {
50 device_layers: None,
51 host_layers: None,
52 }
53 }
54
55 pub fn device_layers(&self) -> Option<&[DeviceLayerMapMetadata]> {
56 self.device_layers.as_deref()
57 }
58
59 pub fn host_layers(&self) -> Option<usize> {
60 self.host_layers
61 }
62
63 pub fn to_cli_spec(&self) -> Option<String> {
64 let layers = self.device_layers.as_ref()?;
65 if layers.is_empty() {
66 return None;
67 }
68 Some(
69 layers
70 .iter()
71 .map(|l| format!("{}:{}", l.ordinal, l.layers))
72 .collect::<Vec<_>>()
73 .join(";"),
74 )
75 }
76}
77
78impl DeviceMapSetting {
79 pub fn dummy() -> Self {
81 Self::Map(DeviceMapMetadata::dummy())
82 }
83 pub fn into_mapper(
84 &self,
85 model_layers: usize,
86 device: &Device,
87 topology: Option<&Topology>,
88 all_devices: &[Device],
89 ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
90 match self {
91 Self::Nccl { nm_device, comm } => {
92 once_log_info("Loading model using a NCCL-parallelized pipeline.");
93 Ok(Box::new(NcclDeviceMapper {
94 nm_device: nm_device.clone(),
95 model_layers,
96 comm: Some(comm.clone()),
97 }))
98 }
99
100 Self::DummyNccl { nm_device } => {
101 once_log_info("Loading model using a NCCL-parallelized pipeline.");
102 Ok(Box::new(NcclDeviceMapper {
103 nm_device: nm_device.clone(),
104 model_layers,
105 comm: None,
106 }))
107 }
108
109 Self::Map(DeviceMapMetadata {
110 device_layers,
111 host_layers,
112 }) => {
113 if let Some(topology) = topology {
114 if topology.layers.iter().all(|x| x.is_none()) {
115 return Ok(Box::new(DummyDeviceMapper {
116 nm_device: device.clone(),
117 }));
118 } else {
119 let layers = topology
120 .layers
121 .iter()
122 .map(|layer| {
123 layer
124 .as_ref()
125 .map(|x| x.device.clone().unwrap_or(device.clone()))
126 .unwrap_or(device.clone())
127 })
128 .collect::<Vec<_>>();
129
130 info!("Loading model according to the following repeating layer mappings based on topology:");
131 for (i, dev) in layers.iter().enumerate() {
132 info!("Layer {i}: {}", dev.device_pretty_repr());
133 }
134
135 return Ok(Box::new(LayerDeviceMapper {
136 mappings: layers,
137 nm_device: device.clone(),
138 }));
139 }
140 }
141
142 let n_device_layers = if let Some(layers) = &device_layers {
145 layers
146 .iter()
147 .map(|metadata| metadata.layers)
148 .sum::<usize>()
149 .clamp(0, model_layers)
150 } else {
151 return Ok(Box::new(DummyDeviceMapper {
152 nm_device: device.clone(),
153 }));
154 };
155 let n_host_layers =
158 host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
159 if n_device_layers + n_host_layers != model_layers {
160 candle_core::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
161 }
162 once_log_info(format!("Model has {model_layers} repeating layers."));
163
164 let mut combined = Vec::with_capacity(model_layers);
166 if device_layers
167 .as_ref()
168 .is_some_and(|layers| layers.len() == 1)
169 {
170 combined.extend(vec![device.clone(); n_device_layers]);
171 } else {
172 let original_seed = if !device.is_cpu() {
173 Some(device.get_current_seed()?)
174 } else {
175 None
176 };
177 for DeviceLayerMapMetadata { ordinal, layers } in
178 device_layers.as_ref().unwrap()
179 {
180 let dev = match device.location() {
181 DeviceLocation::Cpu => Device::Cpu,
182 DeviceLocation::Cuda { gpu_id: device_ord } => {
183 if device_ord == *ordinal {
184 device.clone()
185 } else {
186 let cuda_device = all_devices
187 .iter()
188 .filter(|d| d.is_cuda())
189 .map(|d| {
190 let ordinal = match d.location() {
192 DeviceLocation::Cpu => 0,
193 DeviceLocation::Cuda { gpu_id } => gpu_id,
194 DeviceLocation::Metal { gpu_id } => gpu_id,
195 };
196 (d.clone(), ordinal)
197 })
198 .find(|(_, other_device_ordinal)| {
199 other_device_ordinal == ordinal
200 });
201
202 if let Some((device, _)) = cuda_device {
203 device
204 } else {
205 candle_core::bail!(
206 "Could not find cuda device with ordinal {}",
207 ordinal
208 )
209 }
210 }
211 }
212 DeviceLocation::Metal { gpu_id: device_ord } => {
213 if device_ord == *ordinal {
214 device.clone()
215 } else {
216 Device::new_metal(*ordinal)?
217 }
218 }
219 };
220 if !device.is_cpu() {
221 dev.set_seed(original_seed.unwrap())?;
222 }
223 combined.extend(vec![dev; *layers]);
224 }
225 }
226
227 combined.extend(vec![Device::Cpu; n_host_layers]);
229
230 assert_eq!(combined.len(), model_layers);
232
233 {
235 once_log_info(
236 "Loading model according to the following repeating layer mappings:",
237 );
238 let mut start_index = 0;
239 let mut current_dev = &combined[0];
240
241 for (i, variant) in combined.iter().enumerate().skip(1) {
243 if !variant.same_device(current_dev) {
245 once_log_info(format!(
246 "Layers {}-{}: {} ({} GB)",
247 start_index,
248 i - 1,
249 current_dev.device_pretty_repr(),
250 MemoryUsage
251 .get_total_memory(current_dev)?
252 .div_ceil(1024 * 1024 * 1024),
253 ));
254 start_index = i; current_dev = variant;
256 }
257 }
258
259 once_log_info(format!(
260 "Layers {}-{}: {} ({} GB)",
261 start_index,
262 combined.len() - 1,
263 current_dev.device_pretty_repr(),
264 MemoryUsage
265 .get_total_memory(current_dev)?
266 .div_ceil(1024 * 1024 * 1024),
267 ));
268 }
269
270 Ok(Box::new(LayerDeviceMapper {
271 mappings: combined,
272 nm_device: device.clone(),
273 }))
274 }
275 Self::Auto(_) => {
276 candle_core::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
277 }
278 }
279 }
280}
281
282pub trait DeviceMapper: Debug {
283 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
286
287 fn set_device(
290 &self,
291 layer: usize,
292 varbuilder: ShardedVarBuilder,
293 loading_isq: bool,
294 ) -> ShardedVarBuilder;
295 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
297 fn get_unique_devices(&self) -> Vec<Device>;
298 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
300 fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
303 fn num_device_mapping_layers(&self) -> usize;
304 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>>;
305
306 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
308}
309
310#[derive(Debug)]
311pub struct LayerDeviceMapper {
313 mappings: Vec<Device>,
314 nm_device: Device,
315}
316
317impl DeviceMapper for LayerDeviceMapper {
318 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
319 input.to_device(&self.mappings[layer])
320 }
321 fn set_device<'a>(
322 &self,
323 layer: usize,
324 varbuilder: ShardedVarBuilder,
325 loading_isq: bool,
326 ) -> ShardedVarBuilder {
327 if loading_isq {
328 return varbuilder;
329 }
330 varbuilder.set_device(self.mappings[layer].clone())
331 }
332 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
333 if loading_isq {
334 return Some(&self.nm_device);
335 }
336 self.mappings.get(layer)
337 }
338 fn get_unique_devices(&self) -> Vec<Device> {
339 self.mappings.iter().fold(Vec::new(), |mut acc, device| {
340 if !acc.iter().any(|d| d.same_device(device)) {
341 acc.push(device.clone());
342 }
343 acc
344 })
345 }
346 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
347 if loading_isq {
348 x.to_device(&Device::Cpu)
349 } else {
350 x.to_device(&self.nm_device)
351 }
352 }
353 fn set_nm_device<'a>(
354 &self,
355 varbuilder: ShardedVarBuilder,
356 loading_isq: bool,
357 ) -> ShardedVarBuilder {
358 if loading_isq {
359 varbuilder
360 } else {
361 varbuilder.set_device(self.nm_device.clone())
362 }
363 }
364 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
365 dtype
366 .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
367 .map_err(candle_core::Error::msg)
368 }
369 fn num_device_mapping_layers(&self) -> usize {
370 self.mappings.len()
371 }
372 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
373 let id = mistralrs_quant::Id::new();
374 Ok(Arc::new(mistralrs_quant::Comm::from_device(
375 id,
376 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
377 0,
378 1,
379 )?))
380 }
381}
382
383#[derive(Debug)]
384pub struct DummyDeviceMapper {
385 pub(crate) nm_device: Device,
386}
387
388impl DeviceMapper for DummyDeviceMapper {
389 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
390 Ok(input)
391 }
392 fn set_device<'a>(
393 &self,
394 _: usize,
395 varbuilder: ShardedVarBuilder,
396 loading_isq: bool,
397 ) -> ShardedVarBuilder {
398 if loading_isq {
399 varbuilder.set_device(Device::Cpu)
400 } else {
401 varbuilder.set_device(self.nm_device.clone())
402 }
403 }
404 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
405 Some(&self.nm_device)
406 }
407 fn get_unique_devices(&self) -> Vec<Device> {
408 vec![self.nm_device.clone()]
409 }
410 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
411 if loading_isq {
412 x.to_device(&Device::Cpu)
413 } else {
414 x.to_device(&self.nm_device)
415 }
416 }
417 fn set_nm_device<'a>(
418 &self,
419 varbuilder: ShardedVarBuilder,
420 loading_isq: bool,
421 ) -> ShardedVarBuilder {
422 if loading_isq {
423 varbuilder.set_device(Device::Cpu)
424 } else {
425 varbuilder.set_device(self.nm_device.clone())
426 }
427 }
428 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
429 dtype
430 .try_into_dtype(&[&self.nm_device])
431 .map_err(candle_core::Error::msg)
432 }
433 fn num_device_mapping_layers(&self) -> usize {
434 1
436 }
437 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
438 let id = mistralrs_quant::Id::new();
439 Ok(Arc::new(mistralrs_quant::Comm::from_device(
440 id,
441 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
442 0,
443 1,
444 )?))
445 }
446}
447
448#[derive(Debug)]
449pub struct NcclDeviceMapper {
450 nm_device: Device,
451 model_layers: usize,
452 comm: Option<Arc<mistralrs_quant::Comm>>,
453}
454
455impl DeviceMapper for NcclDeviceMapper {
456 fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
457 Ok(input)
458 }
459 fn set_device<'a>(
460 &self,
461 _: usize,
462 varbuilder: ShardedVarBuilder,
463 loading_isq: bool,
464 ) -> ShardedVarBuilder {
465 if loading_isq {
466 varbuilder.set_device(Device::Cpu)
467 } else {
468 varbuilder.set_device(self.nm_device.clone())
469 }
470 }
471 fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
472 Some(&self.nm_device)
473 }
474 fn get_unique_devices(&self) -> Vec<Device> {
475 vec![self.nm_device.clone()]
476 }
477 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
478 if loading_isq {
479 x.to_device(&Device::Cpu)
480 } else {
481 x.to_device(&self.nm_device)
482 }
483 }
484 fn set_nm_device<'a>(
485 &self,
486 varbuilder: ShardedVarBuilder,
487 loading_isq: bool,
488 ) -> ShardedVarBuilder {
489 if loading_isq {
490 varbuilder.set_device(Device::Cpu)
491 } else {
492 varbuilder.set_device(self.nm_device.clone())
493 }
494 }
495 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
496 dtype
497 .try_into_dtype(&[&self.nm_device])
498 .map_err(candle_core::Error::msg)
499 }
500 fn num_device_mapping_layers(&self) -> usize {
501 self.model_layers
502 }
503 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
504 if let Some(comm) = &self.comm {
505 Ok(comm.clone())
506 } else {
507 let id = mistralrs_quant::Id::new();
508 Ok(Arc::new(mistralrs_quant::Comm::from_device(
509 id,
510 self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
511 0,
512 1,
513 )?))
514 }
515 }
516}
517
518#[derive(Debug)]
519#[allow(dead_code)]
520pub struct NcclPipelineParallelMapper {
522 mappings: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
523 nm_device: Device,
524}
525
526impl DeviceMapper for NcclPipelineParallelMapper {
527 fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
528 input.to_device(&self.mappings[layer].1)
529 }
530 fn set_device<'a>(
531 &self,
532 layer: usize,
533 varbuilder: ShardedVarBuilder,
534 loading_isq: bool,
535 ) -> ShardedVarBuilder {
536 if loading_isq {
537 return varbuilder;
538 }
539 varbuilder.set_device(self.mappings[layer].1.clone())
540 }
541 fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
542 if loading_isq {
543 return Some(&self.nm_device);
544 }
545 self.mappings.get(layer).map(|(_, x)| x)
546 }
547 fn get_unique_devices(&self) -> Vec<Device> {
548 self.mappings
549 .iter()
550 .fold(Vec::new(), |mut acc, (_, device)| {
551 if !acc.iter().any(|d| d.same_device(device)) {
552 acc.push(device.clone());
553 }
554 acc
555 })
556 }
557 fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
558 if loading_isq {
559 x.to_device(&Device::Cpu)
560 } else {
561 x.to_device(&self.nm_device)
562 }
563 }
564 fn set_nm_device<'a>(
565 &self,
566 varbuilder: ShardedVarBuilder,
567 loading_isq: bool,
568 ) -> ShardedVarBuilder {
569 if loading_isq {
570 varbuilder
571 } else {
572 varbuilder.set_device(self.nm_device.clone())
573 }
574 }
575 fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
576 dtype
577 .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
578 .map_err(candle_core::Error::msg)
579 }
580 fn num_device_mapping_layers(&self) -> usize {
581 self.mappings.len()
582 }
583 fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
584 Ok(self.mappings[layer_idx].0.clone())
585 }
586}
587
588pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
590 let mut devices = Vec::new();
591 match base {
592 Device::Cpu => return Ok(vec![Device::Cpu]),
593 Device::Cuda(_) => {
594 let mut ord = 0;
595 let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
596 candle_core::bail!("location and device do not match");
597 };
598 loop {
599 if base_ord == ord {
600 devices.push(base.clone());
601 ord += 1;
602 continue;
603 }
604 if let Ok(dev) = Device::new_cuda(ord) {
606 devices.push(dev);
607 ord += 1;
608 } else {
609 break;
610 }
611 }
612 }
613 #[cfg(not(feature = "metal"))]
614 Device::Metal(_) => {
615 candle_core::bail!("Not compiled with metal features, but have a metal device.");
616 }
617 #[cfg(feature = "metal")]
618 Device::Metal(_) => {
619 #[cfg(feature = "metal")]
620 let total_ords = candle_metal_kernels::metal::Device::all().len();
621 #[cfg(not(feature = "metal"))]
622 let total_ords = 0;
623 let mut ord = 0;
624 let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
625 candle_core::bail!("location and device do not match");
626 };
627 loop {
628 if base_ord == ord {
629 devices.push(base.clone());
630 ord += 1;
631 continue;
632 }
633 if total_ords == ord {
634 break;
635 }
636 if let Ok(dev) = Device::new_metal(ord) {
637 devices.push(dev);
638 ord += 1;
639 } else {
640 break;
641 }
642 }
643 }
644 }
645 Ok(devices)
646}