use crate::*;
use pyo3::prelude::*;
use pyo3::types::{PyIterator, PyList, PyTuple};
use pyo3::Bound;
use zyx::{DType, Tensor, ZyxError};
type ZyxResult<T> = std::result::Result<T, ZyxError>;
fn to_sh_from_tuple(t: (u64, u64)) -> Vec<u64> {
vec![t.0, t.1]
}
fn to_sh(shape: &Bound<'_, PyTuple>) -> Vec<u64> {
if shape.len() == 1 {
let first = shape.get_item(0).unwrap();
if first.is_instance_of::<PyList>() || first.is_instance_of::<PyTuple>() {
let iter = PyIterator::from_object(&first).unwrap();
return iter.filter_map(|item| item.ok().and_then(|v| v.extract::<u64>().ok())).collect();
}
}
shape.as_slice().iter().filter_map(|x| x.extract::<u64>().ok()).collect()
}
#[pymethods]
impl Linear {
#[new]
#[pyo3(signature = (in_features, out_features, bias=true, dtype=DType::F32))]
pub fn py_new(in_features: u64, out_features: u64, bias: bool, dtype: DType) -> ZyxResult<Self> {
Self::new(in_features, out_features, bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl Conv2d {
#[new]
#[pyo3(signature = (in_channels, out_channels, kernel_size, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1, bias=true, dtype=DType::F32))]
pub fn py_new(
in_channels: u64,
out_channels: u64,
kernel_size: (u64, u64),
stride: (u64, u64),
padding: (u64, u64),
dilation: (u64, u64),
groups: u64,
bias: bool,
dtype: DType,
) -> ZyxResult<Self> {
Self::new(in_channels, out_channels, to_sh_from_tuple(kernel_size), to_sh_from_tuple(stride), to_sh_from_tuple(padding), to_sh_from_tuple(dilation), groups, bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl Embedding {
#[new]
#[pyo3(signature = (num_embeddings, embedding_dim, dtype=DType::F32))]
pub fn py_new(num_embeddings: u64, embedding_dim: u64, dtype: DType) -> ZyxResult<Self> {
Self::new(num_embeddings, embedding_dim, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl LayerNorm {
#[new]
#[pyo3(signature = (normalized_shape, eps=1e-5, elementwise_affine=true, py_bias=true, dtype=DType::F32))]
pub fn py_new(normalized_shape: &Bound<'_, PyTuple>, eps: f64, elementwise_affine: bool, py_bias: bool, dtype: DType) -> ZyxResult<Self> {
Self::new(to_sh(normalized_shape), eps, elementwise_affine, py_bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl BatchNorm {
#[new]
#[pyo3(signature = (num_features, eps=1e-5, momentum=0.1, affine=true, track_running_stats=true, dtype=DType::F32))]
pub fn py_new(num_features: u64, eps: f64, momentum: f64, affine: bool, track_running_stats: bool, dtype: DType) -> Self {
Self {
eps: eps as f32,
momentum: momentum as f32,
track_running_stats,
weight: if affine { Some(Tensor::ones(num_features, dtype)) } else { None },
bias: if affine { Some(Tensor::zeros(num_features, dtype)) } else { None },
running_mean: if track_running_stats { Tensor::zeros(num_features, dtype) } else { Tensor::zeros(0, dtype) },
running_var: if track_running_stats { Tensor::ones(num_features, dtype) } else { Tensor::zeros(0, dtype) },
num_batches_tracked: if track_running_stats { Tensor::zeros(1, dtype) } else { Tensor::zeros(0, dtype) },
}
}
}
#[pymethods]
impl GroupNorm {
#[new]
#[pyo3(signature = (num_groups, num_channels, eps=1e-5, affine=true, dtype=DType::F32))]
pub fn py_new(num_groups: u64, num_channels: u64, eps: f64, affine: bool, dtype: DType) -> ZyxResult<Self> {
let _ = eps; Self::new(num_groups, num_channels, affine, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl RMSNorm {
#[new]
#[pyo3(signature = (dim, dtype=DType::F32))]
pub fn py_new(dim: u64, dtype: DType) -> Self {
Self::new(dim, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl CausalSelfAttention {
#[new]
#[pyo3(signature = (embed_dim, num_heads, bias=true, dropout=0.0, dtype=DType::F32))]
pub fn py_new(embed_dim: u64, num_heads: u64, bias: bool, dropout: f32, dtype: DType) -> ZyxResult<Self> {
Self::new(embed_dim, num_heads, bias, dropout, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl MultiheadAttention {
#[new]
#[pyo3(signature = (embed_dim, num_heads, dropout=0.0, bias=true, add_bias_kv=false, add_zero_attn=false, kdim=None, vdim=None, batch_first=false, dtype=DType::F32))]
pub fn py_new(
embed_dim: u64,
num_heads: u64,
dropout: f32,
bias: bool,
add_bias_kv: bool,
add_zero_attn: bool,
kdim: Option<u64>,
vdim: Option<u64>,
batch_first: bool,
dtype: DType,
) -> ZyxResult<Self> {
Self::new(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> ZyxResult<(Tensor, Option<Tensor>)> {
self.forward(query.clone(), key.clone(), value.clone(), None::<Tensor>, true, None::<Tensor>, true, false)
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> ZyxResult<(Tensor, Option<Tensor>)> {
self.forward(query.clone(), key.clone(), value.clone(), None::<Tensor>, true, None::<Tensor>, true, false)
}
}
#[pymethods]
impl PositionalEncoding {
#[new]
#[pyo3(signature = (d_model, max_len=5000, dropout=0.1, dtype=DType::F32))]
pub fn py_new(d_model: u64, max_len: usize, dropout: f32, dtype: DType) -> ZyxResult<Self> {
Self::new(d_model, max_len, dropout, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor) -> ZyxResult<Tensor> {
self.forward(x.clone())
}
}
#[pymethods]
impl TransformerEncoderLayer {
#[new]
#[pyo3(signature = (d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5, batch_first=false, norm_first=false, bias=true, dtype=DType::F32))]
pub fn py_new(
d_model: u64,
nhead: u64,
dim_feedforward: u64,
dropout: f32,
layer_norm_eps: f64,
batch_first: bool,
norm_first: bool,
bias: bool,
dtype: DType,
) -> ZyxResult<Self> {
Self::new(d_model, nhead, dim_feedforward, dropout, |t| t.relu(), layer_norm_eps, batch_first, norm_first, bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, src: &Tensor) -> ZyxResult<Tensor> {
self.forward(src.clone(), None::<Tensor>, None::<Tensor>)
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, src: &Tensor) -> ZyxResult<Tensor> {
self.forward(src.clone(), None::<Tensor>, None::<Tensor>)
}
}
#[pymethods]
impl TransformerDecoderLayer {
#[new]
#[pyo3(signature = (d_model, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5, batch_first=false, norm_first=false, bias=true, dtype=DType::F32))]
pub fn py_new(
d_model: u64,
nhead: u64,
dim_feedforward: u64,
dropout: f32,
layer_norm_eps: f64,
batch_first: bool,
norm_first: bool,
bias: bool,
dtype: DType,
) -> ZyxResult<Self> {
Self::new(d_model, nhead, dim_feedforward, dropout, |t| t.relu(), layer_norm_eps, batch_first, norm_first, bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, tgt: &Tensor, memory: &Tensor) -> ZyxResult<Tensor> {
self.forward(tgt, memory, None::<Tensor>, None::<Tensor>, None::<Tensor>, None::<Tensor>, false, false)
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, tgt: &Tensor, memory: &Tensor) -> ZyxResult<Tensor> {
self.forward(tgt, memory, None::<Tensor>, None::<Tensor>, None::<Tensor>, None::<Tensor>, false, false)
}
}
#[pymethods]
impl RNNCell {
#[new]
#[pyo3(signature = (input_size, hidden_size, bias=true, nonlinearity="tanh", dtype=DType::F32))]
pub fn py_new(input_size: u64, hidden_size: u64, bias: bool, nonlinearity: &str, dtype: DType) -> ZyxResult<Self> {
let s = String::from(nonlinearity);
Self::new(input_size, hidden_size, bias, s.leak(), Some(dtype))
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor, hx: &Tensor) -> ZyxResult<Tensor> {
self.forward(x, hx)
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor, hx: &Tensor) -> ZyxResult<Tensor> {
self.forward(x, hx)
}
}
#[pymethods]
impl GRUCell {
#[new]
#[pyo3(signature = (input_size, hidden_size, bias=true, dtype=DType::F32))]
pub fn py_new(input_size: u64, hidden_size: u64, bias: bool, dtype: DType) -> ZyxResult<Self> {
Self::new(input_size, hidden_size, bias, dtype)
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, input: &Tensor, hx: &Tensor) -> ZyxResult<Tensor> {
self.forward(input.clone(), hx.clone())
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, input: &Tensor, hx: &Tensor) -> ZyxResult<Tensor> {
self.forward(input.clone(), hx.clone())
}
}
#[pymethods]
impl LSTMCell {
#[new]
#[pyo3(signature = (input_size, hidden_size, bias=true, dtype=DType::F32))]
pub fn py_new(input_size: u64, hidden_size: u64, bias: bool, dtype: DType) -> ZyxResult<Self> {
Self::new(input_size, hidden_size, bias, Some(dtype))
}
#[pyo3(name = "forward")]
pub fn forward_py(&self, x: &Tensor, h: &Tensor, c: &Tensor) -> ZyxResult<(Tensor, Tensor)> {
self.forward(x, h, c)
}
#[pyo3(name = "__call__")]
pub fn call_py(&self, x: &Tensor, h: &Tensor, c: &Tensor) -> ZyxResult<(Tensor, Tensor)> {
self.forward(x, h, c)
}
}
pub fn register_nn(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Linear>()?;
m.add_class::<Conv2d>()?;
m.add_class::<Embedding>()?;
m.add_class::<LayerNorm>()?;
m.add_class::<BatchNorm>()?;
m.add_class::<GroupNorm>()?;
m.add_class::<RMSNorm>()?;
m.add_class::<CausalSelfAttention>()?;
m.add_class::<MultiheadAttention>()?;
m.add_class::<PositionalEncoding>()?;
m.add_class::<TransformerEncoderLayer>()?;
m.add_class::<TransformerDecoderLayer>()?;
m.add_class::<RNNCell>()?;
m.add_class::<GRUCell>()?;
m.add_class::<LSTMCell>()?;
Ok(())
}