use metal::Buffer;
use std::fmt;
#[derive(Debug)]
pub enum MetalGraphError {
DeviceNotFound,
CompilationFailed(String),
BufferCreationFailed,
EncodingFailed(String),
ExecutionFailed(String),
}
impl fmt::Display for MetalGraphError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DeviceNotFound => write!(f, "no Metal-capable GPU device found"),
Self::CompilationFailed(msg) => write!(f, "MSL compilation failed: {msg}"),
Self::BufferCreationFailed => write!(f, "Metal buffer allocation failed"),
Self::EncodingFailed(msg) => write!(f, "Metal encoding failed: {msg}"),
Self::ExecutionFailed(msg) => write!(f, "Metal execution failed: {msg}"),
}
}
}
impl std::error::Error for MetalGraphError {}
pub struct MetalWeightHandle {
pub(crate) buffer: Buffer,
pub(crate) byte_len: usize,
}
impl MetalWeightHandle {
pub fn byte_len(&self) -> usize {
self.byte_len
}
}
impl fmt::Debug for MetalWeightHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MetalWeightHandle")
.field("byte_len", &self.byte_len)
.finish()
}
}