1#![allow(missing_docs)]
2use super::{Backend, ComputeOp};
11use crate::error::TruenoError;
12
13#[derive(Debug, Clone)]
19pub struct DotOp {
20 pub len: usize,
22}
23
24impl DotOp {
25 pub fn new(len: usize) -> Self {
26 Self { len }
27 }
28}
29
30impl ComputeOp for DotOp {
31 type Input = (Vec<f32>, Vec<f32>);
32 type Output = f32;
33
34 fn name(&self) -> &'static str {
35 "dot"
36 }
37
38 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
39 let (a, b) = input;
40 if a.len() != b.len() {
41 return Err(TruenoError::SizeMismatch { expected: a.len(), actual: b.len() });
42 }
43 let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
45 Ok(sum)
46 }
47
48 fn tokens(&self, input: &Self::Input) -> usize {
49 input.0.len()
51 }
52}
53
54#[derive(Debug, Clone)]
60pub struct AddOp {
61 pub len: usize,
63}
64
65impl AddOp {
66 pub fn new(len: usize) -> Self {
67 Self { len }
68 }
69}
70
71impl ComputeOp for AddOp {
72 type Input = (Vec<f32>, Vec<f32>);
73 type Output = Vec<f32>;
74
75 fn name(&self) -> &'static str {
76 "add"
77 }
78
79 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
80 let (a, b) = input;
81 if a.len() != b.len() {
82 return Err(TruenoError::SizeMismatch { expected: a.len(), actual: b.len() });
83 }
84 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
85 }
86
87 fn tokens(&self, input: &Self::Input) -> usize {
88 input.0.len()
89 }
90}
91
92#[derive(Debug, Clone)]
98pub struct MatmulOp {
99 pub m: usize,
101 pub k: usize,
103 pub n: usize,
105}
106
107impl MatmulOp {
108 pub fn new(m: usize, k: usize, n: usize) -> Self {
109 Self { m, k, n }
110 }
111}
112
113impl ComputeOp for MatmulOp {
114 type Input = (Vec<f32>, Vec<f32>);
115 type Output = Vec<f32>;
116
117 fn name(&self) -> &'static str {
118 "matmul"
119 }
120
121 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
122 let (a, b) = input;
123 let expected_a = self.m * self.k;
124 let expected_b = self.k * self.n;
125
126 if a.len() != expected_a {
127 return Err(TruenoError::SizeMismatch { expected: expected_a, actual: a.len() });
128 }
129 if b.len() != expected_b {
130 return Err(TruenoError::SizeMismatch { expected: expected_b, actual: b.len() });
131 }
132
133 let simd_backend = crate::Backend::select_best();
136 let mat_a = crate::Matrix::from_vec_with_backend(self.m, self.k, a, simd_backend);
137 let mat_b = crate::Matrix::from_vec_with_backend(self.k, self.n, b, simd_backend);
138
139 let result = mat_a.matmul(&mat_b)?;
140 Ok(result.data)
142 }
143
144 fn tokens(&self, _input: &Self::Input) -> usize {
145 self.m * self.n
148 }
149}
150
151#[derive(Debug, Clone)]
157pub struct SoftmaxOp {
158 pub len: usize,
160}
161
162impl SoftmaxOp {
163 pub fn new(len: usize) -> Self {
164 Self { len }
165 }
166}
167
168impl ComputeOp for SoftmaxOp {
169 type Input = Vec<f32>;
170 type Output = Vec<f32>;
171
172 fn name(&self) -> &'static str {
173 "softmax"
174 }
175
176 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
177 if input.is_empty() {
178 return Ok(vec![]);
179 }
180
181 Ok(crate::blis::softmax::softmax_1d_alloc(&input))
185 }
186
187 fn tokens(&self, input: &Self::Input) -> usize {
188 input.len()
189 }
190}
191
192impl SoftmaxOp {
193 #[inline]
195 pub fn is_simd_backend(backend: Backend) -> bool {
196 matches!(
197 backend,
198 Backend::Avx2 | Backend::Avx512 | Backend::Sse2 | Backend::Neon | Backend::Auto
199 )
200 }
201 }
205
206#[cfg(test)]
207mod tests;