1use anyhow::{anyhow, Result};
34use serde::{Deserialize, Serialize};
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tracing::{debug, warn};
38
39use scirs2_core::gpu::{GpuBackend as ScirsGpuBackend, GpuContext as ScirsGpuContext};
41
42pub struct GpuContext {
44 backend: GpuBackend,
45 config: GpuConfig,
46 stats: Arc<RwLock<GpuStats>>,
47 #[allow(dead_code)]
48 scirs_context: Option<ScirsGpuContext>,
49}
50
51impl GpuContext {
52 pub fn new(backend: GpuBackend) -> Result<Self> {
54 let config = GpuConfig::default();
55
56 let scirs_context = match backend {
58 GpuBackend::Cuda => {
59 debug!("Initializing CUDA backend");
60 ScirsGpuContext::new(ScirsGpuBackend::Cuda).ok()
61 }
62 GpuBackend::Metal => {
63 debug!("Initializing Metal backend");
64 ScirsGpuContext::new(ScirsGpuBackend::Metal).ok()
65 }
66 GpuBackend::Cpu => {
67 debug!("Using CPU fallback");
68 None
69 }
70 GpuBackend::Auto => {
71 ScirsGpuContext::new(ScirsGpuBackend::Cuda)
73 .or_else(|_| ScirsGpuContext::new(ScirsGpuBackend::Metal))
74 .ok()
75 }
76 };
77
78 Ok(Self {
79 backend,
80 config,
81 stats: Arc::new(RwLock::new(GpuStats::default())),
82 scirs_context,
83 })
84 }
85
86 pub fn is_available(&self) -> bool {
88 self.scirs_context.is_some()
89 }
90
91 pub fn backend(&self) -> GpuBackend {
93 self.backend
94 }
95
96 pub async fn batch_process(&self, data: &[f32]) -> Result<Vec<f32>> {
98 let mut stats = self.stats.write().await;
99 stats.batches_processed += 1;
100
101 if let Some(_ctx) = &self.scirs_context {
102 debug!("Processing batch on GPU: {} elements", data.len());
104 stats.gpu_operations += 1;
105
106 Ok(data.to_vec())
108 } else {
109 warn!("GPU not available, falling back to CPU");
111 stats.cpu_fallbacks += 1;
112 Ok(data.to_vec())
113 }
114 }
115
116 pub async fn matrix_multiply(
118 &self,
119 a: &[f32],
120 b: &[f32],
121 m: usize,
122 n: usize,
123 k: usize,
124 ) -> Result<Vec<f32>> {
125 let mut stats = self.stats.write().await;
126 stats.matrix_operations += 1;
127
128 if let Some(_ctx) = &self.scirs_context {
129 debug!("GPU matrix multiply: {}x{} * {}x{}", m, n, n, k);
130
131 let mut result = vec![0.0f32; m * k];
133
134 for i in 0..m {
135 for j in 0..k {
136 for l in 0..n {
137 result[i * k + j] += a[i * n + l] * b[l * k + j];
138 }
139 }
140 }
141
142 Ok(result)
143 } else {
144 let mut result = vec![0.0f32; m * k];
146
147 for i in 0..m {
148 for j in 0..k {
149 for l in 0..n {
150 result[i * k + j] += a[i * n + l] * b[l * k + j];
151 }
152 }
153 }
154
155 Ok(result)
156 }
157 }
158
159 pub async fn vector_add(&self, a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
161 if a.len() != b.len() {
162 return Err(anyhow!("Vector lengths must match"));
163 }
164
165 let mut stats = self.stats.write().await;
166 stats.vector_operations += 1;
167
168 if self.is_available() {
169 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
171 } else {
172 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
174 }
175 }
176
177 pub async fn parallel_sum(&self, data: &[f32]) -> Result<f32> {
179 let mut stats = self.stats.write().await;
180 stats.aggregation_operations += 1;
181
182 if self.is_available() {
183 Ok(data.iter().sum())
185 } else {
186 Ok(data.iter().sum())
188 }
189 }
190
191 pub async fn pattern_match(&self, data: &[f32], pattern: &[f32]) -> Result<Vec<usize>> {
193 let mut stats = self.stats.write().await;
194 stats.pattern_operations += 1;
195
196 let mut matches = Vec::new();
197
198 for i in 0..=data.len().saturating_sub(pattern.len()) {
200 let window = &data[i..i + pattern.len()];
201 if window == pattern {
202 matches.push(i);
203 }
204 }
205
206 Ok(matches)
207 }
208
209 pub async fn stats(&self) -> GpuStats {
211 self.stats.read().await.clone()
212 }
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
217pub enum GpuBackend {
218 Cuda,
220
221 Metal,
223
224 Cpu,
226
227 Auto,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct GpuConfig {
234 pub enabled: bool,
236
237 pub backend: GpuBackend,
239
240 pub batch_size: usize,
242
243 pub memory_limit: usize,
245
246 pub mixed_precision: bool,
248
249 pub num_streams: usize,
251}
252
253impl Default for GpuConfig {
254 fn default() -> Self {
255 Self {
256 enabled: true,
257 backend: GpuBackend::Auto,
258 batch_size: 1024,
259 memory_limit: 2 * 1024 * 1024 * 1024, mixed_precision: false,
261 num_streams: 2,
262 }
263 }
264}
265
266#[derive(Debug, Clone, Default, Serialize, Deserialize)]
268pub struct GpuStats {
269 pub batches_processed: u64,
271
272 pub gpu_operations: u64,
274
275 pub cpu_fallbacks: u64,
277
278 pub matrix_operations: u64,
280
281 pub vector_operations: u64,
283
284 pub aggregation_operations: u64,
286
287 pub pattern_operations: u64,
289
290 pub total_gpu_time_ms: f64,
292
293 pub avg_gpu_time_ms: f64,
295}
296
297impl GpuStats {
298 pub fn gpu_utilization(&self) -> f64 {
300 let total_ops = self.gpu_operations + self.cpu_fallbacks;
301 if total_ops == 0 {
302 0.0
303 } else {
304 self.gpu_operations as f64 / total_ops as f64
305 }
306 }
307
308 pub fn cpu_fallback_rate(&self) -> f64 {
310 let total_ops = self.gpu_operations + self.cpu_fallbacks;
311 if total_ops == 0 {
312 0.0
313 } else {
314 self.cpu_fallbacks as f64 / total_ops as f64
315 }
316 }
317}
318
319pub struct GpuBuffer<T> {
321 data: Vec<T>,
322 device_ptr: Option<usize>, }
324
325impl<T: Clone> GpuBuffer<T> {
326 pub fn new(data: Vec<T>) -> Self {
328 Self {
329 data,
330 device_ptr: None,
331 }
332 }
333
334 pub fn to_device(&mut self) -> Result<()> {
336 self.device_ptr = Some(0x1000); Ok(())
339 }
340
341 pub fn from_device(&mut self) -> Result<()> {
343 self.device_ptr = None;
345 Ok(())
346 }
347
348 pub fn is_on_device(&self) -> bool {
350 self.device_ptr.is_some()
351 }
352
353 pub fn data(&self) -> &[T] {
355 &self.data
356 }
357}
358
359pub struct GpuStreamProcessor {
361 context: GpuContext,
362 config: GpuProcessorConfig,
363}
364
365impl GpuStreamProcessor {
366 pub fn new(backend: GpuBackend, config: GpuProcessorConfig) -> Result<Self> {
368 Ok(Self {
369 context: GpuContext::new(backend)?,
370 config,
371 })
372 }
373
374 pub async fn process_batch(&self, batch: &[f32]) -> Result<Vec<f32>> {
376 if batch.len() < self.config.min_batch_size {
377 return Ok(batch.to_vec());
379 }
380
381 self.context.batch_process(batch).await
382 }
383
384 pub async fn compute_embeddings(&self, inputs: &[f32], weights: &[f32]) -> Result<Vec<f32>> {
386 let dim = weights.len() / inputs.len();
388 self.context
389 .matrix_multiply(inputs, weights, 1, inputs.len(), dim)
390 .await
391 }
392
393 pub async fn aggregate_metrics(&self, values: &[f32], operation: AggregationOp) -> Result<f32> {
395 match operation {
396 AggregationOp::Sum => self.context.parallel_sum(values).await,
397 AggregationOp::Mean => {
398 let sum = self.context.parallel_sum(values).await?;
399 Ok(sum / values.len() as f32)
400 }
401 AggregationOp::Max => Ok(values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b))),
402 AggregationOp::Min => Ok(values.iter().fold(f32::INFINITY, |a, &b| a.min(b))),
403 }
404 }
405
406 pub fn is_gpu_available(&self) -> bool {
408 self.context.is_available()
409 }
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize)]
414pub struct GpuProcessorConfig {
415 pub min_batch_size: usize,
417
418 pub max_batch_size: usize,
420
421 pub async_processing: bool,
423}
424
425impl Default for GpuProcessorConfig {
426 fn default() -> Self {
427 Self {
428 min_batch_size: 100,
429 max_batch_size: 10000,
430 async_processing: true,
431 }
432 }
433}
434
435#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
437pub enum AggregationOp {
438 Sum,
439 Mean,
440 Max,
441 Min,
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[tokio::test]
449 async fn test_gpu_context_creation() {
450 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
451 assert_eq!(ctx.backend(), GpuBackend::Cpu);
452 }
453
454 #[tokio::test]
455 async fn test_batch_processing() {
456 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
457 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
458
459 let result = ctx.batch_process(&data).await.unwrap();
460 assert_eq!(result, data);
461 }
462
463 #[tokio::test]
464 async fn test_matrix_multiply() {
465 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
466
467 let a = vec![1.0, 2.0, 3.0, 4.0];
468 let b = vec![5.0, 6.0, 7.0, 8.0];
469
470 let result = ctx.matrix_multiply(&a, &b, 2, 2, 2).await.unwrap();
471 assert_eq!(result.len(), 4);
472 }
473
474 #[tokio::test]
475 async fn test_vector_add() {
476 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
477
478 let a = vec![1.0, 2.0, 3.0];
479 let b = vec![4.0, 5.0, 6.0];
480
481 let result = ctx.vector_add(&a, &b).await.unwrap();
482 assert_eq!(result, vec![5.0, 7.0, 9.0]);
483 }
484
485 #[tokio::test]
486 async fn test_parallel_sum() {
487 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
488
489 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
490 let sum = ctx.parallel_sum(&data).await.unwrap();
491
492 assert_eq!(sum, 15.0);
493 }
494
495 #[tokio::test]
496 async fn test_pattern_match() {
497 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
498
499 let data = vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0];
500 let pattern = vec![2.0, 3.0];
501
502 let matches = ctx.pattern_match(&data, &pattern).await.unwrap();
503 assert_eq!(matches, vec![1, 3]);
504 }
505
506 #[tokio::test]
507 async fn test_gpu_buffer() {
508 let mut buffer = GpuBuffer::new(vec![1.0, 2.0, 3.0]);
509
510 assert!(!buffer.is_on_device());
511
512 buffer.to_device().unwrap();
513 assert!(buffer.is_on_device());
514
515 buffer.from_device().unwrap();
516 assert!(!buffer.is_on_device());
517 }
518
519 #[tokio::test]
520 async fn test_stream_processor() {
521 let processor =
522 GpuStreamProcessor::new(GpuBackend::Cpu, GpuProcessorConfig::default()).unwrap();
523
524 let batch = vec![1.0, 2.0, 3.0, 4.0, 5.0];
525 let result = processor.process_batch(&batch).await.unwrap();
526
527 assert_eq!(result.len(), batch.len());
528 }
529
530 #[tokio::test]
531 async fn test_aggregation_operations() {
532 let processor =
533 GpuStreamProcessor::new(GpuBackend::Cpu, GpuProcessorConfig::default()).unwrap();
534
535 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
536
537 let sum = processor
538 .aggregate_metrics(&values, AggregationOp::Sum)
539 .await
540 .unwrap();
541 assert_eq!(sum, 15.0);
542
543 let mean = processor
544 .aggregate_metrics(&values, AggregationOp::Mean)
545 .await
546 .unwrap();
547 assert_eq!(mean, 3.0);
548
549 let max = processor
550 .aggregate_metrics(&values, AggregationOp::Max)
551 .await
552 .unwrap();
553 assert_eq!(max, 5.0);
554
555 let min = processor
556 .aggregate_metrics(&values, AggregationOp::Min)
557 .await
558 .unwrap();
559 assert_eq!(min, 1.0);
560 }
561
562 #[tokio::test]
563 async fn test_gpu_stats() {
564 let ctx = GpuContext::new(GpuBackend::Cpu).unwrap();
565
566 let _ = ctx.batch_process(&[1.0, 2.0, 3.0]).await;
567 let _ = ctx.vector_add(&[1.0], &[2.0]).await;
568
569 let stats = ctx.stats().await;
570 assert!(stats.batches_processed > 0);
571 assert!(stats.vector_operations > 0);
572 }
573}