smelte_rs/cpu/f32/
traits.rs1use super::ops;
2use super::tensor::Tensor;
3use crate::traits::{
4 Tensor as TensorTrait, TensorAdd, TensorGelu, TensorMatmul, TensorMatmulT, TensorMul,
5 TensorNormalize, TensorOps, TensorSelect, TensorSoftmax, TensorTanh,
6};
7use crate::SmeltError;
8
9impl<'a> TensorTrait for Tensor<'a> {
10 fn shape(&self) -> &[usize] {
11 &self.shape
12 }
13 fn zeros(shape: Vec<usize>) -> Self {
14 Self::zeros(shape)
15 }
16}
17
18impl<'a> TensorAdd<Tensor<'a>> for Tensor<'a> {
19 fn add(x: &Self, y: &mut Self) -> Result<(), SmeltError> {
20 ops::add(x, y)
21 }
22}
23
24impl<'a> TensorMul<Tensor<'a>> for Tensor<'a> {
25 fn mul(x: &Self, y: &mut Self) -> Result<(), SmeltError> {
26 ops::mul(x, y)
27 }
28}
29
30impl<'a> TensorNormalize<Tensor<'a>> for Tensor<'a> {
31 fn normalize(x: &mut Self, epsilon: f32) -> Result<(), SmeltError> {
32 ops::normalize(x, epsilon)
33 }
34}
35
36impl<'a> TensorMatmul<Tensor<'a>> for Tensor<'a> {
37 fn matmul(x: &Self, y: &Self, out: &mut Self) -> Result<(), SmeltError> {
38 ops::matmul(x, y, out)
39 }
40}
41
42impl<'a> TensorMatmulT<Tensor<'a>> for Tensor<'a> {
43 fn matmul_t(x: &Self, y: &Self, out: &mut Self) -> Result<(), SmeltError> {
44 ops::matmul_t(x, y, out)
45 }
46}
47
48impl<'a> TensorSelect<Tensor<'a>> for Tensor<'a> {
49 fn select(x: &[usize], weight: &Self, out: &mut Self) -> Result<(), SmeltError> {
50 ops::select(x, weight, out)
51 }
52}
53
54impl<'a> TensorGelu<Tensor<'a>> for Tensor<'a> {
55 fn gelu(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
56 ops::apply(x, ops::gelu);
57 Ok(())
58 }
59}
60
61impl<'a> TensorTanh<Tensor<'a>> for Tensor<'a> {
62 fn tanh(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
63 ops::apply(x, ops::inline_tanh);
64 Ok(())
65 }
66}
67
68impl<'a> TensorSoftmax<Tensor<'a>> for Tensor<'a> {
69 fn softmax(x: &mut Tensor<'a>) -> Result<(), SmeltError> {
70 ops::softmax(x)
71 }
72}
73
74impl<'a> TensorOps<Tensor<'a>> for Tensor<'a> {}