1use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5use rayon::prelude::*;
6
7impl Tensor {
8 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
13 let a_dims = self.dims();
14 let b_dims = other.dims();
15
16 if a_dims.len() < 2 || b_dims.len() < 2 {
17 return Err(GhostError::InvalidOperation(
18 "matmul requires at least 2D tensors".to_string()
19 ));
20 }
21
22 let m = a_dims[a_dims.len() - 2];
23 let k = a_dims[a_dims.len() - 1];
24 let k2 = b_dims[b_dims.len() - 2];
25 let n = b_dims[b_dims.len() - 1];
26
27 if k != k2 {
28 return Err(GhostError::ShapeMismatch {
29 expected: vec![m, k],
30 got: vec![k2, n],
31 });
32 }
33
34 if a_dims.len() == 2 && b_dims.len() == 2 {
36 return self.matmul_2d(other, m, k, n);
37 }
38
39 self.batched_matmul(other)
41 }
42
43 fn matmul_2d(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
45 #[cfg(feature = "blas")]
47 {
48 const BLAS_THRESHOLD: usize = 64;
49 if m >= BLAS_THRESHOLD && n >= BLAS_THRESHOLD && k >= BLAS_THRESHOLD {
50 return self.matmul_blas(other, m, k, n);
51 }
52 }
53
54 self.matmul_blocked(other, m, k, n)
56 }
57
58 #[cfg(feature = "blas")]
60 fn matmul_blas(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
61 use cblas::*;
62
63 let a = self.data_f32();
64 let b = other.data_f32();
65 let mut c = vec![0.0f32; m * n];
66
67 unsafe {
68 sgemm(
69 Layout::RowMajor,
70 Transpose::None,
71 Transpose::None,
72 m as i32,
73 n as i32,
74 k as i32,
75 1.0, &a,
77 k as i32, &b,
79 n as i32, 0.0, &mut c,
82 n as i32, );
84 }
85
86 Tensor::from_slice(&c, &[m, n])
87 }
88
89 fn matmul_blocked(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
91 let a = self.data_f32();
92 let b = other.data_f32();
93
94 let mut c = vec![0.0f32; m * n];
96
97 const BLOCK_SIZE: usize = 64;
98
99 c.par_chunks_mut(n).enumerate().for_each(|(i, row)| {
101 for jb in (0..n).step_by(BLOCK_SIZE) {
102 let j_end = (jb + BLOCK_SIZE).min(n);
103
104 for kb in (0..k).step_by(BLOCK_SIZE) {
105 let k_end = (kb + BLOCK_SIZE).min(k);
106
107 for kk in kb..k_end {
108 let a_ik = a[i * k + kk];
109 for j in jb..j_end {
110 row[j] += a_ik * b[kk * n + j];
111 }
112 }
113 }
114 }
115 });
116
117 Tensor::from_slice(&c, &[m, n])
118 }
119
120 fn batched_matmul(&self, other: &Tensor) -> Result<Tensor> {
122 let a_dims = self.dims();
123 let b_dims = other.dims();
124
125 let m = a_dims[a_dims.len() - 2];
126 let k = a_dims[a_dims.len() - 1];
127 let n = b_dims[b_dims.len() - 1];
128
129 let a_batch: Vec<usize> = a_dims[..a_dims.len() - 2].to_vec();
131 let b_batch: Vec<usize> = b_dims[..b_dims.len() - 2].to_vec();
132
133 let batch_dims = broadcast_batch_dims(&a_batch, &b_batch)?;
135 let batch_size: usize = batch_dims.iter().product();
136
137 let a = self.data_f32();
138 let b = other.data_f32();
139
140 let a_batch_stride = m * k;
141 let b_batch_stride = k * n;
142 let c_batch_stride = m * n;
143
144 let mut result = vec![0.0f32; batch_size * m * n];
145
146 result.par_chunks_mut(c_batch_stride).enumerate().for_each(|(batch_idx, c_batch)| {
147 let a_idx = batch_idx % (a_batch.iter().product::<usize>().max(1));
148 let b_idx = batch_idx % (b_batch.iter().product::<usize>().max(1));
149
150 let a_start = a_idx * a_batch_stride;
151 let b_start = b_idx * b_batch_stride;
152
153 for i in 0..m {
154 for j in 0..n {
155 let mut sum = 0.0f32;
156 for kk in 0..k {
157 sum += a[a_start + i * k + kk] * b[b_start + kk * n + j];
158 }
159 c_batch[i * n + j] = sum;
160 }
161 }
162 });
163
164 let mut out_shape = batch_dims;
165 out_shape.push(m);
166 out_shape.push(n);
167
168 Tensor::from_slice(&result, &out_shape)
169 }
170
171 pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
173 if self.ndim() != 1 || other.ndim() != 1 {
174 return Err(GhostError::InvalidOperation(
175 "dot requires 1D tensors".to_string()
176 ));
177 }
178
179 if self.numel() != other.numel() {
180 return Err(GhostError::ShapeMismatch {
181 expected: self.dims().to_vec(),
182 got: other.dims().to_vec(),
183 });
184 }
185
186 let a = self.data_f32();
187 let b = other.data_f32();
188
189 let dot: f32 = a.par_iter()
190 .zip(b.par_iter())
191 .map(|(&x, &y)| x * y)
192 .sum();
193
194 Tensor::from_slice(&[dot], &[])
195 }
196
197 pub fn outer(&self, other: &Tensor) -> Result<Tensor> {
199 if self.ndim() != 1 || other.ndim() != 1 {
200 return Err(GhostError::InvalidOperation(
201 "outer requires 1D tensors".to_string()
202 ));
203 }
204
205 let a = self.data_f32();
206 let b = other.data_f32();
207 let m = a.len();
208 let n = b.len();
209
210 let result: Vec<f32> = (0..m)
211 .into_par_iter()
212 .flat_map(|i| {
213 b.iter().map(|&bj| a[i] * bj).collect::<Vec<_>>()
214 })
215 .collect();
216
217 Tensor::from_slice(&result, &[m, n])
218 }
219
220 pub fn mv(&self, vec: &Tensor) -> Result<Tensor> {
222 if self.ndim() != 2 || vec.ndim() != 1 {
223 return Err(GhostError::InvalidOperation(
224 "mv requires 2D matrix and 1D vector".to_string()
225 ));
226 }
227
228 let m = self.dims()[0];
229 let n = self.dims()[1];
230
231 if vec.numel() != n {
232 return Err(GhostError::ShapeMismatch {
233 expected: vec![n],
234 got: vec.dims().to_vec(),
235 });
236 }
237
238 let mat = self.data_f32();
239 let v = vec.data_f32();
240
241 let result: Vec<f32> = (0..m)
242 .into_par_iter()
243 .map(|i| {
244 (0..n).map(|j| mat[i * n + j] * v[j]).sum()
245 })
246 .collect();
247
248 Tensor::from_slice(&result, &[m])
249 }
250
251 pub fn bmm(&self, other: &Tensor) -> Result<Tensor> {
253 if self.ndim() != 3 || other.ndim() != 3 {
254 return Err(GhostError::InvalidOperation(
255 "bmm requires 3D tensors".to_string()
256 ));
257 }
258
259 self.matmul(other)
260 }
261
262 pub fn trace(&self) -> Result<Tensor> {
264 if self.ndim() != 2 {
265 return Err(GhostError::InvalidOperation(
266 "trace requires 2D tensor".to_string()
267 ));
268 }
269
270 let dims = self.dims();
271 let n = dims[0].min(dims[1]);
272 let data = self.data_f32();
273 let cols = dims[1];
274
275 let trace: f32 = (0..n).map(|i| data[i * cols + i]).sum();
276
277 Tensor::from_slice(&[trace], &[])
278 }
279
280 pub fn diag(&self) -> Result<Tensor> {
282 if self.ndim() != 2 {
283 return Err(GhostError::InvalidOperation(
284 "diag requires 2D tensor".to_string()
285 ));
286 }
287
288 let dims = self.dims();
289 let n = dims[0].min(dims[1]);
290 let data = self.data_f32();
291 let cols = dims[1];
292
293 let diag: Vec<f32> = (0..n).map(|i| data[i * cols + i]).collect();
294
295 Tensor::from_slice(&diag, &[n])
296 }
297}
298
299fn broadcast_batch_dims(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
301 let max_len = a.len().max(b.len());
302 let mut result = Vec::with_capacity(max_len);
303
304 for i in 0..max_len {
305 let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
306 let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
307
308 if a_dim == b_dim {
309 result.push(a_dim);
310 } else if a_dim == 1 {
311 result.push(b_dim);
312 } else if b_dim == 1 {
313 result.push(a_dim);
314 } else {
315 return Err(GhostError::BroadcastError {
316 a: a.to_vec(),
317 b: b.to_vec(),
318 });
319 }
320 }
321
322 result.reverse();
323 Ok(result)
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_matmul_2d() {
332 let a = Tensor::from_slice(
334 &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
335 &[2, 3]
336 ).unwrap();
337 let b = Tensor::from_slice(
338 &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
339 &[3, 2]
340 ).unwrap();
341
342 let c = a.matmul(&b).unwrap();
343 assert_eq!(c.dims(), &[2, 2]);
344
345 let data = c.data_f32();
347 assert_eq!(data[0], 22.0);
348 assert_eq!(data[1], 28.0);
349 assert_eq!(data[2], 49.0);
350 assert_eq!(data[3], 64.0);
351 }
352
353 #[test]
354 fn test_dot() {
355 let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
356 let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
357
358 let dot = a.dot(&b).unwrap();
359 assert_eq!(dot.data_f32()[0], 32.0); }
361
362 #[test]
363 fn test_mv() {
364 let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
365 let vec = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
366
367 let result = mat.mv(&vec).unwrap();
368 assert_eq!(result.dims(), &[2]);
369 assert_eq!(result.data_f32(), vec![5.0, 11.0]); }
371
372 #[test]
373 fn test_trace() {
374 let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
375 let trace = mat.trace().unwrap();
376 assert_eq!(trace.data_f32()[0], 5.0); }
378}