1use super::{
7 BufferHandle, GpuBackendTrait, MemoryManagerTrait, NeuralIntegrationError,
8 NeuralOperation, NeuralResult, ActivationFunction,
9};
10use std::sync::Arc;
11
12pub fn execute_operation<T>(
14 operation: NeuralOperation<T>,
15 inputs: &[T],
16 backend: &Arc<dyn GpuBackendTrait>,
17 memory_manager: &Arc<dyn MemoryManagerTrait>,
18) -> NeuralResult<Vec<T>>
19where
20 T: Clone + Send + Sync + 'static + bytemuck::Pod,
21{
22 match operation {
23 NeuralOperation::MatrixMultiply { a_rows, a_cols, b_cols, _phantom } => {
24 execute_matrix_multiply(a_rows, a_cols, b_cols, inputs, backend, memory_manager)
25 }
26
27 NeuralOperation::VectorAdd { size, _phantom } => {
28 execute_vector_add(size, inputs, backend, memory_manager)
29 }
30
31 NeuralOperation::ActivationFunction { function, size, _phantom } => {
32 execute_activation_function(function, size, inputs, backend, memory_manager)
33 }
34
35 NeuralOperation::Convolution { channels, kernel_size, stride, _phantom } => {
36 execute_convolution(channels, kernel_size, stride, inputs, backend, memory_manager)
37 }
38
39 NeuralOperation::ForwardPropagation { layer_sizes, _phantom } => {
40 execute_forward_propagation(&layer_sizes, inputs, backend, memory_manager)
41 }
42
43 NeuralOperation::BackwardPropagation { layer_sizes, _phantom } => {
44 execute_backward_propagation(&layer_sizes, inputs, backend, memory_manager)
45 }
46
47 NeuralOperation::Custom { kernel_source, name, _phantom } => {
48 execute_custom_kernel(&kernel_source, &name, inputs, backend, memory_manager)
49 }
50 }
51}
52
53pub fn process_batch<T>(
55 operations: Vec<NeuralOperation<T>>,
56 inputs: Vec<Vec<T>>,
57 backend: &Option<Arc<dyn GpuBackendTrait>>,
58 memory_manager: &Arc<dyn MemoryManagerTrait>,
59 batch_size: usize,
60) -> NeuralResult<Vec<Vec<T>>>
61where
62 T: Clone + Send + Sync + 'static + bytemuck::Pod + num_traits::Float,
63{
64 if operations.len() != inputs.len() {
65 return Err(NeuralIntegrationError::OperationError(
66 "Operations and inputs count mismatch".to_string()
67 ));
68 }
69
70 let mut results = Vec::with_capacity(operations.len());
71
72 if let Some(backend) = backend {
73 for chunk in operations.chunks(batch_size).zip(inputs.chunks(batch_size)) {
75 let (ops, ins) = chunk;
76 let batch_results = execute_batch_gpu(ops, ins, backend, memory_manager)?;
77 results.extend(batch_results);
78 }
79 } else {
80 for (operation, input) in operations.into_iter().zip(inputs.into_iter()) {
82 let result = super::bridge::execute_cpu_fallback(operation, &input)?;
83 results.push(result);
84 }
85 }
86
87 Ok(results)
88}
89
90fn execute_matrix_multiply<T>(
92 a_rows: usize,
93 a_cols: usize,
94 b_cols: usize,
95 inputs: &[T],
96 backend: &Arc<dyn GpuBackendTrait>,
97 memory_manager: &Arc<dyn MemoryManagerTrait>,
98) -> NeuralResult<Vec<T>>
99where
100 T: Clone + Send + Sync + 'static + bytemuck::Pod,
101{
102 let expected_size = a_rows * a_cols + a_cols * b_cols;
104 if inputs.len() < expected_size {
105 return Err(NeuralIntegrationError::OperationError(
106 format!("Expected {} elements, got {}", expected_size, inputs.len())
107 ));
108 }
109
110 let input_bytes: &[u8] = bytemuck::cast_slice(inputs);
112
113 let input_buffer = memory_manager.transfer_to_gpu(
115 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
116 )?;
117
118 let kernel_source = generate_matrix_multiply_kernel(a_rows, a_cols, b_cols);
120 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
121
122 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
124
125 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
127 let result: Vec<T> = result_f32.iter()
128 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
129 .collect();
130
131 Ok(result)
132}
133
134fn execute_vector_add<T>(
136 size: usize,
137 inputs: &[T],
138 backend: &Arc<dyn GpuBackendTrait>,
139 memory_manager: &Arc<dyn MemoryManagerTrait>,
140) -> NeuralResult<Vec<T>>
141where
142 T: Clone + Send + Sync + 'static + bytemuck::Pod,
143{
144 if inputs.len() < size * 2 {
145 return Err(NeuralIntegrationError::OperationError(
146 format!("Expected {} elements, got {}", size * 2, inputs.len())
147 ));
148 }
149
150 let input_buffer = memory_manager.transfer_to_gpu(
152 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
153 )?;
154
155 let kernel_source = generate_vector_add_kernel(size);
157 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
158
159 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
161
162 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
164 let result: Vec<T> = result_f32.iter()
165 .take(size)
166 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
167 .collect();
168
169 Ok(result)
170}
171
172fn execute_activation_function<T>(
174 function: ActivationFunction,
175 size: usize,
176 inputs: &[T],
177 backend: &Arc<dyn GpuBackendTrait>,
178 memory_manager: &Arc<dyn MemoryManagerTrait>,
179) -> NeuralResult<Vec<T>>
180where
181 T: Clone + Send + Sync + 'static + bytemuck::Pod,
182{
183 if inputs.len() < size {
184 return Err(NeuralIntegrationError::OperationError(
185 format!("Expected {} elements, got {}", size, inputs.len())
186 ));
187 }
188
189 let input_buffer = memory_manager.transfer_to_gpu(
191 &inputs.iter().take(size).map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
192 )?;
193
194 let kernel_source = generate_activation_kernel(function, size);
196 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
197
198 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
200
201 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
203 let result: Vec<T> = result_f32.iter()
204 .take(size)
205 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
206 .collect();
207
208 Ok(result)
209}
210
211fn execute_convolution<T>(
213 channels: usize,
214 kernel_size: usize,
215 stride: usize,
216 inputs: &[T],
217 backend: &Arc<dyn GpuBackendTrait>,
218 memory_manager: &Arc<dyn MemoryManagerTrait>,
219) -> NeuralResult<Vec<T>>
220where
221 T: Clone + Send + Sync + 'static + bytemuck::Pod,
222{
223 let input_buffer = memory_manager.transfer_to_gpu(
225 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
226 )?;
227
228 let kernel_source = generate_convolution_kernel(channels, kernel_size, stride);
230 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
231
232 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
234
235 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
237 let result: Vec<T> = result_f32.iter()
238 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
239 .collect();
240
241 Ok(result)
242}
243
244fn execute_forward_propagation<T>(
246 layer_sizes: &[usize],
247 inputs: &[T],
248 backend: &Arc<dyn GpuBackendTrait>,
249 memory_manager: &Arc<dyn MemoryManagerTrait>,
250) -> NeuralResult<Vec<T>>
251where
252 T: Clone + Send + Sync + 'static + bytemuck::Pod,
253{
254 let input_buffer = memory_manager.transfer_to_gpu(
256 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
257 )?;
258
259 let kernel_source = generate_forward_propagation_kernel(layer_sizes);
261 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
262
263 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
265
266 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
268 let result: Vec<T> = result_f32.iter()
269 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
270 .collect();
271
272 Ok(result)
273}
274
275fn execute_backward_propagation<T>(
277 layer_sizes: &[usize],
278 inputs: &[T],
279 backend: &Arc<dyn GpuBackendTrait>,
280 memory_manager: &Arc<dyn MemoryManagerTrait>,
281) -> NeuralResult<Vec<T>>
282where
283 T: Clone + Send + Sync + 'static + bytemuck::Pod,
284{
285 let input_buffer = memory_manager.transfer_to_gpu(
287 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
288 )?;
289
290 let kernel_source = generate_backward_propagation_kernel(layer_sizes);
292 let kernel = super::bridge::extract_wgsl_from_rust(&kernel_source)?;
293
294 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
296
297 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
299 let result: Vec<T> = result_f32.iter()
300 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
301 .collect();
302
303 Ok(result)
304}
305
306fn execute_custom_kernel<T>(
308 kernel_source: &str,
309 name: &str,
310 inputs: &[T],
311 backend: &Arc<dyn GpuBackendTrait>,
312 memory_manager: &Arc<dyn MemoryManagerTrait>,
313) -> NeuralResult<Vec<T>>
314where
315 T: Clone + Send + Sync + 'static + bytemuck::Pod,
316{
317 let input_buffer = memory_manager.transfer_to_gpu(
319 &inputs.iter().map(|&x| unsafe { std::mem::transmute_copy(&x) }).collect::<Vec<f32>>()
320 )?;
321
322 let kernel = super::bridge::extract_wgsl_from_rust(kernel_source)?;
324
325 let output_buffer = backend.execute_kernel(&kernel, &[input_buffer])?;
327
328 let result_f32 = memory_manager.transfer_from_gpu(output_buffer)?;
330 let result: Vec<T> = result_f32.iter()
331 .map(|&x| unsafe { std::mem::transmute_copy(&x) })
332 .collect();
333
334 Ok(result)
335}
336
337fn execute_batch_gpu<T>(
339 operations: &[NeuralOperation<T>],
340 inputs: &[Vec<T>],
341 backend: &Arc<dyn GpuBackendTrait>,
342 memory_manager: &Arc<dyn MemoryManagerTrait>,
343) -> NeuralResult<Vec<Vec<T>>>
344where
345 T: Clone + Send + Sync + 'static + bytemuck::Pod,
346{
347 let mut results = Vec::with_capacity(operations.len());
348
349 for (operation, input) in operations.iter().zip(inputs.iter()) {
350 let result = execute_operation(operation.clone(), input, backend, memory_manager)?;
351 results.push(result);
352 }
353
354 Ok(results)
355}
356
357fn generate_matrix_multiply_kernel(a_rows: usize, a_cols: usize, b_cols: usize) -> String {
359 format!(r#"
360 __global__ void matrix_multiply(float* a, float* b, float* c) {{
361 int row = blockIdx.y * blockDim.y + threadIdx.y;
362 int col = blockIdx.x * blockDim.x + threadIdx.x;
363
364 if (row < {a_rows} && col < {b_cols}) {{
365 float sum = 0.0f;
366 for (int k = 0; k < {a_cols}; k++) {{
367 sum += a[row * {a_cols} + k] * b[k * {b_cols} + col];
368 }}
369 c[row * {b_cols} + col] = sum;
370 }}
371 }}
372 "#)
373}
374
375fn generate_vector_add_kernel(size: usize) -> String {
377 format!(r#"
378 __global__ void vector_add(float* a, float* b, float* c) {{
379 int i = blockIdx.x * blockDim.x + threadIdx.x;
380 if (i < {size}) {{
381 c[i] = a[i] + b[i];
382 }}
383 }}
384 "#)
385}
386
387fn generate_activation_kernel(function: ActivationFunction, size: usize) -> String {
389 let activation_code = match function {
390 ActivationFunction::Sigmoid => "1.0f / (1.0f + expf(-x))",
391 ActivationFunction::ReLU => "fmaxf(0.0f, x)",
392 ActivationFunction::Tanh => "tanhf(x)",
393 ActivationFunction::LeakyReLU => "x > 0.0f ? x : 0.01f * x",
394 ActivationFunction::Swish => "x / (1.0f + expf(-x))",
395 ActivationFunction::GELU => "0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x)))",
396 };
397
398 format!(r#"
399 __global__ void activation_function(float* input, float* output) {{
400 int i = blockIdx.x * blockDim.x + threadIdx.x;
401 if (i < {size}) {{
402 float x = input[i];
403 output[i] = {activation_code};
404 }}
405 }}
406 "#)
407}
408
409fn generate_convolution_kernel(channels: usize, kernel_size: usize, stride: usize) -> String {
411 format!(r#"
412 __global__ void convolution(float* input, float* kernel, float* output,
413 int input_width, int input_height, int output_width, int output_height) {{
414 int out_x = blockIdx.x * blockDim.x + threadIdx.x;
415 int out_y = blockIdx.y * blockDim.y + threadIdx.y;
416 int channel = blockIdx.z;
417
418 if (out_x < output_width && out_y < output_height && channel < {channels}) {{
419 float sum = 0.0f;
420
421 for (int ky = 0; ky < {kernel_size}; ky++) {{
422 for (int kx = 0; kx < {kernel_size}; kx++) {{
423 int in_x = out_x * {stride} + kx;
424 int in_y = out_y * {stride} + ky;
425
426 if (in_x < input_width && in_y < input_height) {{
427 int input_idx = channel * input_width * input_height + in_y * input_width + in_x;
428 int kernel_idx = channel * {kernel_size} * {kernel_size} + ky * {kernel_size} + kx;
429 sum += input[input_idx] * kernel[kernel_idx];
430 }}
431 }}
432 }}
433
434 int output_idx = channel * output_width * output_height + out_y * output_width + out_x;
435 output[output_idx] = sum;
436 }}
437 }}
438 "#)
439}
440
441fn generate_forward_propagation_kernel(layer_sizes: &[usize]) -> String {
443 let num_layers = layer_sizes.len();
444 let weights_calculation = (0..num_layers-1).map(|i| {
445 format!(r#"
446 // Layer {} to {}
447 if (layer == {}) {{
448 for (int j = 0; j < {}; j++) {{
449 float sum = 0.0f;
450 for (int k = 0; k < {}; k++) {{
451 sum += activations[prev_layer_offset + k] * weights[weight_offset + j * {} + k];
452 }}
453 activations[current_layer_offset + j] = 1.0f / (1.0f + expf(-sum)); // Sigmoid
454 }}
455 }}
456 "#, i, i+1, i, layer_sizes[i+1], layer_sizes[i], layer_sizes[i])
457 }).collect::<Vec<_>>().join("\n");
458
459 format!(r#"
460 __global__ void forward_propagation(float* inputs, float* weights, float* biases, float* activations) {{
461 int neuron = blockIdx.x * blockDim.x + threadIdx.x;
462 int layer = blockIdx.y;
463
464 // Copy inputs to first layer
465 if (layer == 0 && neuron < {}) {{
466 activations[neuron] = inputs[neuron];
467 return;
468 }}
469
470 // Calculate layer offsets
471 int prev_layer_offset = 0;
472 int current_layer_offset = 0;
473 int weight_offset = 0;
474
475 for (int i = 0; i < layer; i++) {{
476 if (i < layer - 1) prev_layer_offset += layer_sizes[i];
477 current_layer_offset += layer_sizes[i];
478 if (i < layer - 1) weight_offset += layer_sizes[i] * layer_sizes[i + 1];
479 }}
480
481 {}
482 }}
483 "#, layer_sizes[0], weights_calculation)
484}
485
486fn generate_backward_propagation_kernel(layer_sizes: &[usize]) -> String {
488 format!(r#"
489 __global__ void backward_propagation(float* activations, float* weights,
490 float* errors, float* gradients, float* targets) {{
491 int neuron = blockIdx.x * blockDim.x + threadIdx.x;
492 int layer = blockIdx.y;
493
494 // Calculate output layer errors
495 if (layer == {} - 1) {{
496 if (neuron < {}) {{
497 float output = activations[neuron]; // Assuming output layer offset
498 float target = targets[neuron];
499 errors[neuron] = (output - target) * output * (1.0f - output); // Sigmoid derivative
500 }}
501 return;
502 }}
503
504 // Backpropagate errors for hidden layers
505 // Implementation depends on network architecture
506 // This is a simplified version
507 if (layer > 0 && neuron < layer_sizes[layer]) {{
508 float error_sum = 0.0f;
509 for (int next_neuron = 0; next_neuron < layer_sizes[layer + 1]; next_neuron++) {{
510 error_sum += errors[next_neuron] * weights[neuron * layer_sizes[layer + 1] + next_neuron];
511 }}
512 float activation = activations[neuron];
513 errors[neuron] = error_sum * activation * (1.0f - activation);
514 }}
515 }}
516 "#, layer_sizes.len(), layer_sizes.last().unwrap_or(&0))
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_kernel_generation() {
525 let kernel = generate_matrix_multiply_kernel(4, 4, 4);
526 assert!(kernel.contains("matrix_multiply"));
527 assert!(kernel.contains("blockIdx"));
528 assert!(kernel.contains("threadIdx"));
529 }
530
531 #[test]
532 fn test_activation_kernel_generation() {
533 let kernel = generate_activation_kernel(ActivationFunction::ReLU, 128);
534 assert!(kernel.contains("fmaxf"));
535 assert!(kernel.contains("activation_function"));
536 }
537
538 #[test]
539 fn test_vector_add_kernel() {
540 let kernel = generate_vector_add_kernel(256);
541 assert!(kernel.contains("vector_add"));
542 assert!(kernel.contains("256"));
543 }
544}