use crate::error::Result;
use pmetal_bridge::compat::{Array, Param, ops};
#[derive(Debug)]
pub struct Snake {
pub alpha: Param<Array>,
pub alpha_logscale: bool,
}
impl Snake {
pub fn new(channels: i32, alpha_logscale: bool) -> Result<Self> {
let init_val = if alpha_logscale { 0.0 } else { 1.0 };
let alpha = Array::from_f32(init_val);
let alpha = ops::broadcast_to(&alpha, &[1, channels, 1]);
Ok(Self {
alpha: Param::new(alpha),
alpha_logscale,
})
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let alpha = if self.alpha_logscale {
self.alpha.value.exp()
} else {
self.alpha.value.clone()
};
let ax = x.multiply(&alpha);
let sin_ax = ax.sin();
let sin_sq = sin_ax.multiply(&sin_ax);
let scaled = sin_sq.divide(&alpha);
Ok(x.add(&scaled))
}
}
#[derive(Debug)]
pub struct SnakeBeta {
pub alpha: Param<Array>,
pub beta: Param<Array>,
pub logscale: bool,
}
impl SnakeBeta {
pub fn new(channels: i32, logscale: bool) -> Result<Self> {
let init_val = if logscale { 0.0 } else { 1.0 };
let init = Array::from_f32(init_val);
let alpha = ops::broadcast_to(&init, &[1, channels, 1]);
let beta = ops::broadcast_to(&init, &[1, channels, 1]);
Ok(Self {
alpha: Param::new(alpha),
beta: Param::new(beta),
logscale,
})
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let (alpha, beta) = if self.logscale {
(self.alpha.value.exp(), self.beta.value.exp())
} else {
(self.alpha.value.clone(), self.beta.value.clone())
};
let ax = x.multiply(&alpha);
let sin_ax = ax.sin();
let sin_sq = sin_ax.multiply(&sin_ax);
let scaled = sin_sq.divide(&beta);
Ok(x.add(&scaled))
}
}
#[derive(Debug)]
pub struct Activation1d<A> {
pub activation: A,
pub up_ratio: i32,
pub down_ratio: i32,
pub filter: Array,
}
impl<A> Activation1d<A> {
pub fn new(activation: A, up_ratio: Option<i32>, down_ratio: Option<i32>) -> Result<Self> {
let up_ratio = up_ratio.unwrap_or(2);
let down_ratio = down_ratio.unwrap_or(2);
let filter = create_kaiser_filter(12, 0.5 / up_ratio as f32)?;
Ok(Self {
activation,
up_ratio,
down_ratio,
filter,
})
}
}
impl Activation1d<Snake> {
pub fn forward(&self, x: &Array) -> Result<Array> {
let x_up = upsample_1d(x, self.up_ratio)?;
let x_act = self.activation.forward(&x_up)?;
downsample_1d(&x_act, self.down_ratio, &self.filter)
}
}
impl Activation1d<SnakeBeta> {
pub fn forward(&self, x: &Array) -> Result<Array> {
let x_up = upsample_1d(x, self.up_ratio)?;
let x_act = self.activation.forward(&x_up)?;
downsample_1d(&x_act, self.down_ratio, &self.filter)
}
}
fn create_kaiser_filter(taps: i32, cutoff: f32) -> Result<Array> {
let half = taps / 2;
let mut filter = Vec::with_capacity(taps as usize);
for i in 0..taps {
let n = i - half;
let sinc = if n == 0 {
1.0
} else {
let x = std::f32::consts::PI * cutoff * n as f32;
x.sin() / x
};
let window = 0.5 * (1.0 + (std::f32::consts::PI * i as f32 / (taps - 1) as f32).cos());
filter.push(sinc * window * cutoff);
}
let sum: f32 = filter.iter().sum();
for v in &mut filter {
*v /= sum;
}
Ok(Array::from_f32_slice(&filter, &[1, 1, taps]))
}
fn upsample_1d(x: &Array, ratio: i32) -> Result<Array> {
if ratio == 1 {
return Ok(x.clone());
}
let shape = x.shape();
let batch = shape[0];
let channels = shape[1];
let length = shape[2];
let zeros = Array::zeros(&[batch, channels, length, ratio - 1], 10); let x_expanded = x.reshape(&[batch, channels, length, 1]);
let interleaved = ops::concatenate_axis(&[&x_expanded, &zeros], -1);
Ok(interleaved.reshape(&[batch, channels, length * ratio]))
}
fn downsample_1d(x: &Array, ratio: i32, filter: &Array) -> Result<Array> {
if ratio == 1 {
return Ok(x.clone());
}
let shape = x.shape();
let _batch = shape[0];
let channels = shape[1];
let length = shape[2];
let _filter_exp = ops::broadcast_to(filter, &[channels, 1, filter.dim(2)]);
let indices: Vec<i32> = (0..length).step_by(ratio as usize).collect();
let indices_arr = Array::from_f32_slice(
&indices.iter().map(|&i| i as f32).collect::<Vec<_>>(),
&[indices.len() as i32],
)
.as_dtype(7); Ok(x.take_axis(&indices_arr, 2))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snake_forward() {
let snake = Snake::new(4, true).unwrap();
let x = Array::random_normal(&[1, 4, 16], 10);
let y = snake.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 4, 16]);
}
#[test]
fn test_snakebeta_forward() {
let snake = SnakeBeta::new(8, true).unwrap();
let x = Array::random_normal(&[2, 8, 32], 10);
let y = snake.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[2, 8, 32]);
}
#[test]
fn test_snake_values() {
let snake = Snake::new(1, false).unwrap();
let x = Array::zeros(&[1, 1, 4], 10);
let y = snake.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
let sum = y2.sum_all();
sum.eval();
assert!(sum.item_f32().abs() < 1e-5);
}
#[test]
fn test_upsample_1d() {
let x = Array::from_f32_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 1, 4]);
let y = upsample_1d(&x, 2).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 1, 8]);
}
#[test]
fn test_activation1d() {
let snake = Snake::new(4, true).unwrap();
let act1d = Activation1d::new(snake, Some(2), Some(2)).unwrap();
let x = Array::random_normal(&[1, 4, 16], 10);
let y = act1d.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 4, 16]);
}
}