kyro 0.1.1

A high-performance ML inference engine
use candle_core::{Device, Result, Tensor};
use candle_nn::{Module, VarBuilder};

pub struct RmsNorm {
    inner: candle_nn::RmsNorm,
}

impl RmsNorm {
    #[allow(dead_code)]
    pub fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
        let inner = candle_nn::rms_norm(dim, eps, vb)?;
        Ok(Self { inner })
    }
}

impl Module for RmsNorm {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        self.inner.forward(x)
    }
}

#[allow(dead_code)]
pub struct RotaryEmbedding {
    inv_freq: Tensor,
    max_seq_len: usize,
}

impl RotaryEmbedding {
    #[allow(dead_code)]
    pub fn new(dim: usize, max_seq_len: usize, device: &Device) -> Result<Self> {
        Ok(Self {
            inv_freq: Tensor::new(
                (0..dim / 2)
                    .map(|i| {
                        let x = i as f64;
                        1.0 / (10000.0f64.powf(2.0 * x / dim as f64))
                    })
                    .collect::<Vec<_>>()
                    .as_slice(),
                device,
            )?
            .to_dtype(candle_core::DType::F32)?,
            max_seq_len,
        })
    }

    pub fn apply(&self, x: &Tensor, _index: usize) -> Result<Tensor> {
        Ok(x.clone())
    }
}