ghostflow_core/ops/
matmul.rs

1//! Matrix multiplication and linear algebra operations
2
3use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8impl Tensor {
9    /// Matrix multiplication
10    /// Supports:
11    /// - 2D x 2D: standard matmul
12    /// - Batched: broadcast batch dimensions
13    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
14        let a_dims = self.dims();
15        let b_dims = other.dims();
16
17        if a_dims.len() < 2 || b_dims.len() < 2 {
18            return Err(GhostError::InvalidOperation(
19                "matmul requires at least 2D tensors".to_string()
20            ));
21        }
22
23        let m = a_dims[a_dims.len() - 2];
24        let k = a_dims[a_dims.len() - 1];
25        let k2 = b_dims[b_dims.len() - 2];
26        let n = b_dims[b_dims.len() - 1];
27
28        if k != k2 {
29            return Err(GhostError::ShapeMismatch {
30                expected: vec![m, k],
31                got: vec![k2, n],
32            });
33        }
34
35        // Handle batched matmul
36        if a_dims.len() == 2 && b_dims.len() == 2 {
37            return self.matmul_2d(other, m, k, n);
38        }
39
40        // Batched case
41        self.batched_matmul(other)
42    }
43
44    /// 2D matrix multiplication (optimized)
45    fn matmul_2d(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
46        // Use BLAS if available and matrix is large enough
47        #[cfg(feature = "blas")]
48        {
49            const BLAS_THRESHOLD: usize = 64;
50            if m >= BLAS_THRESHOLD && n >= BLAS_THRESHOLD && k >= BLAS_THRESHOLD {
51                return self.matmul_blas(other, m, k, n);
52            }
53        }
54        
55        // Fallback to optimized blocked implementation
56        self.matmul_blocked(other, m, k, n)
57    }
58
59    /// BLAS-accelerated matrix multiplication (10-50x faster for large matrices)
60    #[cfg(feature = "blas")]
61    fn matmul_blas(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
62        use cblas::*;
63        
64        let a = self.data_f32();
65        let b = other.data_f32();
66        let mut c = vec![0.0f32; m * n];
67        
68        unsafe {
69            sgemm(
70                Layout::RowMajor,
71                Transpose::None,
72                Transpose::None,
73                m as i32,
74                n as i32,
75                k as i32,
76                1.0,           // alpha
77                &a,
78                k as i32,      // lda
79                &b,
80                n as i32,      // ldb
81                0.0,           // beta
82                &mut c,
83                n as i32,      // ldc
84            );
85        }
86        
87        Tensor::from_slice(&c, &[m, n])
88    }
89
90    /// Blocked/tiled matrix multiplication (cache-optimized fallback)
91    fn matmul_blocked(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
92        let a = self.data_f32();
93        let b = other.data_f32();
94        
95        // Use blocked/tiled multiplication for cache efficiency
96        let mut c = vec![0.0f32; m * n];
97        
98        const BLOCK_SIZE: usize = 64;
99        
100        // Parallel over output rows
101        c.chunks_mut(n).enumerate().for_each(|(i, row)| {
102            for jb in (0..n).step_by(BLOCK_SIZE) {
103                let j_end = (jb + BLOCK_SIZE).min(n);
104                
105                for kb in (0..k).step_by(BLOCK_SIZE) {
106                    let k_end = (kb + BLOCK_SIZE).min(k);
107                    
108                    for kk in kb..k_end {
109                        let a_ik = a[i * k + kk];
110                        for j in jb..j_end {
111                            row[j] += a_ik * b[kk * n + j];
112                        }
113                    }
114                }
115            }
116        });
117
118        Tensor::from_slice(&c, &[m, n])
119    }
120
121    /// Batched matrix multiplication
122    fn batched_matmul(&self, other: &Tensor) -> Result<Tensor> {
123        let a_dims = self.dims();
124        let b_dims = other.dims();
125        
126        let m = a_dims[a_dims.len() - 2];
127        let k = a_dims[a_dims.len() - 1];
128        let n = b_dims[b_dims.len() - 1];
129        
130        // Compute batch dimensions
131        let a_batch: Vec<usize> = a_dims[..a_dims.len() - 2].to_vec();
132        let b_batch: Vec<usize> = b_dims[..b_dims.len() - 2].to_vec();
133        
134        // Broadcast batch dimensions
135        let batch_dims = broadcast_batch_dims(&a_batch, &b_batch)?;
136        let batch_size: usize = batch_dims.iter().product();
137        
138        let a = self.data_f32();
139        let b = other.data_f32();
140        
141        let a_batch_stride = m * k;
142        let b_batch_stride = k * n;
143        let c_batch_stride = m * n;
144        
145        let mut result = vec![0.0f32; batch_size * m * n];
146        
147        result.chunks_mut(c_batch_stride).enumerate().for_each(|(batch_idx, c_batch)| {
148            let a_idx = batch_idx % (a_batch.iter().product::<usize>().max(1));
149            let b_idx = batch_idx % (b_batch.iter().product::<usize>().max(1));
150            
151            let a_start = a_idx * a_batch_stride;
152            let b_start = b_idx * b_batch_stride;
153            
154            for i in 0..m {
155                for j in 0..n {
156                    let mut sum = 0.0f32;
157                    for kk in 0..k {
158                        sum += a[a_start + i * k + kk] * b[b_start + kk * n + j];
159                    }
160                    c_batch[i * n + j] = sum;
161                }
162            }
163        });
164        
165        let mut out_shape = batch_dims;
166        out_shape.push(m);
167        out_shape.push(n);
168        
169        Tensor::from_slice(&result, &out_shape)
170    }
171
172    /// Vector dot product
173    pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
174        if self.ndim() != 1 || other.ndim() != 1 {
175            return Err(GhostError::InvalidOperation(
176                "dot requires 1D tensors".to_string()
177            ));
178        }
179        
180        if self.numel() != other.numel() {
181            return Err(GhostError::ShapeMismatch {
182                expected: self.dims().to_vec(),
183                got: other.dims().to_vec(),
184            });
185        }
186        
187        let a = self.data_f32();
188        let b = other.data_f32();
189        
190        let dot: f32 = a.iter()
191            .zip(b.iter())
192            .map(|(&x, &y)| x * y)
193            .sum();
194        
195        Tensor::from_slice(&[dot], &[])
196    }
197
198    /// Outer product of two vectors
199    pub fn outer(&self, other: &Tensor) -> Result<Tensor> {
200        if self.ndim() != 1 || other.ndim() != 1 {
201            return Err(GhostError::InvalidOperation(
202                "outer requires 1D tensors".to_string()
203            ));
204        }
205        
206        let a = self.data_f32();
207        let b = other.data_f32();
208        let m = a.len();
209        let n = b.len();
210        
211        let result: Vec<f32> = (0..m)
212            .into_iter()
213            .flat_map(|i| {
214                b.iter().map(|&bj| a[i] * bj).collect::<Vec<_>>()
215            })
216            .collect();
217        
218        Tensor::from_slice(&result, &[m, n])
219    }
220
221    /// Matrix-vector multiplication
222    pub fn mv(&self, vec: &Tensor) -> Result<Tensor> {
223        if self.ndim() != 2 || vec.ndim() != 1 {
224            return Err(GhostError::InvalidOperation(
225                "mv requires 2D matrix and 1D vector".to_string()
226            ));
227        }
228        
229        let m = self.dims()[0];
230        let n = self.dims()[1];
231        
232        if vec.numel() != n {
233            return Err(GhostError::ShapeMismatch {
234                expected: vec![n],
235                got: vec.dims().to_vec(),
236            });
237        }
238        
239        let mat = self.data_f32();
240        let v = vec.data_f32();
241        
242        let result: Vec<f32> = (0..m)
243            .into_iter()
244            .map(|i| {
245                (0..n).map(|j| mat[i * n + j] * v[j]).sum()
246            })
247            .collect();
248        
249        Tensor::from_slice(&result, &[m])
250    }
251
252    /// Batch matrix-matrix multiplication (bmm)
253    pub fn bmm(&self, other: &Tensor) -> Result<Tensor> {
254        if self.ndim() != 3 || other.ndim() != 3 {
255            return Err(GhostError::InvalidOperation(
256                "bmm requires 3D tensors".to_string()
257            ));
258        }
259        
260        self.matmul(other)
261    }
262
263    /// Compute trace of a matrix
264    pub fn trace(&self) -> Result<Tensor> {
265        if self.ndim() != 2 {
266            return Err(GhostError::InvalidOperation(
267                "trace requires 2D tensor".to_string()
268            ));
269        }
270        
271        let dims = self.dims();
272        let n = dims[0].min(dims[1]);
273        let data = self.data_f32();
274        let cols = dims[1];
275        
276        let trace: f32 = (0..n).map(|i| data[i * cols + i]).sum();
277        
278        Tensor::from_slice(&[trace], &[])
279    }
280
281    /// Compute diagonal of a matrix
282    pub fn diag(&self) -> Result<Tensor> {
283        if self.ndim() != 2 {
284            return Err(GhostError::InvalidOperation(
285                "diag requires 2D tensor".to_string()
286            ));
287        }
288        
289        let dims = self.dims();
290        let n = dims[0].min(dims[1]);
291        let data = self.data_f32();
292        let cols = dims[1];
293        
294        let diag: Vec<f32> = (0..n).map(|i| data[i * cols + i]).collect();
295        
296        Tensor::from_slice(&diag, &[n])
297    }
298}
299
300/// Broadcast batch dimensions
301fn broadcast_batch_dims(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
302    let max_len = a.len().max(b.len());
303    let mut result = Vec::with_capacity(max_len);
304    
305    for i in 0..max_len {
306        let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
307        let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
308        
309        if a_dim == b_dim {
310            result.push(a_dim);
311        } else if a_dim == 1 {
312            result.push(b_dim);
313        } else if b_dim == 1 {
314            result.push(a_dim);
315        } else {
316            return Err(GhostError::BroadcastError {
317                a: a.to_vec(),
318                b: b.to_vec(),
319            });
320        }
321    }
322    
323    result.reverse();
324    Ok(result)
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_matmul_2d() {
333        // [2, 3] x [3, 2] = [2, 2]
334        let a = Tensor::from_slice(
335            &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
336            &[2, 3]
337        ).unwrap();
338        let b = Tensor::from_slice(
339            &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
340            &[3, 2]
341        ).unwrap();
342        
343        let c = a.matmul(&b).unwrap();
344        assert_eq!(c.dims(), &[2, 2]);
345        
346        // Expected: [[22, 28], [49, 64]]
347        let data = c.data_f32();
348        assert_eq!(data[0], 22.0);
349        assert_eq!(data[1], 28.0);
350        assert_eq!(data[2], 49.0);
351        assert_eq!(data[3], 64.0);
352    }
353
354    #[test]
355    fn test_dot() {
356        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
357        let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
358        
359        let dot = a.dot(&b).unwrap();
360        assert_eq!(dot.data_f32()[0], 32.0); // 1*4 + 2*5 + 3*6
361    }
362
363    #[test]
364    fn test_mv() {
365        let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
366        let vec = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
367        
368        let result = mat.mv(&vec).unwrap();
369        assert_eq!(result.dims(), &[2]);
370        assert_eq!(result.data_f32(), vec![5.0, 11.0]); // [1*1+2*2, 3*1+4*2]
371    }
372
373    #[test]
374    fn test_trace() {
375        let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
376        let trace = mat.trace().unwrap();
377        assert_eq!(trace.data_f32()[0], 5.0); // 1 + 4
378    }
379}
380