1use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8impl Tensor {
9 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
14 let a_dims = self.dims();
15 let b_dims = other.dims();
16
17 if a_dims.len() < 2 || b_dims.len() < 2 {
18 return Err(GhostError::InvalidOperation(
19 "matmul requires at least 2D tensors".to_string()
20 ));
21 }
22
23 let m = a_dims[a_dims.len() - 2];
24 let k = a_dims[a_dims.len() - 1];
25 let k2 = b_dims[b_dims.len() - 2];
26 let n = b_dims[b_dims.len() - 1];
27
28 if k != k2 {
29 return Err(GhostError::ShapeMismatch {
30 expected: vec![m, k],
31 got: vec![k2, n],
32 });
33 }
34
35 if a_dims.len() == 2 && b_dims.len() == 2 {
37 return self.matmul_2d(other, m, k, n);
38 }
39
40 self.batched_matmul(other)
42 }
43
44 fn matmul_2d(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
46 #[cfg(feature = "blas")]
48 {
49 const BLAS_THRESHOLD: usize = 64;
50 if m >= BLAS_THRESHOLD && n >= BLAS_THRESHOLD && k >= BLAS_THRESHOLD {
51 return self.matmul_blas(other, m, k, n);
52 }
53 }
54
55 self.matmul_blocked(other, m, k, n)
57 }
58
59 #[cfg(feature = "blas")]
61 fn matmul_blas(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
62 use cblas::*;
63
64 let a = self.data_f32();
65 let b = other.data_f32();
66 let mut c = vec![0.0f32; m * n];
67
68 unsafe {
69 sgemm(
70 Layout::RowMajor,
71 Transpose::None,
72 Transpose::None,
73 m as i32,
74 n as i32,
75 k as i32,
76 1.0, &a,
78 k as i32, &b,
80 n as i32, 0.0, &mut c,
83 n as i32, );
85 }
86
87 Tensor::from_slice(&c, &[m, n])
88 }
89
90 fn matmul_blocked(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
92 let a = self.data_f32();
93 let b = other.data_f32();
94
95 let mut c = vec![0.0f32; m * n];
97
98 const BLOCK_SIZE: usize = 64;
99
100 c.chunks_mut(n).enumerate().for_each(|(i, row)| {
102 for jb in (0..n).step_by(BLOCK_SIZE) {
103 let j_end = (jb + BLOCK_SIZE).min(n);
104
105 for kb in (0..k).step_by(BLOCK_SIZE) {
106 let k_end = (kb + BLOCK_SIZE).min(k);
107
108 for kk in kb..k_end {
109 let a_ik = a[i * k + kk];
110 for j in jb..j_end {
111 row[j] += a_ik * b[kk * n + j];
112 }
113 }
114 }
115 }
116 });
117
118 Tensor::from_slice(&c, &[m, n])
119 }
120
121 fn batched_matmul(&self, other: &Tensor) -> Result<Tensor> {
123 let a_dims = self.dims();
124 let b_dims = other.dims();
125
126 let m = a_dims[a_dims.len() - 2];
127 let k = a_dims[a_dims.len() - 1];
128 let n = b_dims[b_dims.len() - 1];
129
130 let a_batch: Vec<usize> = a_dims[..a_dims.len() - 2].to_vec();
132 let b_batch: Vec<usize> = b_dims[..b_dims.len() - 2].to_vec();
133
134 let batch_dims = broadcast_batch_dims(&a_batch, &b_batch)?;
136 let batch_size: usize = batch_dims.iter().product();
137
138 let a = self.data_f32();
139 let b = other.data_f32();
140
141 let a_batch_stride = m * k;
142 let b_batch_stride = k * n;
143 let c_batch_stride = m * n;
144
145 let mut result = vec![0.0f32; batch_size * m * n];
146
147 result.chunks_mut(c_batch_stride).enumerate().for_each(|(batch_idx, c_batch)| {
148 let a_idx = batch_idx % (a_batch.iter().product::<usize>().max(1));
149 let b_idx = batch_idx % (b_batch.iter().product::<usize>().max(1));
150
151 let a_start = a_idx * a_batch_stride;
152 let b_start = b_idx * b_batch_stride;
153
154 for i in 0..m {
155 for j in 0..n {
156 let mut sum = 0.0f32;
157 for kk in 0..k {
158 sum += a[a_start + i * k + kk] * b[b_start + kk * n + j];
159 }
160 c_batch[i * n + j] = sum;
161 }
162 }
163 });
164
165 let mut out_shape = batch_dims;
166 out_shape.push(m);
167 out_shape.push(n);
168
169 Tensor::from_slice(&result, &out_shape)
170 }
171
172 pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
174 if self.ndim() != 1 || other.ndim() != 1 {
175 return Err(GhostError::InvalidOperation(
176 "dot requires 1D tensors".to_string()
177 ));
178 }
179
180 if self.numel() != other.numel() {
181 return Err(GhostError::ShapeMismatch {
182 expected: self.dims().to_vec(),
183 got: other.dims().to_vec(),
184 });
185 }
186
187 let a = self.data_f32();
188 let b = other.data_f32();
189
190 let dot: f32 = a.iter()
191 .zip(b.iter())
192 .map(|(&x, &y)| x * y)
193 .sum();
194
195 Tensor::from_slice(&[dot], &[])
196 }
197
198 pub fn outer(&self, other: &Tensor) -> Result<Tensor> {
200 if self.ndim() != 1 || other.ndim() != 1 {
201 return Err(GhostError::InvalidOperation(
202 "outer requires 1D tensors".to_string()
203 ));
204 }
205
206 let a = self.data_f32();
207 let b = other.data_f32();
208 let m = a.len();
209 let n = b.len();
210
211 let result: Vec<f32> = (0..m)
212 .into_iter()
213 .flat_map(|i| {
214 b.iter().map(|&bj| a[i] * bj).collect::<Vec<_>>()
215 })
216 .collect();
217
218 Tensor::from_slice(&result, &[m, n])
219 }
220
221 pub fn mv(&self, vec: &Tensor) -> Result<Tensor> {
223 if self.ndim() != 2 || vec.ndim() != 1 {
224 return Err(GhostError::InvalidOperation(
225 "mv requires 2D matrix and 1D vector".to_string()
226 ));
227 }
228
229 let m = self.dims()[0];
230 let n = self.dims()[1];
231
232 if vec.numel() != n {
233 return Err(GhostError::ShapeMismatch {
234 expected: vec![n],
235 got: vec.dims().to_vec(),
236 });
237 }
238
239 let mat = self.data_f32();
240 let v = vec.data_f32();
241
242 let result: Vec<f32> = (0..m)
243 .into_iter()
244 .map(|i| {
245 (0..n).map(|j| mat[i * n + j] * v[j]).sum()
246 })
247 .collect();
248
249 Tensor::from_slice(&result, &[m])
250 }
251
252 pub fn bmm(&self, other: &Tensor) -> Result<Tensor> {
254 if self.ndim() != 3 || other.ndim() != 3 {
255 return Err(GhostError::InvalidOperation(
256 "bmm requires 3D tensors".to_string()
257 ));
258 }
259
260 self.matmul(other)
261 }
262
263 pub fn trace(&self) -> Result<Tensor> {
265 if self.ndim() != 2 {
266 return Err(GhostError::InvalidOperation(
267 "trace requires 2D tensor".to_string()
268 ));
269 }
270
271 let dims = self.dims();
272 let n = dims[0].min(dims[1]);
273 let data = self.data_f32();
274 let cols = dims[1];
275
276 let trace: f32 = (0..n).map(|i| data[i * cols + i]).sum();
277
278 Tensor::from_slice(&[trace], &[])
279 }
280
281 pub fn diag(&self) -> Result<Tensor> {
283 if self.ndim() != 2 {
284 return Err(GhostError::InvalidOperation(
285 "diag requires 2D tensor".to_string()
286 ));
287 }
288
289 let dims = self.dims();
290 let n = dims[0].min(dims[1]);
291 let data = self.data_f32();
292 let cols = dims[1];
293
294 let diag: Vec<f32> = (0..n).map(|i| data[i * cols + i]).collect();
295
296 Tensor::from_slice(&diag, &[n])
297 }
298}
299
300fn broadcast_batch_dims(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
302 let max_len = a.len().max(b.len());
303 let mut result = Vec::with_capacity(max_len);
304
305 for i in 0..max_len {
306 let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
307 let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
308
309 if a_dim == b_dim {
310 result.push(a_dim);
311 } else if a_dim == 1 {
312 result.push(b_dim);
313 } else if b_dim == 1 {
314 result.push(a_dim);
315 } else {
316 return Err(GhostError::BroadcastError {
317 a: a.to_vec(),
318 b: b.to_vec(),
319 });
320 }
321 }
322
323 result.reverse();
324 Ok(result)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_matmul_2d() {
333 let a = Tensor::from_slice(
335 &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
336 &[2, 3]
337 ).unwrap();
338 let b = Tensor::from_slice(
339 &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
340 &[3, 2]
341 ).unwrap();
342
343 let c = a.matmul(&b).unwrap();
344 assert_eq!(c.dims(), &[2, 2]);
345
346 let data = c.data_f32();
348 assert_eq!(data[0], 22.0);
349 assert_eq!(data[1], 28.0);
350 assert_eq!(data[2], 49.0);
351 assert_eq!(data[3], 64.0);
352 }
353
354 #[test]
355 fn test_dot() {
356 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
357 let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
358
359 let dot = a.dot(&b).unwrap();
360 assert_eq!(dot.data_f32()[0], 32.0); }
362
363 #[test]
364 fn test_mv() {
365 let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
366 let vec = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
367
368 let result = mat.mv(&vec).unwrap();
369 assert_eq!(result.dims(), &[2]);
370 assert_eq!(result.data_f32(), vec![5.0, 11.0]); }
372
373 #[test]
374 fn test_trace() {
375 let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
376 let trace = mat.trace().unwrap();
377 assert_eq!(trace.data_f32()[0], 5.0); }
379}
380