smelte_rs/cpu/f32/
traits.rs

1use super::ops;
2use super::tensor::Tensor;
3use crate::traits::{
4    Tensor as TensorTrait, TensorAdd, TensorGelu, TensorMatmul, TensorMatmulT, TensorMul,
5    TensorNormalize, TensorOps, TensorSelect, TensorSoftmax, TensorTanh,
6};
7use crate::SmeltError;
8
9impl<'a> TensorTrait for Tensor<'a> {
10    fn shape(&self) -> &[usize] {
11        &self.shape
12    }
13    fn zeros(shape: Vec<usize>) -> Self {
14        Self::zeros(shape)
15    }
16}
17
18impl<'a> TensorAdd<Tensor<'a>> for Tensor<'a> {
19    fn add(x: &Self, y: &mut Self) -> Result<(), SmeltError> {
20        ops::add(x, y)
21    }
22}
23
24impl<'a> TensorMul<Tensor<'a>> for Tensor<'a> {
25    fn mul(x: &Self, y: &mut Self) -> Result<(), SmeltError> {
26        ops::mul(x, y)
27    }
28}
29
30impl<'a> TensorNormalize<Tensor<'a>> for Tensor<'a> {
31    fn normalize(x: &mut Self, epsilon: f32) -> Result<(), SmeltError> {
32        ops::normalize(x, epsilon)
33    }
34}
35
36impl<'a> TensorMatmul<Tensor<'a>> for Tensor<'a> {
37    fn matmul(x: &Self, y: &Self, out: &mut Self) -> Result<(), SmeltError> {
38        ops::matmul(x, y, out)
39    }
40}
41
42impl<'a> TensorMatmulT<Tensor<'a>> for Tensor<'a> {
43    fn matmul_t(x: &Self, y: &Self, out: &mut Self) -> Result<(), SmeltError> {
44        ops::matmul_t(x, y, out)
45    }
46}
47
48impl<'a> TensorSelect<Tensor<'a>> for Tensor<'a> {
49    fn select(x: &[usize], weight: &Self, out: &mut Self) -> Result<(), SmeltError> {
50        ops::select(x, weight, out)
51    }
52}
53
54impl<'a> TensorGelu<Tensor<'a>> for Tensor<'a> {
55    fn gelu(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
56        ops::apply(x, ops::gelu);
57        Ok(())
58    }
59}
60
61impl<'a> TensorTanh<Tensor<'a>> for Tensor<'a> {
62    fn tanh(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
63        ops::apply(x, ops::inline_tanh);
64        Ok(())
65    }
66}
67
68impl<'a> TensorSoftmax<Tensor<'a>> for Tensor<'a> {
69    fn softmax(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
70        ops::softmax(x)
71    }
72}
73
74impl<'a> TensorOps<Tensor<'a>> for Tensor<'a> {}