use metal::foreign_types::ForeignType;
use metal::{CommandBuffer, Device};
use objc2::msg_send;
use objc2::runtime::AnyObject;
use crate::metal::{buffer::MetalBuffer, error::Result};
#[allow(dead_code)]
pub struct MPSBatchNorm {
batch_norm: *mut AnyObject,
}
impl MPSBatchNorm {
pub fn new(
device: &Device,
_num_features: usize,
_epsilon: f32,
_momentum: f32,
) -> Result<Self> {
unsafe {
let class = objc2::class!(MPSCNNBatchNormalization);
let batch_norm: *mut AnyObject = msg_send![class, alloc];
let batch_norm: *mut AnyObject = msg_send![batch_norm,
initWithDevice: device.as_ptr() as *mut AnyObject,
dataSource: std::ptr::null::<AnyObject>() ];
Ok(Self { batch_norm })
}
}
pub fn apply(
&self,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_output: &MetalBuffer,
_training: bool,
) -> Result<()> {
Ok(())
}
}
impl Drop for MPSBatchNorm {
fn drop(&mut self) {
unsafe {
let _: () = msg_send![self.batch_norm, release];
}
}
}
#[allow(dead_code)]
pub struct MPSLayerNorm {
kernel_manager: crate::metal::kernels::KernelManager,
normalized_shape: Vec<usize>,
epsilon: f32,
}
#[allow(dead_code)]
impl MPSLayerNorm {
pub fn new(device: &Device, normalized_shape: Vec<usize>, epsilon: f32) -> Result<Self> {
let kernel_manager = crate::metal::kernels::KernelManager::new(device)?;
Ok(Self {
kernel_manager,
normalized_shape,
epsilon,
})
}
pub fn apply(
&self,
_command_buffer: &CommandBuffer,
_input: &MetalBuffer,
_output: &MetalBuffer,
_gamma: Option<&MetalBuffer>,
_beta: Option<&MetalBuffer>,
) -> Result<()> {
Ok(())
}
}