ghostflow_core/ops/
matmul.rs

1//! Matrix multiplication and linear algebra operations
2
3use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5use rayon::prelude::*;
6
7impl Tensor {
8    /// Matrix multiplication
9    /// Supports:
10    /// - 2D x 2D: standard matmul
11    /// - Batched: broadcast batch dimensions
12    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
13        let a_dims = self.dims();
14        let b_dims = other.dims();
15
16        if a_dims.len() < 2 || b_dims.len() < 2 {
17            return Err(GhostError::InvalidOperation(
18                "matmul requires at least 2D tensors".to_string()
19            ));
20        }
21
22        let m = a_dims[a_dims.len() - 2];
23        let k = a_dims[a_dims.len() - 1];
24        let k2 = b_dims[b_dims.len() - 2];
25        let n = b_dims[b_dims.len() - 1];
26
27        if k != k2 {
28            return Err(GhostError::ShapeMismatch {
29                expected: vec![m, k],
30                got: vec![k2, n],
31            });
32        }
33
34        // Handle batched matmul
35        if a_dims.len() == 2 && b_dims.len() == 2 {
36            return self.matmul_2d(other, m, k, n);
37        }
38
39        // Batched case
40        self.batched_matmul(other)
41    }
42
43    /// 2D matrix multiplication (optimized)
44    fn matmul_2d(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
45        // Use BLAS if available and matrix is large enough
46        #[cfg(feature = "blas")]
47        {
48            const BLAS_THRESHOLD: usize = 64;
49            if m >= BLAS_THRESHOLD && n >= BLAS_THRESHOLD && k >= BLAS_THRESHOLD {
50                return self.matmul_blas(other, m, k, n);
51            }
52        }
53        
54        // Fallback to optimized blocked implementation
55        self.matmul_blocked(other, m, k, n)
56    }
57
58    /// BLAS-accelerated matrix multiplication (10-50x faster for large matrices)
59    #[cfg(feature = "blas")]
60    fn matmul_blas(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
61        use cblas::*;
62        
63        let a = self.data_f32();
64        let b = other.data_f32();
65        let mut c = vec![0.0f32; m * n];
66        
67        unsafe {
68            sgemm(
69                Layout::RowMajor,
70                Transpose::None,
71                Transpose::None,
72                m as i32,
73                n as i32,
74                k as i32,
75                1.0,           // alpha
76                &a,
77                k as i32,      // lda
78                &b,
79                n as i32,      // ldb
80                0.0,           // beta
81                &mut c,
82                n as i32,      // ldc
83            );
84        }
85        
86        Tensor::from_slice(&c, &[m, n])
87    }
88
89    /// Blocked/tiled matrix multiplication (cache-optimized fallback)
90    fn matmul_blocked(&self, other: &Tensor, m: usize, k: usize, n: usize) -> Result<Tensor> {
91        let a = self.data_f32();
92        let b = other.data_f32();
93        
94        // Use blocked/tiled multiplication for cache efficiency
95        let mut c = vec![0.0f32; m * n];
96        
97        const BLOCK_SIZE: usize = 64;
98        
99        // Parallel over output rows
100        c.par_chunks_mut(n).enumerate().for_each(|(i, row)| {
101            for jb in (0..n).step_by(BLOCK_SIZE) {
102                let j_end = (jb + BLOCK_SIZE).min(n);
103                
104                for kb in (0..k).step_by(BLOCK_SIZE) {
105                    let k_end = (kb + BLOCK_SIZE).min(k);
106                    
107                    for kk in kb..k_end {
108                        let a_ik = a[i * k + kk];
109                        for j in jb..j_end {
110                            row[j] += a_ik * b[kk * n + j];
111                        }
112                    }
113                }
114            }
115        });
116
117        Tensor::from_slice(&c, &[m, n])
118    }
119
120    /// Batched matrix multiplication
121    fn batched_matmul(&self, other: &Tensor) -> Result<Tensor> {
122        let a_dims = self.dims();
123        let b_dims = other.dims();
124        
125        let m = a_dims[a_dims.len() - 2];
126        let k = a_dims[a_dims.len() - 1];
127        let n = b_dims[b_dims.len() - 1];
128        
129        // Compute batch dimensions
130        let a_batch: Vec<usize> = a_dims[..a_dims.len() - 2].to_vec();
131        let b_batch: Vec<usize> = b_dims[..b_dims.len() - 2].to_vec();
132        
133        // Broadcast batch dimensions
134        let batch_dims = broadcast_batch_dims(&a_batch, &b_batch)?;
135        let batch_size: usize = batch_dims.iter().product();
136        
137        let a = self.data_f32();
138        let b = other.data_f32();
139        
140        let a_batch_stride = m * k;
141        let b_batch_stride = k * n;
142        let c_batch_stride = m * n;
143        
144        let mut result = vec![0.0f32; batch_size * m * n];
145        
146        result.par_chunks_mut(c_batch_stride).enumerate().for_each(|(batch_idx, c_batch)| {
147            let a_idx = batch_idx % (a_batch.iter().product::<usize>().max(1));
148            let b_idx = batch_idx % (b_batch.iter().product::<usize>().max(1));
149            
150            let a_start = a_idx * a_batch_stride;
151            let b_start = b_idx * b_batch_stride;
152            
153            for i in 0..m {
154                for j in 0..n {
155                    let mut sum = 0.0f32;
156                    for kk in 0..k {
157                        sum += a[a_start + i * k + kk] * b[b_start + kk * n + j];
158                    }
159                    c_batch[i * n + j] = sum;
160                }
161            }
162        });
163        
164        let mut out_shape = batch_dims;
165        out_shape.push(m);
166        out_shape.push(n);
167        
168        Tensor::from_slice(&result, &out_shape)
169    }
170
171    /// Vector dot product
172    pub fn dot(&self, other: &Tensor) -> Result<Tensor> {
173        if self.ndim() != 1 || other.ndim() != 1 {
174            return Err(GhostError::InvalidOperation(
175                "dot requires 1D tensors".to_string()
176            ));
177        }
178        
179        if self.numel() != other.numel() {
180            return Err(GhostError::ShapeMismatch {
181                expected: self.dims().to_vec(),
182                got: other.dims().to_vec(),
183            });
184        }
185        
186        let a = self.data_f32();
187        let b = other.data_f32();
188        
189        let dot: f32 = a.par_iter()
190            .zip(b.par_iter())
191            .map(|(&x, &y)| x * y)
192            .sum();
193        
194        Tensor::from_slice(&[dot], &[])
195    }
196
197    /// Outer product of two vectors
198    pub fn outer(&self, other: &Tensor) -> Result<Tensor> {
199        if self.ndim() != 1 || other.ndim() != 1 {
200            return Err(GhostError::InvalidOperation(
201                "outer requires 1D tensors".to_string()
202            ));
203        }
204        
205        let a = self.data_f32();
206        let b = other.data_f32();
207        let m = a.len();
208        let n = b.len();
209        
210        let result: Vec<f32> = (0..m)
211            .into_par_iter()
212            .flat_map(|i| {
213                b.iter().map(|&bj| a[i] * bj).collect::<Vec<_>>()
214            })
215            .collect();
216        
217        Tensor::from_slice(&result, &[m, n])
218    }
219
220    /// Matrix-vector multiplication
221    pub fn mv(&self, vec: &Tensor) -> Result<Tensor> {
222        if self.ndim() != 2 || vec.ndim() != 1 {
223            return Err(GhostError::InvalidOperation(
224                "mv requires 2D matrix and 1D vector".to_string()
225            ));
226        }
227        
228        let m = self.dims()[0];
229        let n = self.dims()[1];
230        
231        if vec.numel() != n {
232            return Err(GhostError::ShapeMismatch {
233                expected: vec![n],
234                got: vec.dims().to_vec(),
235            });
236        }
237        
238        let mat = self.data_f32();
239        let v = vec.data_f32();
240        
241        let result: Vec<f32> = (0..m)
242            .into_par_iter()
243            .map(|i| {
244                (0..n).map(|j| mat[i * n + j] * v[j]).sum()
245            })
246            .collect();
247        
248        Tensor::from_slice(&result, &[m])
249    }
250
251    /// Batch matrix-matrix multiplication (bmm)
252    pub fn bmm(&self, other: &Tensor) -> Result<Tensor> {
253        if self.ndim() != 3 || other.ndim() != 3 {
254            return Err(GhostError::InvalidOperation(
255                "bmm requires 3D tensors".to_string()
256            ));
257        }
258        
259        self.matmul(other)
260    }
261
262    /// Compute trace of a matrix
263    pub fn trace(&self) -> Result<Tensor> {
264        if self.ndim() != 2 {
265            return Err(GhostError::InvalidOperation(
266                "trace requires 2D tensor".to_string()
267            ));
268        }
269        
270        let dims = self.dims();
271        let n = dims[0].min(dims[1]);
272        let data = self.data_f32();
273        let cols = dims[1];
274        
275        let trace: f32 = (0..n).map(|i| data[i * cols + i]).sum();
276        
277        Tensor::from_slice(&[trace], &[])
278    }
279
280    /// Compute diagonal of a matrix
281    pub fn diag(&self) -> Result<Tensor> {
282        if self.ndim() != 2 {
283            return Err(GhostError::InvalidOperation(
284                "diag requires 2D tensor".to_string()
285            ));
286        }
287        
288        let dims = self.dims();
289        let n = dims[0].min(dims[1]);
290        let data = self.data_f32();
291        let cols = dims[1];
292        
293        let diag: Vec<f32> = (0..n).map(|i| data[i * cols + i]).collect();
294        
295        Tensor::from_slice(&diag, &[n])
296    }
297}
298
299/// Broadcast batch dimensions
300fn broadcast_batch_dims(a: &[usize], b: &[usize]) -> Result<Vec<usize>> {
301    let max_len = a.len().max(b.len());
302    let mut result = Vec::with_capacity(max_len);
303    
304    for i in 0..max_len {
305        let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
306        let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
307        
308        if a_dim == b_dim {
309            result.push(a_dim);
310        } else if a_dim == 1 {
311            result.push(b_dim);
312        } else if b_dim == 1 {
313            result.push(a_dim);
314        } else {
315            return Err(GhostError::BroadcastError {
316                a: a.to_vec(),
317                b: b.to_vec(),
318            });
319        }
320    }
321    
322    result.reverse();
323    Ok(result)
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_matmul_2d() {
332        // [2, 3] x [3, 2] = [2, 2]
333        let a = Tensor::from_slice(
334            &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
335            &[2, 3]
336        ).unwrap();
337        let b = Tensor::from_slice(
338            &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
339            &[3, 2]
340        ).unwrap();
341        
342        let c = a.matmul(&b).unwrap();
343        assert_eq!(c.dims(), &[2, 2]);
344        
345        // Expected: [[22, 28], [49, 64]]
346        let data = c.data_f32();
347        assert_eq!(data[0], 22.0);
348        assert_eq!(data[1], 28.0);
349        assert_eq!(data[2], 49.0);
350        assert_eq!(data[3], 64.0);
351    }
352
353    #[test]
354    fn test_dot() {
355        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
356        let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
357        
358        let dot = a.dot(&b).unwrap();
359        assert_eq!(dot.data_f32()[0], 32.0); // 1*4 + 2*5 + 3*6
360    }
361
362    #[test]
363    fn test_mv() {
364        let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
365        let vec = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
366        
367        let result = mat.mv(&vec).unwrap();
368        assert_eq!(result.dims(), &[2]);
369        assert_eq!(result.data_f32(), vec![5.0, 11.0]); // [1*1+2*2, 3*1+4*2]
370    }
371
372    #[test]
373    fn test_trace() {
374        let mat = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
375        let trace = mat.trace().unwrap();
376        assert_eq!(trace.data_f32()[0], 5.0); // 1 + 4
377    }
378}