1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/// CPU implementation of GEMM for f32
/// Computes Y = alpha * op(A) * op(B) + beta * C
/// where op(X) is either X or X^T depending on transpose flags
pub fn f32_f32_f32_f32_cpu(
a_dims: Vec<usize>,
b_dims: Vec<usize>,
y_dims: Vec<usize>,
a_bytes: &[u8],
b_bytes: &[u8],
c_bytes: Option<&[u8]>,
y_bytes: &mut [u8],
alpha: f32,
beta: f32,
trans_a: bool,
trans_b: bool,
) {
assert_eq!(a_dims.len(), 2, "GEMM: A must be 2D");
assert_eq!(b_dims.len(), 2, "GEMM: B must be 2D");
assert_eq!(y_dims.len(), 2, "GEMM: Y must be 2D");
// Cast byte slices to f32 slices
let a_data =
unsafe { std::slice::from_raw_parts(a_bytes.as_ptr() as *const f32, a_bytes.len() / 4) };
let b_data =
unsafe { std::slice::from_raw_parts(b_bytes.as_ptr() as *const f32, b_bytes.len() / 4) };
let y_data = unsafe {
std::slice::from_raw_parts_mut(y_bytes.as_mut_ptr() as *mut f32, y_bytes.len() / 4)
};
// Determine dimensions
// A is (M, K) or (K, M) if trans_a
// B is (K, N) or (N, K) if trans_b
// Y is (M, N)
let (m, k_a) = if trans_a {
(a_dims[1], a_dims[0])
} else {
(a_dims[0], a_dims[1])
};
let (k_b, n) = if trans_b {
(b_dims[1], b_dims[0])
} else {
(b_dims[0], b_dims[1])
};
assert_eq!(k_a, k_b, "GEMM: K dimensions must match");
let k = k_a;
assert_eq!(y_dims[0], m, "GEMM: Y row dimension must equal M");
assert_eq!(y_dims[1], n, "GEMM: Y col dimension must equal N");
// Helper to get element from A (considering transpose)
let get_a = |row: usize, col: usize| -> f32 {
let idx = if trans_a {
// A is stored as (K, M), we want A[row][col] = A_stored[col][row]
col * a_dims[1] + row
} else {
// A is stored as (M, K), we want A[row][col]
row * a_dims[1] + col
};
a_data[idx]
};
// Helper to get element from B (considering transpose)
let get_b = |row: usize, col: usize| -> f32 {
let idx = if trans_b {
// B is stored as (N, K), we want B[row][col] = B_stored[col][row]
col * b_dims[1] + row
} else {
// B is stored as (K, N), we want B[row][col]
row * b_dims[1] + col
};
b_data[idx]
};
// Initialize Y with beta * C if C is provided, otherwise zero
if let Some(c_bytes_slice) = c_bytes {
let c_data = unsafe {
std::slice::from_raw_parts(
c_bytes_slice.as_ptr() as *const f32,
c_bytes_slice.len() / 4,
)
};
// C should be broadcastable to (M, N)
// For simplicity, we'll assume C is either:
// - scalar (1 element)
// - 1D with N elements (broadcast across rows)
// - 2D with shape (M, N)
let c_len = c_data.len();
for i in 0..m {
for j in 0..n {
let c_val = if c_len == 1 {
// Scalar
c_data[0]
} else if c_len == n {
// 1D array broadcast across rows
c_data[j]
} else if c_len == m * n {
// Full 2D array
c_data[i * n + j]
} else {
panic!("GEMM: C tensor shape not broadcastable to ({}, {})", m, n);
};
y_data[i * n + j] = beta * c_val;
}
}
} else {
// No C tensor, initialize Y to zero
for val in y_data.iter_mut() {
*val = 0.0;
}
}
// Compute Y += alpha * A * B
for i in 0..m {
for j in 0..n {
let mut sum = 0.0f32;
for p in 0..k {
sum += get_a(i, p) * get_b(p, j);
}
y_data[i * n + j] += alpha * sum;
}
}
}