1use crate::GpuError;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Instant;
11use thiserror::Error;
12use wgpu::util::DeviceExt;
13
14#[derive(Error, Debug)]
15pub enum UnifiedGpuError {
16 #[error("GPU error: {0}")]
17 Gpu(#[from] GpuError),
18
19 #[error("Shader compilation failed: {0}")]
20 ShaderCompilation(String),
21
22 #[error("Buffer size mismatch: expected {expected}, got {actual}")]
23 BufferSizeMismatch { expected: usize, actual: usize },
24
25 #[error("Invalid operation: {0}")]
26 InvalidOperation(String),
27
28 #[error("Memory allocation failed: {0}")]
29 MemoryAllocation(String),
30}
31
32pub type UnifiedGpuResult<T> = Result<T, UnifiedGpuError>;
33
34pub trait GpuAccelerated<T> {
36 fn to_gpu_buffer(&self, context: &GpuContext) -> UnifiedGpuResult<wgpu::Buffer>;
38
39 fn from_gpu_buffer(buffer: &wgpu::Buffer, context: &GpuContext) -> UnifiedGpuResult<T>;
41
42 fn gpu_operation(
44 &self,
45 operation: &str,
46 context: &GpuContext,
47 params: &GpuOperationParams,
48 ) -> UnifiedGpuResult<T>;
49}
50
51#[derive(Debug, Clone)]
53pub struct GpuOperationParams {
54 pub params: HashMap<String, GpuParam>,
56 pub batch_size: usize,
58 pub workgroup_size: (u32, u32, u32),
60}
61
62#[derive(Debug, Clone)]
64pub enum GpuParam {
65 Float(f32),
66 Double(f64),
67 Integer(i32),
68 UnsignedInteger(u32),
69 Buffer(String), Array(Vec<f32>),
71}
72
73impl Default for GpuOperationParams {
74 fn default() -> Self {
75 Self {
76 params: HashMap::new(),
77 batch_size: 1,
78 workgroup_size: (1, 1, 1),
79 }
80 }
81}
82
83pub struct GpuContext {
85 pub device: wgpu::Device,
86 pub queue: wgpu::Queue,
87 shader_cache: HashMap<String, wgpu::ComputePipeline>,
88 #[allow(dead_code)]
89 buffer_pool: GpuBufferPool,
90}
91
92impl GpuContext {
93 pub async fn new() -> UnifiedGpuResult<Self> {
95 let instance = wgpu::Instance::default();
96
97 let adapter = instance
98 .request_adapter(&wgpu::RequestAdapterOptions {
99 power_preference: wgpu::PowerPreference::HighPerformance,
100 compatible_surface: None,
101 force_fallback_adapter: false,
102 })
103 .await
104 .ok_or_else(|| {
105 UnifiedGpuError::Gpu(GpuError::InitializationError(
106 "No GPU adapter found".to_string(),
107 ))
108 })?;
109
110 let (device, queue) = adapter
111 .request_device(
112 &wgpu::DeviceDescriptor {
113 label: Some("Amari Unified GPU Device"),
114 required_features: wgpu::Features::empty(),
115 required_limits: wgpu::Limits::default(),
116 },
117 None,
118 )
119 .await
120 .map_err(|e| UnifiedGpuError::Gpu(GpuError::InitializationError(e.to_string())))?;
121
122 Ok(Self {
123 device,
124 queue,
125 shader_cache: HashMap::new(),
126 buffer_pool: GpuBufferPool::new(),
127 })
128 }
129
130 pub fn get_compute_pipeline(
132 &mut self,
133 shader_key: &str,
134 shader_source: &str,
135 bind_group_layout: &wgpu::BindGroupLayout,
136 ) -> UnifiedGpuResult<&wgpu::ComputePipeline> {
137 if !self.shader_cache.contains_key(shader_key) {
138 let shader_module = self
139 .device
140 .create_shader_module(wgpu::ShaderModuleDescriptor {
141 label: Some(&format!("{} Shader", shader_key)),
142 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
143 });
144
145 let pipeline_layout =
146 self.device
147 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
148 label: Some(&format!("{} Pipeline Layout", shader_key)),
149 bind_group_layouts: &[bind_group_layout],
150 push_constant_ranges: &[],
151 });
152
153 let compute_pipeline =
154 self.device
155 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
156 label: Some(&format!("{} Pipeline", shader_key)),
157 layout: Some(&pipeline_layout),
158 module: &shader_module,
159 entry_point: "main",
160 });
161
162 self.shader_cache
163 .insert(shader_key.to_string(), compute_pipeline);
164 }
165
166 Ok(self
167 .shader_cache
168 .get(shader_key)
169 .expect("Pipeline should exist"))
170 }
171
172 pub fn create_buffer_with_data<T: bytemuck::Pod>(
174 &self,
175 label: &str,
176 data: &[T],
177 usage: wgpu::BufferUsages,
178 ) -> wgpu::Buffer {
179 self.device
180 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
181 label: Some(label),
182 contents: bytemuck::cast_slice(data),
183 usage,
184 })
185 }
186
187 pub fn create_buffer(&self, label: &str, size: u64, usage: wgpu::BufferUsages) -> wgpu::Buffer {
189 self.device.create_buffer(&wgpu::BufferDescriptor {
190 label: Some(label),
191 size,
192 usage,
193 mapped_at_creation: false,
194 })
195 }
196
197 pub fn execute_compute(
199 &self,
200 pipeline: &wgpu::ComputePipeline,
201 bind_group: &wgpu::BindGroup,
202 workgroup_count: (u32, u32, u32),
203 ) {
204 let mut encoder = self
205 .device
206 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
207 label: Some("Compute Encoder"),
208 });
209
210 {
211 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
212 label: Some("Compute Pass"),
213 timestamp_writes: None,
214 });
215 compute_pass.set_pipeline(pipeline);
216 compute_pass.set_bind_group(0, bind_group, &[]);
217 compute_pass.dispatch_workgroups(
218 workgroup_count.0,
219 workgroup_count.1,
220 workgroup_count.2,
221 );
222 }
223
224 self.queue.submit([encoder.finish()]);
225 }
226
227 pub async fn read_buffer<T: bytemuck::Pod + Clone>(
229 &self,
230 buffer: &wgpu::Buffer,
231 size: u64,
232 ) -> UnifiedGpuResult<Vec<T>> {
233 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
234 label: Some("Staging Buffer"),
235 size,
236 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
237 mapped_at_creation: false,
238 });
239
240 let mut encoder = self
241 .device
242 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
243 label: Some("Copy Encoder"),
244 });
245
246 encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, size);
247 self.queue.submit([encoder.finish()]);
248
249 let buffer_slice = staging_buffer.slice(..);
250 let (tx, rx) = futures::channel::oneshot::channel();
251 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
252 tx.send(result).ok();
253 });
254
255 self.device.poll(wgpu::Maintain::Wait);
256
257 rx.await
258 .map_err(|_| UnifiedGpuError::InvalidOperation("Buffer read timeout".to_string()))?
259 .map_err(|e| UnifiedGpuError::InvalidOperation(format!("Buffer map failed: {}", e)))?;
260
261 let data = buffer_slice.get_mapped_range();
262 let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
263 drop(data);
264 staging_buffer.unmap();
265
266 Ok(result)
267 }
268}
269
270pub struct GpuBufferPool {
272 _pools: HashMap<String, Vec<wgpu::Buffer>>, }
274
275impl GpuBufferPool {
276 pub fn new() -> Self {
277 Self {
278 _pools: HashMap::new(),
279 }
280 }
281
282 }
286
287impl Default for GpuBufferPool {
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293#[derive(Clone)]
295pub struct SharedGpuContext {
296 device: Arc<wgpu::Device>,
297 queue: Arc<wgpu::Queue>,
298 adapter_info: wgpu::AdapterInfo,
299 buffer_pool: Arc<std::sync::Mutex<EnhancedGpuBufferPool>>,
300 shader_cache: Arc<std::sync::Mutex<HashMap<String, Arc<wgpu::ComputePipeline>>>>,
301 creation_time: Instant,
302}
303
304impl SharedGpuContext {
305 pub async fn global() -> UnifiedGpuResult<&'static Self> {
309 let context = Self::new().await?;
310 Ok(Box::leak(Box::new(context)))
312 }
313
314 async fn new() -> UnifiedGpuResult<Self> {
316 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
317 backends: wgpu::Backends::all(),
318 flags: wgpu::InstanceFlags::default(),
319 dx12_shader_compiler: wgpu::Dx12Compiler::default(),
320 gles_minor_version: wgpu::Gles3MinorVersion::Automatic,
321 });
322
323 let adapter = instance
324 .request_adapter(&wgpu::RequestAdapterOptions {
325 power_preference: wgpu::PowerPreference::HighPerformance,
326 compatible_surface: None,
327 force_fallback_adapter: false,
328 })
329 .await
330 .ok_or_else(|| {
331 UnifiedGpuError::InvalidOperation("No suitable GPU adapter found".into())
332 })?;
333
334 let adapter_info = adapter.get_info();
335
336 let (device, queue) = adapter
337 .request_device(
338 &wgpu::DeviceDescriptor {
339 label: Some("Shared Amari GPU Device"),
340 required_features: wgpu::Features::TIMESTAMP_QUERY,
341 required_limits: wgpu::Limits::default(),
342 },
343 None,
344 )
345 .await
346 .map_err(|e| {
347 UnifiedGpuError::InvalidOperation(format!("Device request failed: {:?}", e))
348 })?;
349
350 Ok(Self {
351 device: Arc::new(device),
352 queue: Arc::new(queue),
353 adapter_info,
354 buffer_pool: Arc::new(std::sync::Mutex::new(EnhancedGpuBufferPool::new())),
355 shader_cache: Arc::new(std::sync::Mutex::new(HashMap::new())),
356 creation_time: Instant::now(),
357 })
358 }
359
360 pub fn device(&self) -> &wgpu::Device {
362 &self.device
363 }
364
365 pub fn queue(&self) -> &wgpu::Queue {
367 &self.queue
368 }
369
370 pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
372 &self.adapter_info
373 }
374
375 pub fn get_buffer(
377 &self,
378 size: u64,
379 usage: wgpu::BufferUsages,
380 label: Option<&str>,
381 ) -> wgpu::Buffer {
382 if let Ok(mut pool) = self.buffer_pool.lock() {
383 pool.get_or_create(&self.device, size, usage, label)
384 } else {
385 self.device.create_buffer(&wgpu::BufferDescriptor {
387 label,
388 size,
389 usage,
390 mapped_at_creation: false,
391 })
392 }
393 }
394
395 pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
397 if let Ok(mut pool) = self.buffer_pool.lock() {
398 pool.return_buffer(buffer, size, usage);
399 }
400 }
402
403 pub fn get_compute_pipeline(
405 &self,
406 shader_key: &str,
407 shader_source: &str,
408 entry_point: &str,
409 ) -> UnifiedGpuResult<Arc<wgpu::ComputePipeline>> {
410 let cache_key = format!("{}:{}", shader_key, entry_point);
411
412 if let Ok(mut cache) = self.shader_cache.lock() {
413 if let Some(pipeline) = cache.get(&cache_key) {
414 return Ok(Arc::clone(pipeline));
415 }
416
417 let shader_module = self
419 .device
420 .create_shader_module(wgpu::ShaderModuleDescriptor {
421 label: Some(&format!("{} Shader", shader_key)),
422 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
423 });
424
425 let bind_group_layout =
426 self.device
427 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
428 label: Some(&format!("{} Bind Group Layout", shader_key)),
429 entries: &[
430 wgpu::BindGroupLayoutEntry {
431 binding: 0,
432 visibility: wgpu::ShaderStages::COMPUTE,
433 ty: wgpu::BindingType::Buffer {
434 ty: wgpu::BufferBindingType::Storage { read_only: true },
435 has_dynamic_offset: false,
436 min_binding_size: None,
437 },
438 count: None,
439 },
440 wgpu::BindGroupLayoutEntry {
441 binding: 1,
442 visibility: wgpu::ShaderStages::COMPUTE,
443 ty: wgpu::BindingType::Buffer {
444 ty: wgpu::BufferBindingType::Storage { read_only: false },
445 has_dynamic_offset: false,
446 min_binding_size: None,
447 },
448 count: None,
449 },
450 ],
451 });
452
453 let pipeline_layout =
454 self.device
455 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
456 label: Some(&format!("{} Pipeline Layout", shader_key)),
457 bind_group_layouts: &[&bind_group_layout],
458 push_constant_ranges: &[],
459 });
460
461 let pipeline = self
462 .device
463 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
464 label: Some(&format!("{} Pipeline", shader_key)),
465 layout: Some(&pipeline_layout),
466 module: &shader_module,
467 entry_point,
468 });
469
470 let pipeline_arc = Arc::new(pipeline);
471 cache.insert(cache_key, Arc::clone(&pipeline_arc));
472 Ok(pipeline_arc)
473 } else {
474 Err(UnifiedGpuError::InvalidOperation(
475 "Failed to access shader cache".into(),
476 ))
477 }
478 }
479
480 pub fn buffer_pool_stats(&self) -> BufferPoolStats {
482 if let Ok(pool) = self.buffer_pool.lock() {
483 pool.get_stats()
484 } else {
485 BufferPoolStats::default()
486 }
487 }
488
489 pub fn uptime(&self) -> std::time::Duration {
491 self.creation_time.elapsed()
492 }
493
494 pub fn get_optimal_workgroup(&self, operation: &str, data_size: usize) -> (u32, u32, u32) {
496 match operation {
497 "matrix_multiply" | "matrix_operation" => {
498 (16, 16, 1)
501 }
502 "vector_operation" | "reduce" | "scan" => {
503 let workgroup_size = if data_size > 10000 {
505 256 } else if data_size > 1000 {
507 128 } else {
509 64 };
511 (workgroup_size, 1, 1)
512 }
513 "geometric_algebra" | "clifford_algebra" => {
514 (128, 1, 1)
516 }
517 "cellular_automata" | "ca_evolution" => {
518 (16, 16, 1)
520 }
521 "neural_network" | "batch_processing" => {
522 (256, 1, 1)
524 }
525 "information_geometry" | "fisher_information" | "bregman_divergence" => {
526 (256, 1, 1)
528 }
529 "tropical_algebra" | "tropical_matrix" => {
530 (128, 1, 1)
532 }
533 "dual_number" | "automatic_differentiation" => {
534 (128, 1, 1)
536 }
537 "fusion_system" | "llm_evaluation" => {
538 (256, 1, 1)
540 }
541 "enumerative_geometry" | "intersection_theory" => {
542 (64, 1, 1)
544 }
545 _ => (64, 1, 1), }
547 }
548
549 pub fn get_workgroup_declaration(&self, operation: &str, data_size: usize) -> String {
551 let (x, y, z) = self.get_optimal_workgroup(operation, data_size);
552
553 if y == 1 && z == 1 {
554 format!("@compute @workgroup_size({})", x)
555 } else if z == 1 {
556 format!("@compute @workgroup_size({}, {})", x, y)
557 } else {
558 format!("@compute @workgroup_size({}, {}, {})", x, y, z)
559 }
560 }
561}
562
563pub struct EnhancedGpuBufferPool {
565 pools: HashMap<(u64, wgpu::BufferUsages), Vec<wgpu::Buffer>>,
566 stats: HashMap<(u64, wgpu::BufferUsages), PoolEntryStats>,
567 total_created: u64,
568 total_reused: u64,
569 last_cleanup: Instant,
570}
571
572#[derive(Debug, Clone, Default)]
573pub struct PoolEntryStats {
574 pub created_count: u64,
575 pub reused_count: u64,
576 pub last_used: Option<Instant>,
577 pub total_size_bytes: u64,
578}
579
580#[derive(Debug, Clone, Default)]
581pub struct BufferPoolStats {
582 pub total_buffers_created: u64,
583 pub total_buffers_reused: u64,
584 pub current_pooled_count: usize,
585 pub total_pooled_memory_mb: f32,
586 pub hit_rate_percent: f32,
587}
588
589impl EnhancedGpuBufferPool {
590 pub fn new() -> Self {
591 Self {
592 pools: HashMap::new(),
593 stats: HashMap::new(),
594 total_created: 0,
595 total_reused: 0,
596 last_cleanup: Instant::now(),
597 }
598 }
599}
600
601impl Default for EnhancedGpuBufferPool {
602 fn default() -> Self {
603 Self::new()
604 }
605}
606
607impl EnhancedGpuBufferPool {
608 pub fn get_or_create(
609 &mut self,
610 device: &wgpu::Device,
611 size: u64,
612 usage: wgpu::BufferUsages,
613 label: Option<&str>,
614 ) -> wgpu::Buffer {
615 let key = (size, usage);
616
617 if let Some(buffers) = self.pools.get_mut(&key) {
619 if let Some(buffer) = buffers.pop() {
620 self.total_reused += 1;
621 self.stats.entry(key).or_default().reused_count += 1;
622 self.stats.get_mut(&key).unwrap().last_used = Some(Instant::now());
623 return buffer;
624 }
625 }
626
627 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
629 label,
630 size,
631 usage,
632 mapped_at_creation: false,
633 });
634
635 self.total_created += 1;
636 let stats = self.stats.entry(key).or_default();
637 stats.created_count += 1;
638 stats.total_size_bytes += size;
639 stats.last_used = Some(Instant::now());
640
641 if self.last_cleanup.elapsed().as_secs() > 30 {
643 self.cleanup_old_buffers();
644 }
645
646 buffer
647 }
648
649 pub fn return_buffer(&mut self, buffer: wgpu::Buffer, size: u64, usage: wgpu::BufferUsages) {
650 let key = (size, usage);
651 self.pools.entry(key).or_default().push(buffer);
652 }
653
654 pub fn get_stats(&self) -> BufferPoolStats {
655 let total_ops = self.total_created + self.total_reused;
656 let hit_rate = if total_ops > 0 {
657 (self.total_reused as f32 / total_ops as f32) * 100.0
658 } else {
659 0.0
660 };
661
662 let current_pooled_count = self.pools.values().map(|v| v.len()).sum();
663 let total_pooled_memory_mb: f32 = self
664 .pools
665 .iter()
666 .map(|((size, _usage), buffers)| {
667 (*size as f32 * buffers.len() as f32) / 1024.0 / 1024.0
668 })
669 .sum();
670
671 BufferPoolStats {
672 total_buffers_created: self.total_created,
673 total_buffers_reused: self.total_reused,
674 current_pooled_count,
675 total_pooled_memory_mb,
676 hit_rate_percent: hit_rate,
677 }
678 }
679
680 fn cleanup_old_buffers(&mut self) {
681 let now = Instant::now();
682 let cleanup_threshold = std::time::Duration::from_secs(300); self.pools.retain(|&key, buffers| {
685 if let Some(stats) = self.stats.get(&key) {
686 if let Some(last_used) = stats.last_used {
687 if now.duration_since(last_used) > cleanup_threshold {
688 buffers.clear();
690 return false;
691 }
692 }
693 }
694 true
695 });
696
697 self.last_cleanup = now;
698 }
699}
700
701pub struct GpuDispatcher {
703 gpu_context: Option<GpuContext>,
704 cpu_threshold: usize,
705 gpu_threshold: usize,
706}
707
708impl GpuDispatcher {
709 pub async fn new() -> UnifiedGpuResult<Self> {
711 let gpu_context = (GpuContext::new().await).ok(); Ok(Self {
714 gpu_context,
715 cpu_threshold: 100, gpu_threshold: 1000, })
718 }
719
720 pub fn should_use_gpu(&self, workload_size: usize) -> bool {
722 self.gpu_context.is_some()
723 && workload_size >= self.cpu_threshold
724 && workload_size >= self.gpu_threshold
725 }
726
727 pub async fn execute<T, F, G>(&mut self, workload_size: usize, gpu_op: G, cpu_op: F) -> T
729 where
730 F: FnOnce() -> T,
731 G: FnOnce(&mut GpuContext) -> UnifiedGpuResult<T>,
732 {
733 if self.should_use_gpu(workload_size) {
734 if let Some(ref mut ctx) = self.gpu_context {
735 if let Ok(result) = gpu_op(ctx) {
736 return result;
737 }
738 }
739 }
740
741 cpu_op()
743 }
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749
750 #[tokio::test]
751 #[ignore = "GPU hardware required, may fail in CI/CD environments"]
752 async fn test_gpu_context_creation() {
753 let _result = GpuContext::new().await;
755 }
757
758 #[tokio::test]
759 #[ignore = "GPU hardware required, may fail in CI/CD environments"]
760 async fn test_gpu_dispatcher() {
761 let dispatcher = GpuDispatcher::new().await;
762 assert!(dispatcher.is_ok());
763 }
764
765 #[test]
766 fn test_gpu_operation_params() {
767 let mut params = GpuOperationParams::default();
768 params
769 .params
770 .insert("scale".to_string(), GpuParam::Float(2.0));
771 params.batch_size = 100;
772
773 assert_eq!(params.batch_size, 100);
774 match params.params.get("scale") {
775 Some(GpuParam::Float(val)) => assert_eq!(*val, 2.0),
776 _ => panic!("Expected float parameter"),
777 }
778 }
779}