Skip to main content

hanzo_engine/
device_map.rs

1use std::{collections::HashMap, fmt::Debug, sync::Arc};
2
3use crate::{
4    pipeline::AutoDeviceMapParams, utils::debug::DeviceRepr, MemoryUsage, Topology, TryIntoDType,
5};
6use hanzo_ml::{DType, Device, DeviceLocation, Result, Tensor};
7use hanzo_quant::log::once_log_info;
8use hanzo_quant::ShardedVarBuilder;
9use serde::{Deserialize, Serialize};
10use tracing::info;
11
12#[derive(Debug, Default, Deserialize, Serialize, Clone)]
13pub struct DeviceLayerMapMetadata {
14    pub ordinal: usize,
15    pub layers: usize,
16}
17
18#[derive(Debug, Clone)]
19pub enum DeviceMapSetting {
20    /// Manual device mapping.
21    Map(DeviceMapMetadata),
22    /// Automatic device mapping (recommended).
23    Auto(AutoDeviceMapParams),
24    /// Dummy device mapping for a NCCL pipeline
25    DummyNccl { nm_device: Device },
26    /// Real device mapping for a NCCL pipeline
27    Nccl {
28        nm_device: Device,
29        comm: Arc<hanzo_quant::Comm>,
30    },
31}
32
33#[derive(Debug, Default, Deserialize, Clone)]
34/// Metadata to initialize the device mapper.
35pub struct DeviceMapMetadata {
36    device_layers: Option<Vec<DeviceLayerMapMetadata>>,
37    host_layers: Option<usize>,
38}
39
40impl DeviceMapMetadata {
41    pub fn from_num_device_layers(device_layers: Vec<DeviceLayerMapMetadata>) -> Self {
42        Self {
43            device_layers: Some(device_layers),
44            host_layers: None,
45        }
46    }
47    /// A device mapper to not map device.
48    pub fn dummy() -> Self {
49        Self {
50            device_layers: None,
51            host_layers: None,
52        }
53    }
54
55    pub fn device_layers(&self) -> Option<&[DeviceLayerMapMetadata]> {
56        self.device_layers.as_deref()
57    }
58
59    pub fn host_layers(&self) -> Option<usize> {
60        self.host_layers
61    }
62
63    pub fn to_cli_spec(&self) -> Option<String> {
64        let layers = self.device_layers.as_ref()?;
65        if layers.is_empty() {
66            return None;
67        }
68        Some(
69            layers
70                .iter()
71                .map(|l| format!("{}:{}", l.ordinal, l.layers))
72                .collect::<Vec<_>>()
73                .join(";"),
74        )
75    }
76}
77
78impl DeviceMapSetting {
79    /// A device mapper to not map device.
80    pub fn dummy() -> Self {
81        Self::Map(DeviceMapMetadata::dummy())
82    }
83    pub fn into_mapper(
84        &self,
85        model_layers: usize,
86        device: &Device,
87        topology: Option<&Topology>,
88        all_devices: &[Device],
89    ) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
90        match self {
91            Self::Nccl { nm_device, comm } => {
92                once_log_info("Loading model using a NCCL-parallelized pipeline.");
93                Ok(Box::new(NcclDeviceMapper {
94                    nm_device: nm_device.clone(),
95                    model_layers,
96                    comm: Some(comm.clone()),
97                }))
98            }
99
100            Self::DummyNccl { nm_device } => {
101                once_log_info("Loading model using a NCCL-parallelized pipeline.");
102                Ok(Box::new(NcclDeviceMapper {
103                    nm_device: nm_device.clone(),
104                    model_layers,
105                    comm: None,
106                }))
107            }
108
109            Self::Map(DeviceMapMetadata {
110                device_layers,
111                host_layers,
112            }) => {
113                if let Some(topology) = topology {
114                    if topology.layers.iter().all(|x| x.is_none()) {
115                        return Ok(Box::new(DummyDeviceMapper {
116                            nm_device: device.clone(),
117                        }));
118                    } else {
119                        let layers = topology
120                            .layers
121                            .iter()
122                            .map(|layer| {
123                                layer
124                                    .as_ref()
125                                    .map(|x| x.device.clone().unwrap_or(device.clone()))
126                                    .unwrap_or(device.clone())
127                            })
128                            .collect::<Vec<_>>();
129
130                        info!("Loading model according to the following repeating layer mappings based on topology:");
131                        for (i, dev) in layers.iter().enumerate() {
132                            info!("Layer {i}: {}", dev.device_pretty_repr());
133                        }
134
135                        return Ok(Box::new(LayerDeviceMapper {
136                            mappings: layers,
137                            nm_device: device.clone(),
138                        }));
139                    }
140                }
141
142                // How many device layers
143                // Clamp to max of model layers
144                let n_device_layers = if let Some(layers) = &device_layers {
145                    layers
146                        .iter()
147                        .map(|metadata| metadata.layers)
148                        .sum::<usize>()
149                        .clamp(0, model_layers)
150                } else {
151                    return Ok(Box::new(DummyDeviceMapper {
152                        nm_device: device.clone(),
153                    }));
154                };
155                // How many host (cpu) layers, defaulting to automatically filling the rest.
156                // If n_device_layers > model_layers, n_host_layers = 0
157                let n_host_layers =
158                    host_layers.unwrap_or(model_layers.saturating_sub(n_device_layers));
159                if n_device_layers + n_host_layers != model_layers {
160                    hanzo_ml::bail!("Expected the total number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
161                }
162                once_log_info(format!("Model has {model_layers} repeating layers."));
163
164                // Handle multi-GPU mapping here
165                let mut combined = Vec::with_capacity(model_layers);
166                if device_layers
167                    .as_ref()
168                    .is_some_and(|layers| layers.len() == 1)
169                {
170                    combined.extend(vec![device.clone(); n_device_layers]);
171                } else {
172                    let original_seed = if !device.is_cpu() {
173                        Some(device.get_current_seed()?)
174                    } else {
175                        None
176                    };
177                    for DeviceLayerMapMetadata { ordinal, layers } in
178                        device_layers.as_ref().unwrap()
179                    {
180                        let dev = match device.location() {
181                            DeviceLocation::Cpu => Device::Cpu,
182                            #[cfg(feature = "rocm")]
183                            DeviceLocation::Rocm { .. } => device.clone(),
184                            #[cfg(feature = "vulkan")]
185                            DeviceLocation::Vulkan { .. } => device.clone(),
186                            DeviceLocation::Cuda { gpu_id: device_ord } => {
187                                if device_ord == *ordinal {
188                                    device.clone()
189                                } else {
190                                    let cuda_device = all_devices
191                                        .iter()
192                                        .filter(|d| d.is_cuda())
193                                        .map(|d| {
194                                            // should implement this in hanzo-ml and get the ordinal back from the device location directly
195                                            let ordinal = match d.location() {
196                                                DeviceLocation::Cpu => 0,
197                                                DeviceLocation::Cuda { gpu_id } => gpu_id,
198                                                DeviceLocation::Metal { gpu_id } => gpu_id,
199                                                #[cfg(feature = "rocm")]
200                                                DeviceLocation::Rocm { gpu_id } => gpu_id,
201                                                #[cfg(feature = "vulkan")]
202                                                DeviceLocation::Vulkan { gpu_id } => gpu_id,
203                                            };
204                                            (d.clone(), ordinal)
205                                        })
206                                        .find(|(_, other_device_ordinal)| {
207                                            other_device_ordinal == ordinal
208                                        });
209
210                                    if let Some((device, _)) = cuda_device {
211                                        device
212                                    } else {
213                                        hanzo_ml::bail!(
214                                            "Could not find cuda device with ordinal {}",
215                                            ordinal
216                                        )
217                                    }
218                                }
219                            }
220                            DeviceLocation::Metal { gpu_id: device_ord } => {
221                                if device_ord == *ordinal {
222                                    device.clone()
223                                } else {
224                                    Device::new_metal(*ordinal)?
225                                }
226                            }
227                        };
228                        if !device.is_cpu() {
229                            dev.set_seed(original_seed.unwrap())?;
230                        }
231                        combined.extend(vec![dev; *layers]);
232                    }
233                }
234
235                // Always put the CPU layers at the end so that we reduce dtoh and htod copies
236                combined.extend(vec![Device::Cpu; n_host_layers]);
237
238                // Sanity
239                assert_eq!(combined.len(), model_layers);
240
241                // Print it out
242                {
243                    once_log_info(
244                        "Loading model according to the following repeating layer mappings:",
245                    );
246                    let mut start_index = 0;
247                    let mut current_dev = &combined[0];
248
249                    // Iterate starting from index 1 to detect when the variant changes
250                    for (i, variant) in combined.iter().enumerate().skip(1) {
251                        // If the variant changes, print the previous continuous block
252                        if !variant.same_device(current_dev) {
253                            once_log_info(format!(
254                                "Layers {}-{}: {} ({} GB)",
255                                start_index,
256                                i - 1,
257                                current_dev.device_pretty_repr(),
258                                MemoryUsage
259                                    .query(current_dev)?
260                                    .total()
261                                    .div_ceil(1024 * 1024 * 1024),
262                            ));
263                            start_index = i; // start a new range
264                            current_dev = variant;
265                        }
266                    }
267
268                    once_log_info(format!(
269                        "Layers {}-{}: {} ({} GB)",
270                        start_index,
271                        combined.len() - 1,
272                        current_dev.device_pretty_repr(),
273                        MemoryUsage
274                            .query(current_dev)?
275                            .total()
276                            .div_ceil(1024 * 1024 * 1024),
277                    ));
278                }
279
280                Ok(Box::new(LayerDeviceMapper {
281                    mappings: combined,
282                    nm_device: device.clone(),
283                }))
284            }
285            Self::Auto(_) => {
286                hanzo_ml::bail!(".into_mapper does not work on Auto device map, convert it to a Map with DeviceMappedModelLoader::get_device_layers")
287            }
288        }
289    }
290}
291
292pub trait DeviceMapper: Debug {
293    // === DURING RUNTIME ===
294    /// Map during runtime
295    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor>;
296
297    // === DURING LOADING TIME ===
298    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
299    fn set_device(
300        &self,
301        layer: usize,
302        varbuilder: ShardedVarBuilder,
303        loading_isq: bool,
304    ) -> ShardedVarBuilder;
305    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
306    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device>;
307    fn get_unique_devices(&self) -> Vec<Device>;
308    /// If ISQ layer, then do not change the device (return None). *They will do it later in NormalModel::quantize*
309    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor>;
310    /// Set non mapped layer device. This is for ISQ + device mapping support
311    /// If ISQ layer, then do not change the device. *They will do it later in NormalModel::quantize*
312    fn set_nm_device(&self, varbuilder: ShardedVarBuilder, loading_isq: bool) -> ShardedVarBuilder;
313    fn num_device_mapping_layers(&self) -> usize;
314    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>>;
315
316    // === IMMEDIATELY AFTER INIT ===
317    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType>;
318}
319
320#[derive(Debug)]
321/// A device mapper which does device mapping per hidden layer.
322pub struct LayerDeviceMapper {
323    mappings: Vec<Device>,
324    nm_device: Device,
325}
326
327impl DeviceMapper for LayerDeviceMapper {
328    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
329        input.to_device(&self.mappings[layer])
330    }
331    fn set_device<'a>(
332        &self,
333        layer: usize,
334        varbuilder: ShardedVarBuilder,
335        loading_isq: bool,
336    ) -> ShardedVarBuilder {
337        if loading_isq {
338            return varbuilder;
339        }
340        varbuilder.set_device(self.mappings[layer].clone())
341    }
342    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
343        if loading_isq {
344            return Some(&self.nm_device);
345        }
346        self.mappings.get(layer)
347    }
348    fn get_unique_devices(&self) -> Vec<Device> {
349        self.mappings.iter().fold(Vec::new(), |mut acc, device| {
350            if !acc.iter().any(|d| d.same_device(device)) {
351                acc.push(device.clone());
352            }
353            acc
354        })
355    }
356    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
357        if loading_isq {
358            x.to_device(&Device::Cpu)
359        } else {
360            x.to_device(&self.nm_device)
361        }
362    }
363    fn set_nm_device<'a>(
364        &self,
365        varbuilder: ShardedVarBuilder,
366        loading_isq: bool,
367    ) -> ShardedVarBuilder {
368        if loading_isq {
369            varbuilder
370        } else {
371            varbuilder.set_device(self.nm_device.clone())
372        }
373    }
374    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
375        dtype
376            .try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
377            .map_err(hanzo_ml::Error::msg)
378    }
379    fn num_device_mapping_layers(&self) -> usize {
380        self.mappings.len()
381    }
382    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
383        let id = hanzo_quant::Id::new();
384        Ok(Arc::new(hanzo_quant::Comm::from_device(
385            id,
386            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
387            0,
388            1,
389        )?))
390    }
391}
392
393#[derive(Debug)]
394pub struct DummyDeviceMapper {
395    pub(crate) nm_device: Device,
396}
397
398impl DeviceMapper for DummyDeviceMapper {
399    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
400        Ok(input)
401    }
402    fn set_device<'a>(
403        &self,
404        _: usize,
405        varbuilder: ShardedVarBuilder,
406        loading_isq: bool,
407    ) -> ShardedVarBuilder {
408        if loading_isq {
409            varbuilder.set_device(Device::Cpu)
410        } else {
411            varbuilder.set_device(self.nm_device.clone())
412        }
413    }
414    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
415        Some(&self.nm_device)
416    }
417    fn get_unique_devices(&self) -> Vec<Device> {
418        vec![self.nm_device.clone()]
419    }
420    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
421        if loading_isq {
422            x.to_device(&Device::Cpu)
423        } else {
424            x.to_device(&self.nm_device)
425        }
426    }
427    fn set_nm_device<'a>(
428        &self,
429        varbuilder: ShardedVarBuilder,
430        loading_isq: bool,
431    ) -> ShardedVarBuilder {
432        if loading_isq {
433            varbuilder.set_device(Device::Cpu)
434        } else {
435            varbuilder.set_device(self.nm_device.clone())
436        }
437    }
438    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
439        dtype
440            .try_into_dtype(&[&self.nm_device])
441            .map_err(hanzo_ml::Error::msg)
442    }
443    fn num_device_mapping_layers(&self) -> usize {
444        // Effectively one layer
445        1
446    }
447    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
448        let id = hanzo_quant::Id::new();
449        Ok(Arc::new(hanzo_quant::Comm::from_device(
450            id,
451            self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
452            0,
453            1,
454        )?))
455    }
456}
457
458#[derive(Debug)]
459pub struct NcclDeviceMapper {
460    nm_device: Device,
461    model_layers: usize,
462    comm: Option<Arc<hanzo_quant::Comm>>,
463}
464
465impl DeviceMapper for NcclDeviceMapper {
466    fn map(&self, input: Tensor, _: usize) -> Result<Tensor> {
467        Ok(input)
468    }
469    fn set_device<'a>(
470        &self,
471        _: usize,
472        varbuilder: ShardedVarBuilder,
473        loading_isq: bool,
474    ) -> ShardedVarBuilder {
475        if loading_isq {
476            varbuilder.set_device(Device::Cpu)
477        } else {
478            varbuilder.set_device(self.nm_device.clone())
479        }
480    }
481    fn device_for(&self, _: usize, _loading_isq: bool) -> Option<&Device> {
482        Some(&self.nm_device)
483    }
484    fn get_unique_devices(&self) -> Vec<Device> {
485        vec![self.nm_device.clone()]
486    }
487    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
488        if loading_isq {
489            x.to_device(&Device::Cpu)
490        } else {
491            x.to_device(&self.nm_device)
492        }
493    }
494    fn set_nm_device<'a>(
495        &self,
496        varbuilder: ShardedVarBuilder,
497        loading_isq: bool,
498    ) -> ShardedVarBuilder {
499        if loading_isq {
500            varbuilder.set_device(Device::Cpu)
501        } else {
502            varbuilder.set_device(self.nm_device.clone())
503        }
504    }
505    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
506        dtype
507            .try_into_dtype(&[&self.nm_device])
508            .map_err(hanzo_ml::Error::msg)
509    }
510    fn num_device_mapping_layers(&self) -> usize {
511        self.model_layers
512    }
513    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
514        if let Some(comm) = &self.comm {
515            Ok(comm.clone())
516        } else {
517            let id = hanzo_quant::Id::new();
518            Ok(Arc::new(hanzo_quant::Comm::from_device(
519                id,
520                self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
521                0,
522                1,
523            )?))
524        }
525    }
526}
527
528#[derive(Debug)]
529#[allow(dead_code)]
530/// A device mapper which does device mapping per hidden layer.
531pub struct NcclPipelineParallelMapper {
532    mappings: Vec<(Arc<hanzo_quant::Comm>, Device)>,
533    nm_device: Device,
534}
535
536impl DeviceMapper for NcclPipelineParallelMapper {
537    fn map(&self, input: Tensor, layer: usize) -> Result<Tensor> {
538        input.to_device(&self.mappings[layer].1)
539    }
540    fn set_device<'a>(
541        &self,
542        layer: usize,
543        varbuilder: ShardedVarBuilder,
544        loading_isq: bool,
545    ) -> ShardedVarBuilder {
546        if loading_isq {
547            return varbuilder;
548        }
549        varbuilder.set_device(self.mappings[layer].1.clone())
550    }
551    fn device_for(&self, layer: usize, loading_isq: bool) -> Option<&Device> {
552        if loading_isq {
553            return Some(&self.nm_device);
554        }
555        self.mappings.get(layer).map(|(_, x)| x)
556    }
557    fn get_unique_devices(&self) -> Vec<Device> {
558        self.mappings
559            .iter()
560            .fold(Vec::new(), |mut acc, (_, device)| {
561                if !acc.iter().any(|d| d.same_device(device)) {
562                    acc.push(device.clone());
563                }
564                acc
565            })
566    }
567    fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
568        if loading_isq {
569            x.to_device(&Device::Cpu)
570        } else {
571            x.to_device(&self.nm_device)
572        }
573    }
574    fn set_nm_device<'a>(
575        &self,
576        varbuilder: ShardedVarBuilder,
577        loading_isq: bool,
578    ) -> ShardedVarBuilder {
579        if loading_isq {
580            varbuilder
581        } else {
582            varbuilder.set_device(self.nm_device.clone())
583        }
584    }
585    fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
586        dtype
587            .try_into_dtype(&self.mappings.iter().map(|(_, x)| x).collect::<Vec<_>>())
588            .map_err(hanzo_ml::Error::msg)
589    }
590    fn num_device_mapping_layers(&self) -> usize {
591        self.mappings.len()
592    }
593    fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<hanzo_quant::Comm>> {
594        Ok(self.mappings[layer_idx].0.clone())
595    }
596}
597
598/// Pre-creates one copy of an attention mask per unique device used by a `DeviceMapper`.
599///
600/// Instead of calling `mask.to_device(xs.device())` inside every layer loop iteration
601/// (which allocates new GPU storage each time when src != dst device), create a
602/// `DeviceMappedMask` once before the loop and call `.get(device)` inside the loop
603/// for zero-allocation mask lookup.
604pub enum DeviceMappedMask {
605    /// No masking.
606    None,
607    /// Flash attention handles causality. No tensor needed.
608    CausalFlash,
609    /// Explicit mask tensor, replicated to each device.
610    Custom(HashMap<DeviceLocation, Tensor>),
611}
612
613impl DeviceMappedMask {
614    /// Build a device-mapped mask from an [`AttentionMask`].
615    pub fn new(mask: crate::attention::AttentionMask, mapper: &dyn DeviceMapper) -> Result<Self> {
616        match mask {
617            crate::attention::AttentionMask::None => Ok(Self::None),
618            crate::attention::AttentionMask::CausalFlash => Ok(Self::CausalFlash),
619            crate::attention::AttentionMask::Custom(tensor) => {
620                let mut masks = HashMap::new();
621                for device in mapper.get_unique_devices() {
622                    let loc = device.location();
623                    if let std::collections::hash_map::Entry::Vacant(e) = masks.entry(loc) {
624                        e.insert(tensor.to_device(&device)?);
625                    }
626                }
627                Ok(Self::Custom(masks))
628            }
629        }
630    }
631
632    /// Build a device-mapped mask from a single tensor on its current device.
633    pub fn from_single(mask: crate::attention::AttentionMask) -> Self {
634        match mask {
635            crate::attention::AttentionMask::None => Self::None,
636            crate::attention::AttentionMask::CausalFlash => Self::CausalFlash,
637            crate::attention::AttentionMask::Custom(tensor) => {
638                let mut masks = HashMap::new();
639                masks.insert(tensor.device().location(), tensor);
640                Self::Custom(masks)
641            }
642        }
643    }
644
645    /// Look up the [`AttentionMask`] for the given device.
646    pub fn get(&self, device: &Device) -> crate::attention::AttentionMask {
647        match self {
648            Self::None => crate::attention::AttentionMask::None,
649            Self::CausalFlash => crate::attention::AttentionMask::CausalFlash,
650            Self::Custom(masks) => {
651                let tensor = masks
652                    .get(&device.location())
653                    .expect("DeviceMappedMask: device not in mapper's unique devices");
654                crate::attention::AttentionMask::Custom(tensor.clone())
655            }
656        }
657    }
658}
659
660/// Get all devices on the same device type but different ordinals
661pub fn get_all_similar_devices(base: &Device) -> Result<Vec<Device>> {
662    let mut devices = Vec::new();
663    match base {
664        Device::Cpu => return Ok(vec![Device::Cpu]),
665        #[cfg(feature = "rocm")]
666        Device::Rocm(_) => return Ok(vec![base.clone()]),
667        #[cfg(feature = "vulkan")]
668        Device::Vulkan(_) => return Ok(vec![base.clone()]),
669        Device::Cuda(_) => {
670            let mut ord = 0;
671            let DeviceLocation::Cuda { gpu_id: base_ord } = base.location() else {
672                hanzo_ml::bail!("location and device do not match");
673            };
674            loop {
675                if base_ord == ord {
676                    devices.push(base.clone());
677                    ord += 1;
678                    continue;
679                }
680                // Needs to be without a stream as PagedAttention doesn't like it otherwise.
681                if let Ok(dev) = Device::new_cuda(ord) {
682                    devices.push(dev);
683                    ord += 1;
684                } else {
685                    break;
686                }
687            }
688        }
689        #[cfg(not(feature = "metal"))]
690        Device::Metal(_) => {
691            hanzo_ml::bail!("Not compiled with metal features, but have a metal device.");
692        }
693        #[cfg(feature = "metal")]
694        Device::Metal(_) => {
695            #[cfg(feature = "metal")]
696            let total_ords = hanzo_metal_kernels::metal::Device::all().len();
697            #[cfg(not(feature = "metal"))]
698            let total_ords = 0;
699            let mut ord = 0;
700            let DeviceLocation::Metal { gpu_id: base_ord } = base.location() else {
701                hanzo_ml::bail!("location and device do not match");
702            };
703            loop {
704                if base_ord == ord {
705                    devices.push(base.clone());
706                    ord += 1;
707                    continue;
708                }
709                if total_ords == ord {
710                    break;
711                }
712                if let Ok(dev) = Device::new_metal(ord) {
713                    devices.push(dev);
714                    ord += 1;
715                } else {
716                    break;
717                }
718            }
719        }
720    }
721    Ok(devices)
722}