1use crate::kernels::KernelType;
8use scirs2_core::ndarray::{s, Array2};
9use thiserror::Error;
10
11#[cfg(feature = "gpu")]
12use std::collections::HashMap;
13#[cfg(feature = "gpu")]
14use wgpu::{
15 util::DeviceExt, Adapter, BufferDescriptor, BufferUsages, ComputePassDescriptor,
16 ComputePipeline, Device, DeviceDescriptor, Features, Instance, Limits, PowerPreference, Queue,
17 RequestAdapterOptions, ShaderModule, ShaderModuleDescriptor, ShaderSource,
18};
19
20#[derive(Error, Debug)]
22pub enum GpuKernelError {
23 #[error("GPU device not available")]
24 DeviceNotAvailable,
25 #[error("Insufficient GPU memory")]
26 InsufficientMemory,
27 #[error("GPU computation failed: {0}")]
28 ComputationFailed(String),
29 #[error("Shader compilation failed: {0}")]
30 ShaderCompilationFailed(String),
31 #[error("Buffer creation failed")]
32 BufferCreationFailed,
33 #[error("GPU feature not supported: {0}")]
34 FeatureNotSupported(String),
35 #[error("Kernel matrix dimensions mismatch")]
36 DimensionMismatch,
37}
38
39pub type GpuKernelResult<T> = Result<T, GpuKernelError>;
41
42#[cfg(feature = "gpu")]
44pub struct GpuKernelComputer {
45 device: Device,
46 queue: Queue,
47 adapter: Adapter,
48 pipelines: HashMap<String, ComputePipeline>,
49 shader_modules: HashMap<String, ShaderModule>,
50}
51
52#[cfg(feature = "gpu")]
53impl GpuKernelComputer {
54 pub async fn new() -> GpuKernelResult<Self> {
56 let instance = Instance::new(&wgpu::InstanceDescriptor {
57 backends: wgpu::Backends::all(),
58 flags: wgpu::InstanceFlags::default(),
59 ..Default::default()
60 });
61
62 let adapter = instance
63 .request_adapter(&RequestAdapterOptions {
64 power_preference: PowerPreference::HighPerformance,
65 compatible_surface: None,
66 force_fallback_adapter: false,
67 })
68 .await
69 .ok_or(GpuKernelError::DeviceNotAvailable)?;
70
71 let (device, queue) = adapter
72 .request_device(
73 &DeviceDescriptor {
74 label: None,
75 required_features: Features::empty(),
76 required_limits: Limits::default(),
77 memory_hints: wgpu::MemoryHints::Performance,
78 },
79 None,
80 )
81 .await
82 .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
83
84 let mut computer = Self {
85 device,
86 queue,
87 adapter,
88 pipelines: HashMap::new(),
89 shader_modules: HashMap::new(),
90 };
91
92 computer.init_rbf_shader()?;
94 computer.init_polynomial_shader()?;
95 computer.init_linear_shader()?;
96 computer.init_sigmoid_shader()?;
97
98 Ok(computer)
99 }
100
101 fn init_rbf_shader(&mut self) -> GpuKernelResult<()> {
103 let shader_source = r#"
104 @group(0) @binding(0) var<storage, read> X: array<f32>;
105 @group(0) @binding(1) var<storage, read> Y: array<f32>;
106 @group(0) @binding(2) var<storage, read_write> result: array<f32>;
107 @group(0) @binding(3) var<storage, read> params: array<f32>;
108
109 @compute @workgroup_size(16, 16)
110 fn rbf_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
111 let n_x = u32(params[0]);
112 let n_y = u32(params[1]);
113 let n_features = u32(params[2]);
114 let gamma = params[3];
115
116 let i = global_id.x;
117 let j = global_id.y;
118
119 if (i >= n_x || j >= n_y) {
120 return;
121 }
122
123 var sum_sq_diff = 0.0;
124 for (var k = 0u; k < n_features; k++) {
125 let diff = X[i * n_features + k] - Y[j * n_features + k];
126 sum_sq_diff += diff * diff;
127 }
128
129 result[i * n_y + j] = exp(-gamma * sum_sq_diff);
130 }
131 "#;
132
133 let shader = self.device.create_shader_module(ShaderModuleDescriptor {
134 label: Some("RBF Kernel Shader"),
135 source: ShaderSource::Wgsl(shader_source.into()),
136 });
137
138 let pipeline = self
139 .device
140 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
141 label: Some("RBF Kernel Pipeline"),
142 layout: None,
143 module: &shader,
144 entry_point: Some("rbf_kernel"),
145 compilation_options: Default::default(),
146 cache: None,
147 });
148
149 self.shader_modules.insert("rbf".to_string(), shader);
150 self.pipelines.insert("rbf".to_string(), pipeline);
151
152 Ok(())
153 }
154
155 fn init_polynomial_shader(&mut self) -> GpuKernelResult<()> {
157 let shader_source = r#"
158 @group(0) @binding(0) var<storage, read> X: array<f32>;
159 @group(0) @binding(1) var<storage, read> Y: array<f32>;
160 @group(0) @binding(2) var<storage, read_write> result: array<f32>;
161 @group(0) @binding(3) var<storage, read> params: array<f32>;
162
163 @compute @workgroup_size(16, 16)
164 fn polynomial_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
165 let n_x = u32(params[0]);
166 let n_y = u32(params[1]);
167 let n_features = u32(params[2]);
168 let gamma = params[3];
169 let coef0 = params[4];
170 let degree = params[5];
171
172 let i = global_id.x;
173 let j = global_id.y;
174
175 if (i >= n_x || j >= n_y) {
176 return;
177 }
178
179 var dot_product = 0.0;
180 for (var k = 0u; k < n_features; k++) {
181 dot_product += X[i * n_features + k] * Y[j * n_features + k];
182 }
183
184 result[i * n_y + j] = pow(gamma * dot_product + coef0, degree);
185 }
186 "#;
187
188 let shader = self.device.create_shader_module(ShaderModuleDescriptor {
189 label: Some("Polynomial Kernel Shader"),
190 source: ShaderSource::Wgsl(shader_source.into()),
191 });
192
193 let pipeline = self
194 .device
195 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
196 label: Some("Polynomial Kernel Pipeline"),
197 layout: None,
198 module: &shader,
199 entry_point: Some("polynomial_kernel"),
200 compilation_options: Default::default(),
201 cache: None,
202 });
203
204 self.shader_modules.insert("polynomial".to_string(), shader);
205 self.pipelines.insert("polynomial".to_string(), pipeline);
206
207 Ok(())
208 }
209
210 fn init_linear_shader(&mut self) -> GpuKernelResult<()> {
212 let shader_source = r#"
213 @group(0) @binding(0) var<storage, read> X: array<f32>;
214 @group(0) @binding(1) var<storage, read> Y: array<f32>;
215 @group(0) @binding(2) var<storage, read_write> result: array<f32>;
216 @group(0) @binding(3) var<storage, read> params: array<f32>;
217
218 @compute @workgroup_size(16, 16)
219 fn linear_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
220 let n_x = u32(params[0]);
221 let n_y = u32(params[1]);
222 let n_features = u32(params[2]);
223
224 let i = global_id.x;
225 let j = global_id.y;
226
227 if (i >= n_x || j >= n_y) {
228 return;
229 }
230
231 var dot_product = 0.0;
232 for (var k = 0u; k < n_features; k++) {
233 dot_product += X[i * n_features + k] * Y[j * n_features + k];
234 }
235
236 result[i * n_y + j] = dot_product;
237 }
238 "#;
239
240 let shader = self.device.create_shader_module(ShaderModuleDescriptor {
241 label: Some("Linear Kernel Shader"),
242 source: ShaderSource::Wgsl(shader_source.into()),
243 });
244
245 let pipeline = self
246 .device
247 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
248 label: Some("Linear Kernel Pipeline"),
249 layout: None,
250 module: &shader,
251 entry_point: Some("linear_kernel"),
252 compilation_options: Default::default(),
253 cache: None,
254 });
255
256 self.shader_modules.insert("linear".to_string(), shader);
257 self.pipelines.insert("linear".to_string(), pipeline);
258
259 Ok(())
260 }
261
262 fn init_sigmoid_shader(&mut self) -> GpuKernelResult<()> {
264 let shader_source = r#"
265 @group(0) @binding(0) var<storage, read> X: array<f32>;
266 @group(0) @binding(1) var<storage, read> Y: array<f32>;
267 @group(0) @binding(2) var<storage, read_write> result: array<f32>;
268 @group(0) @binding(3) var<storage, read> params: array<f32>;
269
270 @compute @workgroup_size(16, 16)
271 fn sigmoid_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
272 let n_x = u32(params[0]);
273 let n_y = u32(params[1]);
274 let n_features = u32(params[2]);
275 let gamma = params[3];
276 let coef0 = params[4];
277
278 let i = global_id.x;
279 let j = global_id.y;
280
281 if (i >= n_x || j >= n_y) {
282 return;
283 }
284
285 var dot_product = 0.0;
286 for (var k = 0u; k < n_features; k++) {
287 dot_product += X[i * n_features + k] * Y[j * n_features + k];
288 }
289
290 result[i * n_y + j] = tanh(gamma * dot_product + coef0);
291 }
292 "#;
293
294 let shader = self.device.create_shader_module(ShaderModuleDescriptor {
295 label: Some("Sigmoid Kernel Shader"),
296 source: ShaderSource::Wgsl(shader_source.into()),
297 });
298
299 let pipeline = self
300 .device
301 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
302 label: Some("Sigmoid Kernel Pipeline"),
303 layout: None,
304 module: &shader,
305 entry_point: Some("sigmoid_kernel"),
306 compilation_options: Default::default(),
307 cache: None,
308 });
309
310 self.shader_modules.insert("sigmoid".to_string(), shader);
311 self.pipelines.insert("sigmoid".to_string(), pipeline);
312
313 Ok(())
314 }
315
316 pub async fn compute_kernel_matrix(
318 &self,
319 X: &Array2<f32>,
320 Y: &Array2<f32>,
321 kernel_type: &KernelType,
322 ) -> GpuKernelResult<Array2<f32>> {
323 let (n_x, n_features_x) = X.dim();
324 let (n_y, n_features_y) = Y.dim();
325
326 if n_features_x != n_features_y {
327 return Err(GpuKernelError::DimensionMismatch);
328 }
329
330 let (pipeline_name, params) = match kernel_type {
331 KernelType::Rbf { gamma } => (
332 "rbf",
333 vec![n_x as f32, n_y as f32, n_features_x as f32, *gamma as f32],
334 ),
335 KernelType::Polynomial {
336 gamma,
337 coef0,
338 degree,
339 } => (
340 "polynomial",
341 vec![
342 n_x as f32,
343 n_y as f32,
344 n_features_x as f32,
345 *gamma as f32,
346 *coef0 as f32,
347 *degree as f32,
348 ],
349 ),
350 KernelType::Linear => ("linear", vec![n_x as f32, n_y as f32, n_features_x as f32]),
351 KernelType::Sigmoid { gamma, coef0 } => (
352 "sigmoid",
353 vec![
354 n_x as f32,
355 n_y as f32,
356 n_features_x as f32,
357 *gamma as f32,
358 *coef0 as f32,
359 ],
360 ),
361 _ => {
362 return Err(GpuKernelError::FeatureNotSupported(
363 "Kernel type not supported on GPU".to_string(),
364 ))
365 }
366 };
367
368 let pipeline = self.pipelines.get(pipeline_name).ok_or_else(|| {
369 GpuKernelError::FeatureNotSupported(format!("Pipeline {pipeline_name} not found"))
370 })?;
371
372 let x_buffer = self
374 .device
375 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
376 label: Some("X Buffer"),
377 contents: bytemuck::cast_slice(X.as_slice().unwrap()),
378 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
379 });
380
381 let y_buffer = self
382 .device
383 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
384 label: Some("Y Buffer"),
385 contents: bytemuck::cast_slice(Y.as_slice().unwrap()),
386 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
387 });
388
389 let result_buffer = self.device.create_buffer(&BufferDescriptor {
390 label: Some("Result Buffer"),
391 size: (n_x * n_y * std::mem::size_of::<f32>()) as u64,
392 usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
393 mapped_at_creation: false,
394 });
395
396 let params_buffer = self
397 .device
398 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
399 label: Some("Params Buffer"),
400 contents: bytemuck::cast_slice(¶ms),
401 usage: BufferUsages::STORAGE | BufferUsages::COPY_DST,
402 });
403
404 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
406 label: Some("Kernel Bind Group"),
407 layout: &pipeline.get_bind_group_layout(0),
408 entries: &[
409 wgpu::BindGroupEntry {
410 binding: 0,
411 resource: x_buffer.as_entire_binding(),
412 },
413 wgpu::BindGroupEntry {
414 binding: 1,
415 resource: y_buffer.as_entire_binding(),
416 },
417 wgpu::BindGroupEntry {
418 binding: 2,
419 resource: result_buffer.as_entire_binding(),
420 },
421 wgpu::BindGroupEntry {
422 binding: 3,
423 resource: params_buffer.as_entire_binding(),
424 },
425 ],
426 });
427
428 let mut encoder = self
430 .device
431 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
432 label: Some("Kernel Compute Encoder"),
433 });
434
435 {
436 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
437 label: Some("Kernel Compute Pass"),
438 timestamp_writes: None,
439 });
440
441 compute_pass.set_pipeline(pipeline);
442 compute_pass.set_bind_group(0, &bind_group, &[]);
443
444 let workgroup_size = 16;
445 let num_workgroups_x = (n_x + workgroup_size - 1) / workgroup_size;
446 let num_workgroups_y = (n_y + workgroup_size - 1) / workgroup_size;
447
448 compute_pass.dispatch_workgroups(num_workgroups_x as u32, num_workgroups_y as u32, 1);
449 }
450
451 let staging_buffer = self.device.create_buffer(&BufferDescriptor {
453 label: Some("Staging Buffer"),
454 size: (n_x * n_y * std::mem::size_of::<f32>()) as u64,
455 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
456 mapped_at_creation: false,
457 });
458
459 encoder.copy_buffer_to_buffer(
460 &result_buffer,
461 0,
462 &staging_buffer,
463 0,
464 (n_x * n_y * std::mem::size_of::<f32>()) as u64,
465 );
466
467 self.queue.submit(Some(encoder.finish()));
468
469 let buffer_slice = staging_buffer.slice(..);
471 let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
472 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
473 tx.send(result).unwrap();
474 });
475
476 self.device.poll(wgpu::Maintain::Wait);
477 rx.receive()
478 .await
479 .unwrap()
480 .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
481
482 let data = buffer_slice.get_mapped_range();
483 let result_data: &[f32] = bytemuck::cast_slice(&data);
484
485 let result_matrix = Array2::from_shape_vec((n_x, n_y), result_data.to_vec())
486 .map_err(|e| GpuKernelError::ComputationFailed(e.to_string()))?;
487
488 Ok(result_matrix)
489 }
490
491 pub fn device_info(&self) -> String {
493 format!("GPU: {}", self.adapter.get_info().name)
494 }
495
496 pub fn supports_compute(&self) -> bool {
498 self.adapter.features().contains(Features::empty())
499 }
500
501 pub fn memory_info(&self) -> wgpu::AdapterInfo {
503 self.adapter.get_info()
504 }
505}
506
507#[cfg(not(feature = "gpu"))]
508pub struct GpuKernelComputer;
509
510#[cfg(not(feature = "gpu"))]
511impl GpuKernelComputer {
512 pub async fn new() -> GpuKernelResult<Self> {
513 Err(GpuKernelError::FeatureNotSupported(
514 "GPU support not enabled".to_string(),
515 ))
516 }
517
518 pub async fn compute_kernel_matrix(
519 &self,
520 _x: &Array2<f32>,
521 _y: &Array2<f32>,
522 _kernel_type: &KernelType,
523 ) -> GpuKernelResult<Array2<f32>> {
524 Err(GpuKernelError::FeatureNotSupported(
525 "GPU support not enabled".to_string(),
526 ))
527 }
528}
529
530pub struct GpuKernel {
532 #[cfg(feature = "gpu")]
533 computer: Option<GpuKernelComputer>,
534 kernel_type: KernelType,
535 use_gpu: bool,
536}
537
538impl GpuKernel {
539 pub fn new(kernel_type: KernelType, use_gpu: bool) -> Self {
541 Self {
542 #[cfg(feature = "gpu")]
543 computer: None,
544 kernel_type,
545 use_gpu,
546 }
547 }
548
549 #[cfg(feature = "gpu")]
551 pub async fn init_gpu(&mut self) -> GpuKernelResult<()> {
552 if self.use_gpu {
553 self.computer = Some(GpuKernelComputer::new().await?);
554 }
555 Ok(())
556 }
557
558 #[cfg(not(feature = "gpu"))]
559 pub async fn init_gpu(&mut self) -> GpuKernelResult<()> {
560 if self.use_gpu {
561 return Err(GpuKernelError::FeatureNotSupported(
562 "GPU support not enabled".to_string(),
563 ));
564 }
565 Ok(())
566 }
567
568 pub async fn compute_matrix(&self, x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
570 #[cfg(feature = "gpu")]
571 if let Some(computer) = &self.computer {
572 if let Ok(result) = computer
573 .compute_kernel_matrix(x, y, &self.kernel_type)
574 .await
575 {
576 return result;
577 }
578 }
579
580 self.compute_cpu_kernel_matrix(x, y)
582 }
583
584 pub fn compute_cpu_kernel_matrix(&self, x: &Array2<f32>, y: &Array2<f32>) -> Array2<f32> {
586 let (n_x, _n_features) = x.dim();
587 let (n_y, _) = y.dim();
588 let mut result = Array2::zeros((n_x, n_y));
589
590 for i in 0..n_x {
591 for j in 0..n_y {
592 let x_i = x.row(i);
593 let y_j = y.row(j);
594
595 let kernel_value = match &self.kernel_type {
596 KernelType::Linear => x_i.dot(&y_j) as f64,
597 KernelType::Rbf { gamma } => {
598 let diff = &x_i - &y_j;
599 let squared_distance = diff.dot(&diff) as f64;
600 (-gamma * squared_distance).exp()
601 }
602 KernelType::Polynomial {
603 gamma,
604 coef0,
605 degree,
606 } => {
607 let dot_product = x_i.dot(&y_j) as f64;
608 (gamma * dot_product + coef0).powf(*degree)
609 }
610 KernelType::Sigmoid { gamma, coef0 } => {
611 let dot_product = x_i.dot(&y_j) as f64;
612 (gamma * dot_product + coef0).tanh()
613 }
614 _ => 0.0, };
616
617 result[(i, j)] = kernel_value as f32;
618 }
619 }
620
621 result
622 }
623
624 pub fn is_gpu_available(&self) -> bool {
626 #[cfg(feature = "gpu")]
627 return self.computer.is_some();
628 #[cfg(not(feature = "gpu"))]
629 false
630 }
631
632 pub fn device_info(&self) -> String {
634 #[cfg(feature = "gpu")]
635 if let Some(computer) = &self.computer {
636 return computer.device_info();
637 }
638 "CPU".to_string()
639 }
640}
641
642pub struct GpuKernelBenchmark {
644 pub gpu_time: Option<std::time::Duration>,
645 pub cpu_time: std::time::Duration,
646 pub speedup: Option<f64>,
647 pub accuracy: f64,
648}
649
650impl GpuKernelBenchmark {
651 pub async fn run(
653 x: &Array2<f32>,
654 y: &Array2<f32>,
655 kernel_type: KernelType,
656 ) -> GpuKernelResult<Self> {
657 let cpu_start = std::time::Instant::now();
659 let cpu_kernel = GpuKernel::new(kernel_type.clone(), false);
660 #[cfg_attr(not(feature = "gpu"), allow(unused_variables))]
661 let cpu_result = cpu_kernel.compute_cpu_kernel_matrix(x, y);
662 let cpu_time = cpu_start.elapsed();
663
664 #[cfg(feature = "gpu")]
666 let (gpu_time, speedup, accuracy) = {
667 if let Ok(computer) = GpuKernelComputer::new().await {
668 let gpu_start = std::time::Instant::now();
669 let gpu_result = computer.compute_kernel_matrix(x, y, &kernel_type).await?;
670 let gpu_time = gpu_start.elapsed();
671
672 let diff = &cpu_result - &gpu_result;
674 let mse = diff.mapv(|x| x * x).mean().unwrap_or(0.0);
675 let accuracy = 1.0 - (mse as f64).sqrt();
676
677 let speedup = cpu_time.as_secs_f64() / gpu_time.as_secs_f64();
678
679 (Some(gpu_time), Some(speedup), accuracy)
680 } else {
681 (None, None, 0.0)
682 }
683 };
684
685 #[cfg(not(feature = "gpu"))]
686 let (gpu_time, speedup, accuracy) = (None, None, 0.0);
687
688 Ok(GpuKernelBenchmark {
689 gpu_time,
690 cpu_time,
691 speedup,
692 accuracy,
693 })
694 }
695}
696
697pub mod gpu_utils {
699 use super::*;
700
701 pub fn optimal_batch_size(n_samples: usize, n_features: usize) -> usize {
703 let memory_limit = 1024 * 1024 * 1024; let sample_size = n_features * std::mem::size_of::<f32>();
706 let max_batch = memory_limit / sample_size;
707
708 (max_batch.min(n_samples)).max(1)
709 }
710
711 pub fn should_use_gpu(n_samples: usize, n_features: usize) -> bool {
713 let computation_size = n_samples * n_samples * n_features;
715 computation_size > 1_000_000 }
717
718 pub async fn compute_kernel_matrix_batched(
720 computer: &GpuKernelComputer,
721 x: &Array2<f32>,
722 y: &Array2<f32>,
723 kernel_type: &KernelType,
724 batch_size: usize,
725 ) -> GpuKernelResult<Array2<f32>> {
726 let (n_x, _n_features) = x.dim();
727 let (n_y, _) = y.dim();
728
729 let mut result = Array2::zeros((n_x, n_y));
730
731 for i in (0..n_x).step_by(batch_size) {
732 let end_i = (i + batch_size).min(n_x);
733 let x_batch = x.slice(s![i..end_i, ..]);
734
735 for j in (0..n_y).step_by(batch_size) {
736 let end_j = (j + batch_size).min(n_y);
737 let y_batch = y.slice(s![j..end_j, ..]);
738
739 let batch_result = computer
740 .compute_kernel_matrix(&x_batch.to_owned(), &y_batch.to_owned(), kernel_type)
741 .await?;
742
743 result
744 .slice_mut(s![i..end_i, j..end_j])
745 .assign(&batch_result);
746 }
747 }
748
749 Ok(result)
750 }
751}
752
753#[allow(non_snake_case)]
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[test]
759 fn test_gpu_kernel_creation() {
760 let kernel = GpuKernel::new(KernelType::Rbf { gamma: 1.0 }, true);
761 assert!(!kernel.is_gpu_available()); }
763
764 #[test]
765 fn test_gpu_kernel_sync() {
766 let kernel = GpuKernel::new(KernelType::Linear, true);
767 assert_eq!(kernel.device_info(), "CPU");
769 }
770
771 #[test]
772 fn test_gpu_utils() {
773 let batch_size = gpu_utils::optimal_batch_size(1000, 100);
774 assert!(batch_size > 0);
775
776 let should_use = gpu_utils::should_use_gpu(1000, 1000);
777 assert!(should_use);
778
779 let should_not_use = gpu_utils::should_use_gpu(10, 10);
780 assert!(!should_not_use);
781 }
782
783 #[test]
784 #[allow(non_snake_case)]
785 fn test_benchmark_sync() {
786 let X_var = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f32).collect()).unwrap();
787 let Y_var = Array2::from_shape_vec((8, 5), (0..40).map(|x| x as f32).collect()).unwrap();
788
789 assert_eq!(X_var.dim(), (10, 5));
791 assert_eq!(Y_var.dim(), (8, 5));
792 }
793}