Skip to main content

burn_mpsgraph/
device.rs

1use burn_backend::{DeviceId, DeviceOps};
2
3/// Device for the MPSGraph backend (Apple GPU).
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
5pub struct MpsGraphDevice {
6    pub index: usize,
7}
8
9impl DeviceOps for MpsGraphDevice {}
10
11impl burn_backend::Device for MpsGraphDevice {
12    fn from_id(id: DeviceId) -> Self { Self { index: id.index_id as usize } }
13    fn to_id(&self) -> DeviceId { DeviceId { type_id: 1, index_id: self.index as u32 } }
14    fn device_count(_type_id: u16) -> usize { 1 }
15}