burn-mpsgraph 0.0.1

Apple MPSGraph backend for the Burn deep learning framework
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
use burn_backend::{DeviceId, DeviceOps};

/// Device for the MPSGraph backend (Apple GPU).
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct MpsGraphDevice {
    pub index: usize,
}

impl DeviceOps for MpsGraphDevice {}

impl burn_backend::Device for MpsGraphDevice {
    fn from_id(id: DeviceId) -> Self { Self { index: id.index_id as usize } }
    fn to_id(&self) -> DeviceId { DeviceId { type_id: 1, index_id: self.index as u32 } }
    fn device_count(_type_id: u16) -> usize { 1 }
}