burn_mpsgraph/device.rs
1use burn_backend::{DeviceId, DeviceOps};
2
3/// Device for the MPSGraph backend (Apple GPU).
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
5pub struct MpsGraphDevice {
6 pub index: usize,
7}
8
9impl DeviceOps for MpsGraphDevice {}
10
11impl burn_backend::Device for MpsGraphDevice {
12 fn from_id(id: DeviceId) -> Self { Self { index: id.index_id as usize } }
13 fn to_id(&self) -> DeviceId { DeviceId { type_id: 1, index_id: self.index as u32 } }
14 fn device_count(_type_id: u16) -> usize { 1 }
15}