1#![cfg_attr(not(feature = "cuda"), allow(unused_variables))]
4
5use crate::error::{CudaError, CudaResult};
6use crate::tensor::CudaTensor;
7use crate::stream::CudaStream;
8
9pub mod elementwise {
11 use super::*;
12
13 pub fn add(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
15 #[cfg(feature = "cuda")]
16 {
17 }
22
23 let a_cpu = a.to_tensor()?;
25 let b_cpu = b.to_tensor()?;
26 let result = a_cpu.add(&b_cpu)
27 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
28 CudaTensor::from_tensor(&result, a.device_id())
29 }
30
31 pub fn sub(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
33 let a_cpu = a.to_tensor()?;
34 let b_cpu = b.to_tensor()?;
35 let result = a_cpu.sub(&b_cpu)
36 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
37 CudaTensor::from_tensor(&result, a.device_id())
38 }
39
40 pub fn mul(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
42 let a_cpu = a.to_tensor()?;
43 let b_cpu = b.to_tensor()?;
44 let result = a_cpu.mul(&b_cpu)
45 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
46 CudaTensor::from_tensor(&result, a.device_id())
47 }
48
49 pub fn div(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
51 let a_cpu = a.to_tensor()?;
52 let b_cpu = b.to_tensor()?;
53 let result = a_cpu.div(&b_cpu)
54 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
55 CudaTensor::from_tensor(&result, a.device_id())
56 }
57
58 pub fn relu(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
60 let x_cpu = x.to_tensor()?;
61 let result = x_cpu.relu();
62 CudaTensor::from_tensor(&result, x.device_id())
63 }
64
65 pub fn sigmoid(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
67 let x_cpu = x.to_tensor()?;
68 let result = x_cpu.sigmoid();
69 CudaTensor::from_tensor(&result, x.device_id())
70 }
71
72 pub fn gelu(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
74 let x_cpu = x.to_tensor()?;
75 let result = x_cpu.gelu();
76 CudaTensor::from_tensor(&result, x.device_id())
77 }
78
79 pub fn softmax(x: &CudaTensor, dim: i32, stream: &CudaStream) -> CudaResult<CudaTensor> {
81 let x_cpu = x.to_tensor()?;
82 let result = x_cpu.softmax(dim);
83 CudaTensor::from_tensor(&result, x.device_id())
84 }
85}
86
87pub mod matmul {
89 use super::*;
90
91 pub fn matmul(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
93 let a_cpu = a.to_tensor()?;
96 let b_cpu = b.to_tensor()?;
97 let result = a_cpu.matmul(&b_cpu)
98 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
99 CudaTensor::from_tensor(&result, a.device_id())
100 }
101
102 pub fn bmm(a: &CudaTensor, b: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
104 let a_cpu = a.to_tensor()?;
106 let b_cpu = b.to_tensor()?;
107 let result = a_cpu.matmul(&b_cpu)
108 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
109 CudaTensor::from_tensor(&result, a.device_id())
110 }
111}
112
113pub mod reduction {
115 use super::*;
116
117 pub fn sum(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
119 let x_cpu = x.to_tensor()?;
120 let result = x_cpu.sum();
121 CudaTensor::from_tensor(&result, x.device_id())
122 }
123
124 pub fn mean(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
126 let x_cpu = x.to_tensor()?;
127 let result = x_cpu.mean();
128 CudaTensor::from_tensor(&result, x.device_id())
129 }
130
131 pub fn max(x: &CudaTensor, stream: &CudaStream) -> CudaResult<CudaTensor> {
133 let x_cpu = x.to_tensor()?;
134 let result = x_cpu.max();
135 CudaTensor::from_tensor(&result, x.device_id())
136 }
137}
138
139pub mod conv {
141 use super::*;
142
143 pub fn conv2d(
145 input: &CudaTensor,
146 weight: &CudaTensor,
147 bias: Option<&CudaTensor>,
148 stride: (usize, usize),
149 padding: (usize, usize),
150 stream: &CudaStream,
151 ) -> CudaResult<CudaTensor> {
152 Err(CudaError::InvalidValue("Conv2d not yet implemented".into()))
155 }
156}
157
158pub mod attention {
160 use super::*;
161
162 pub fn flash_attention(
167 q: &CudaTensor, k: &CudaTensor,
169 v: &CudaTensor,
170 scale: f32,
171 causal: bool,
172 stream: &CudaStream,
173 ) -> CudaResult<CudaTensor> {
174 let q_cpu = q.to_tensor()?;
178 let k_cpu = k.to_tensor()?;
179 let v_cpu = v.to_tensor()?;
180
181 let k_t = k_cpu.transpose(2, 3)
183 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
184
185 let scores = q_cpu.matmul(&k_t)
186 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
187
188 let scaled = scores.mul_scalar(scale);
189
190 let masked = if causal {
192 apply_causal_mask(&scaled)?
193 } else {
194 scaled
195 };
196
197 let attn_weights = masked.softmax(-1);
198
199 let output = attn_weights.matmul(&v_cpu)
200 .map_err(|e| CudaError::InvalidValue(e.to_string()))?;
201
202 CudaTensor::from_tensor(&output, q.device_id())
203 }
204
205 fn apply_causal_mask(scores: &ghostflow_core::Tensor) -> CudaResult<ghostflow_core::Tensor> {
206 let dims = scores.dims();
207 let seq_len = dims[dims.len() - 1];
208 let mut data = scores.data_f32();
209
210 let batch_size: usize = dims[..dims.len()-2].iter().product();
212 let matrix_size = seq_len * seq_len;
213
214 for b in 0..batch_size {
215 for i in 0..seq_len {
216 for j in (i + 1)..seq_len {
217 data[b * matrix_size + i * seq_len + j] = f32::NEG_INFINITY;
218 }
219 }
220 }
221
222 ghostflow_core::Tensor::from_slice(&data, dims)
223 .map_err(|e| CudaError::InvalidValue(e.to_string()))
224 }
225}