1use std::sync::Arc;
8
9use oxicuda_driver::Module;
10use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
11use oxicuda_memory::DeviceBuffer;
12use oxicuda_ptx::templates::elementwise::{ElementwiseOp as PtxOp, ElementwiseTemplate};
13
14use crate::error::{BlasError, BlasResult};
15use crate::handle::BlasHandle;
16use crate::types::GpuFloat;
17
18const BLOCK_SIZE: u32 = 256;
24
25fn validate_binary_buffers<T: Copy>(
27 n: u32,
28 a: &DeviceBuffer<T>,
29 b: &DeviceBuffer<T>,
30 c: &DeviceBuffer<T>,
31) -> BlasResult<()> {
32 let n_usize = n as usize;
33 if a.len() < n_usize {
34 return Err(BlasError::BufferTooSmall {
35 expected: n_usize,
36 actual: a.len(),
37 });
38 }
39 if b.len() < n_usize {
40 return Err(BlasError::BufferTooSmall {
41 expected: n_usize,
42 actual: b.len(),
43 });
44 }
45 if c.len() < n_usize {
46 return Err(BlasError::BufferTooSmall {
47 expected: n_usize,
48 actual: c.len(),
49 });
50 }
51 Ok(())
52}
53
54fn build_binary_kernel(
56 handle: &BlasHandle,
57 ptx_op: PtxOp,
58 ptx_type: oxicuda_ptx::ir::PtxType,
59) -> BlasResult<(Kernel, String)> {
60 let template = ElementwiseTemplate::new(ptx_op, ptx_type, handle.sm_version());
61 let kernel_name = template.kernel_name();
62 let ptx_source = template
63 .generate()
64 .map_err(|e| BlasError::PtxGeneration(format!("{}: {e}", ptx_op.as_str())))?;
65 let module = Arc::new(Module::from_ptx(&ptx_source).map_err(|e| {
66 BlasError::LaunchFailed(format!("module load for {}: {e}", ptx_op.as_str()))
67 })?);
68 let kernel = Kernel::from_module(module, &kernel_name)
69 .map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
70 Ok((kernel, kernel_name))
71}
72
73pub fn add<T: GpuFloat>(
92 handle: &BlasHandle,
93 n: u32,
94 a: &DeviceBuffer<T>,
95 b: &DeviceBuffer<T>,
96 c: &mut DeviceBuffer<T>,
97) -> BlasResult<()> {
98 if n == 0 {
99 return Ok(());
100 }
101 validate_binary_buffers(n, a, b, c)?;
102
103 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Add, T::PTX_TYPE)?;
104 let grid = grid_size_for(n, BLOCK_SIZE);
105 let params = LaunchParams::new(grid, BLOCK_SIZE);
106 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
107
108 kernel
109 .launch(¶ms, handle.stream(), &args)
110 .map_err(|e| BlasError::LaunchFailed(format!("add: {e}")))?;
111 Ok(())
112}
113
114pub fn mul<T: GpuFloat>(
128 handle: &BlasHandle,
129 n: u32,
130 a: &DeviceBuffer<T>,
131 b: &DeviceBuffer<T>,
132 c: &mut DeviceBuffer<T>,
133) -> BlasResult<()> {
134 if n == 0 {
135 return Ok(());
136 }
137 validate_binary_buffers(n, a, b, c)?;
138
139 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Mul, T::PTX_TYPE)?;
140 let grid = grid_size_for(n, BLOCK_SIZE);
141 let params = LaunchParams::new(grid, BLOCK_SIZE);
142 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
143
144 kernel
145 .launch(¶ms, handle.stream(), &args)
146 .map_err(|e| BlasError::LaunchFailed(format!("mul: {e}")))?;
147 Ok(())
148}
149
150pub fn sub<T: GpuFloat>(
164 handle: &BlasHandle,
165 n: u32,
166 a: &DeviceBuffer<T>,
167 b: &DeviceBuffer<T>,
168 c: &mut DeviceBuffer<T>,
169) -> BlasResult<()> {
170 if n == 0 {
171 return Ok(());
172 }
173 validate_binary_buffers(n, a, b, c)?;
174 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Sub, T::PTX_TYPE)?;
175 let grid = grid_size_for(n, BLOCK_SIZE);
176 let params = LaunchParams::new(grid, BLOCK_SIZE);
177 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
178 kernel
179 .launch(¶ms, handle.stream(), &args)
180 .map_err(|e| BlasError::LaunchFailed(format!("sub: {e}")))?;
181 Ok(())
182}
183
184pub fn div<T: GpuFloat>(
198 handle: &BlasHandle,
199 n: u32,
200 a: &DeviceBuffer<T>,
201 b: &DeviceBuffer<T>,
202 c: &mut DeviceBuffer<T>,
203) -> BlasResult<()> {
204 if n == 0 {
205 return Ok(());
206 }
207 validate_binary_buffers(n, a, b, c)?;
208 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Div, T::PTX_TYPE)?;
209 let grid = grid_size_for(n, BLOCK_SIZE);
210 let params = LaunchParams::new(grid, BLOCK_SIZE);
211 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
212 kernel
213 .launch(¶ms, handle.stream(), &args)
214 .map_err(|e| BlasError::LaunchFailed(format!("div: {e}")))?;
215 Ok(())
216}
217
218pub fn pow<T: GpuFloat>(
234 handle: &BlasHandle,
235 n: u32,
236 a: &DeviceBuffer<T>,
237 b: &DeviceBuffer<T>,
238 c: &mut DeviceBuffer<T>,
239) -> BlasResult<()> {
240 if n == 0 {
241 return Ok(());
242 }
243 validate_binary_buffers(n, a, b, c)?;
244 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Pow, T::PTX_TYPE)?;
245 let grid = grid_size_for(n, BLOCK_SIZE);
246 let params = LaunchParams::new(grid, BLOCK_SIZE);
247 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
248 kernel
249 .launch(¶ms, handle.stream(), &args)
250 .map_err(|e| BlasError::LaunchFailed(format!("pow: {e}")))?;
251 Ok(())
252}
253
254pub fn min<T: GpuFloat>(
268 handle: &BlasHandle,
269 n: u32,
270 a: &DeviceBuffer<T>,
271 b: &DeviceBuffer<T>,
272 c: &mut DeviceBuffer<T>,
273) -> BlasResult<()> {
274 if n == 0 {
275 return Ok(());
276 }
277 validate_binary_buffers(n, a, b, c)?;
278 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Min, T::PTX_TYPE)?;
279 let grid = grid_size_for(n, BLOCK_SIZE);
280 let params = LaunchParams::new(grid, BLOCK_SIZE);
281 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
282 kernel
283 .launch(¶ms, handle.stream(), &args)
284 .map_err(|e| BlasError::LaunchFailed(format!("min: {e}")))?;
285 Ok(())
286}
287
288pub fn max<T: GpuFloat>(
302 handle: &BlasHandle,
303 n: u32,
304 a: &DeviceBuffer<T>,
305 b: &DeviceBuffer<T>,
306 c: &mut DeviceBuffer<T>,
307) -> BlasResult<()> {
308 if n == 0 {
309 return Ok(());
310 }
311 validate_binary_buffers(n, a, b, c)?;
312 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Max, T::PTX_TYPE)?;
313 let grid = grid_size_for(n, BLOCK_SIZE);
314 let params = LaunchParams::new(grid, BLOCK_SIZE);
315 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
316 kernel
317 .launch(¶ms, handle.stream(), &args)
318 .map_err(|e| BlasError::LaunchFailed(format!("max: {e}")))?;
319 Ok(())
320}
321
322pub fn cmp_eq<T: GpuFloat>(
336 handle: &BlasHandle,
337 n: u32,
338 a: &DeviceBuffer<T>,
339 b: &DeviceBuffer<T>,
340 c: &mut DeviceBuffer<T>,
341) -> BlasResult<()> {
342 if n == 0 {
343 return Ok(());
344 }
345 validate_binary_buffers(n, a, b, c)?;
346 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpEq, T::PTX_TYPE)?;
347 let grid = grid_size_for(n, BLOCK_SIZE);
348 let params = LaunchParams::new(grid, BLOCK_SIZE);
349 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
350 kernel
351 .launch(¶ms, handle.stream(), &args)
352 .map_err(|e| BlasError::LaunchFailed(format!("cmp_eq: {e}")))?;
353 Ok(())
354}
355
356pub fn cmp_ne<T: GpuFloat>(
370 handle: &BlasHandle,
371 n: u32,
372 a: &DeviceBuffer<T>,
373 b: &DeviceBuffer<T>,
374 c: &mut DeviceBuffer<T>,
375) -> BlasResult<()> {
376 if n == 0 {
377 return Ok(());
378 }
379 validate_binary_buffers(n, a, b, c)?;
380 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpNe, T::PTX_TYPE)?;
381 let grid = grid_size_for(n, BLOCK_SIZE);
382 let params = LaunchParams::new(grid, BLOCK_SIZE);
383 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
384 kernel
385 .launch(¶ms, handle.stream(), &args)
386 .map_err(|e| BlasError::LaunchFailed(format!("cmp_ne: {e}")))?;
387 Ok(())
388}
389
390pub fn cmp_lt<T: GpuFloat>(
404 handle: &BlasHandle,
405 n: u32,
406 a: &DeviceBuffer<T>,
407 b: &DeviceBuffer<T>,
408 c: &mut DeviceBuffer<T>,
409) -> BlasResult<()> {
410 if n == 0 {
411 return Ok(());
412 }
413 validate_binary_buffers(n, a, b, c)?;
414 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLt, T::PTX_TYPE)?;
415 let grid = grid_size_for(n, BLOCK_SIZE);
416 let params = LaunchParams::new(grid, BLOCK_SIZE);
417 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
418 kernel
419 .launch(¶ms, handle.stream(), &args)
420 .map_err(|e| BlasError::LaunchFailed(format!("cmp_lt: {e}")))?;
421 Ok(())
422}
423
424pub fn cmp_gt<T: GpuFloat>(
438 handle: &BlasHandle,
439 n: u32,
440 a: &DeviceBuffer<T>,
441 b: &DeviceBuffer<T>,
442 c: &mut DeviceBuffer<T>,
443) -> BlasResult<()> {
444 if n == 0 {
445 return Ok(());
446 }
447 validate_binary_buffers(n, a, b, c)?;
448 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGt, T::PTX_TYPE)?;
449 let grid = grid_size_for(n, BLOCK_SIZE);
450 let params = LaunchParams::new(grid, BLOCK_SIZE);
451 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
452 kernel
453 .launch(¶ms, handle.stream(), &args)
454 .map_err(|e| BlasError::LaunchFailed(format!("cmp_gt: {e}")))?;
455 Ok(())
456}
457
458pub fn cmp_le<T: GpuFloat>(
472 handle: &BlasHandle,
473 n: u32,
474 a: &DeviceBuffer<T>,
475 b: &DeviceBuffer<T>,
476 c: &mut DeviceBuffer<T>,
477) -> BlasResult<()> {
478 if n == 0 {
479 return Ok(());
480 }
481 validate_binary_buffers(n, a, b, c)?;
482 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLe, T::PTX_TYPE)?;
483 let grid = grid_size_for(n, BLOCK_SIZE);
484 let params = LaunchParams::new(grid, BLOCK_SIZE);
485 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
486 kernel
487 .launch(¶ms, handle.stream(), &args)
488 .map_err(|e| BlasError::LaunchFailed(format!("cmp_le: {e}")))?;
489 Ok(())
490}
491
492pub fn cmp_ge<T: GpuFloat>(
506 handle: &BlasHandle,
507 n: u32,
508 a: &DeviceBuffer<T>,
509 b: &DeviceBuffer<T>,
510 c: &mut DeviceBuffer<T>,
511) -> BlasResult<()> {
512 if n == 0 {
513 return Ok(());
514 }
515 validate_binary_buffers(n, a, b, c)?;
516 let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGe, T::PTX_TYPE)?;
517 let grid = grid_size_for(n, BLOCK_SIZE);
518 let params = LaunchParams::new(grid, BLOCK_SIZE);
519 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
520 kernel
521 .launch(¶ms, handle.stream(), &args)
522 .map_err(|e| BlasError::LaunchFailed(format!("cmp_ge: {e}")))?;
523 Ok(())
524}
525
526pub fn or_max<T: GpuFloat>(
540 handle: &BlasHandle,
541 n: u32,
542 a: &DeviceBuffer<T>,
543 b: &DeviceBuffer<T>,
544 c: &mut DeviceBuffer<T>,
545) -> BlasResult<()> {
546 if n == 0 {
547 return Ok(());
548 }
549 validate_binary_buffers(n, a, b, c)?;
550 let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrMax, T::PTX_TYPE)?;
551 let grid = grid_size_for(n, BLOCK_SIZE);
552 let params = LaunchParams::new(grid, BLOCK_SIZE);
553 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
554 kernel
555 .launch(¶ms, handle.stream(), &args)
556 .map_err(|e| BlasError::LaunchFailed(format!("or_max: {e}")))?;
557 Ok(())
558}
559
560pub fn or_prob_sum<T: GpuFloat>(
574 handle: &BlasHandle,
575 n: u32,
576 a: &DeviceBuffer<T>,
577 b: &DeviceBuffer<T>,
578 c: &mut DeviceBuffer<T>,
579) -> BlasResult<()> {
580 if n == 0 {
581 return Ok(());
582 }
583 validate_binary_buffers(n, a, b, c)?;
584 let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrProbSum, T::PTX_TYPE)?;
585 let grid = grid_size_for(n, BLOCK_SIZE);
586 let params = LaunchParams::new(grid, BLOCK_SIZE);
587 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
588 kernel
589 .launch(¶ms, handle.stream(), &args)
590 .map_err(|e| BlasError::LaunchFailed(format!("or_prob_sum: {e}")))?;
591 Ok(())
592}
593
594pub fn nand<T: GpuFloat>(
608 handle: &BlasHandle,
609 n: u32,
610 a: &DeviceBuffer<T>,
611 b: &DeviceBuffer<T>,
612 c: &mut DeviceBuffer<T>,
613) -> BlasResult<()> {
614 if n == 0 {
615 return Ok(());
616 }
617 validate_binary_buffers(n, a, b, c)?;
618 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nand, T::PTX_TYPE)?;
619 let grid = grid_size_for(n, BLOCK_SIZE);
620 let params = LaunchParams::new(grid, BLOCK_SIZE);
621 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
622 kernel
623 .launch(¶ms, handle.stream(), &args)
624 .map_err(|e| BlasError::LaunchFailed(format!("nand: {e}")))?;
625 Ok(())
626}
627
628pub fn nor<T: GpuFloat>(
642 handle: &BlasHandle,
643 n: u32,
644 a: &DeviceBuffer<T>,
645 b: &DeviceBuffer<T>,
646 c: &mut DeviceBuffer<T>,
647) -> BlasResult<()> {
648 if n == 0 {
649 return Ok(());
650 }
651 validate_binary_buffers(n, a, b, c)?;
652 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nor, T::PTX_TYPE)?;
653 let grid = grid_size_for(n, BLOCK_SIZE);
654 let params = LaunchParams::new(grid, BLOCK_SIZE);
655 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
656 kernel
657 .launch(¶ms, handle.stream(), &args)
658 .map_err(|e| BlasError::LaunchFailed(format!("nor: {e}")))?;
659 Ok(())
660}
661
662pub fn xor<T: GpuFloat>(
676 handle: &BlasHandle,
677 n: u32,
678 a: &DeviceBuffer<T>,
679 b: &DeviceBuffer<T>,
680 c: &mut DeviceBuffer<T>,
681) -> BlasResult<()> {
682 if n == 0 {
683 return Ok(());
684 }
685 validate_binary_buffers(n, a, b, c)?;
686 let (kernel, _name) = build_binary_kernel(handle, PtxOp::Xor, T::PTX_TYPE)?;
687 let grid = grid_size_for(n, BLOCK_SIZE);
688 let params = LaunchParams::new(grid, BLOCK_SIZE);
689 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
690 kernel
691 .launch(¶ms, handle.stream(), &args)
692 .map_err(|e| BlasError::LaunchFailed(format!("xor: {e}")))?;
693 Ok(())
694}
695
696pub fn fused_add_relu<T: GpuFloat>(
714 handle: &BlasHandle,
715 n: u32,
716 a: &DeviceBuffer<T>,
717 b: &DeviceBuffer<T>,
718 c: &mut DeviceBuffer<T>,
719) -> BlasResult<()> {
720 if n == 0 {
721 return Ok(());
722 }
723 validate_binary_buffers(n, a, b, c)?;
724
725 let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedAddRelu, T::PTX_TYPE)?;
726 let grid = grid_size_for(n, BLOCK_SIZE);
727 let params = LaunchParams::new(grid, BLOCK_SIZE);
728 let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
729
730 kernel
731 .launch(¶ms, handle.stream(), &args)
732 .map_err(|e| BlasError::LaunchFailed(format!("fused_add_relu: {e}")))?;
733 Ok(())
734}
735
736pub fn fused_scale_add<T: GpuFloat>(
756 handle: &BlasHandle,
757 n: u32,
758 alpha: T,
759 a: &DeviceBuffer<T>,
760 beta: T,
761 b: &DeviceBuffer<T>,
762 c: &mut DeviceBuffer<T>,
763) -> BlasResult<()> {
764 if n == 0 {
765 return Ok(());
766 }
767 validate_binary_buffers(n, a, b, c)?;
768
769 let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedScaleAdd, T::PTX_TYPE)?;
770 let grid = grid_size_for(n, BLOCK_SIZE);
771 let params = LaunchParams::new(grid, BLOCK_SIZE);
772
773 let alpha_bits = alpha.to_bits_u64();
776 let beta_bits = beta.to_bits_u64();
777 let args = (
778 a.as_device_ptr(),
779 b.as_device_ptr(),
780 c.as_device_ptr(),
781 alpha_bits,
782 beta_bits,
783 n,
784 );
785
786 kernel
787 .launch(¶ms, handle.stream(), &args)
788 .map_err(|e| BlasError::LaunchFailed(format!("fused_scale_add: {e}")))?;
789 Ok(())
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795
796 #[test]
797 fn block_size_is_power_of_two() {
798 assert!(BLOCK_SIZE.is_power_of_two());
799 const { assert!(BLOCK_SIZE >= 32) };
800 }
801
802 #[test]
803 fn ptx_template_generates_add_f32() {
804 let template = ElementwiseTemplate::new(
805 PtxOp::Add,
806 oxicuda_ptx::ir::PtxType::F32,
807 oxicuda_ptx::arch::SmVersion::Sm80,
808 );
809 let ptx = template
810 .generate()
811 .expect("add PTX generation should succeed");
812 assert!(ptx.contains("elementwise_add_f32"));
813 }
814
815 #[test]
816 fn ptx_template_generates_mul_f64() {
817 let template = ElementwiseTemplate::new(
818 PtxOp::Mul,
819 oxicuda_ptx::ir::PtxType::F64,
820 oxicuda_ptx::arch::SmVersion::Sm80,
821 );
822 let ptx = template
823 .generate()
824 .expect("mul PTX generation should succeed");
825 assert!(ptx.contains("elementwise_mul_f64"));
826 }
827
828 #[test]
829 fn ptx_template_generates_fused_add_relu_f32() {
830 let template = ElementwiseTemplate::new(
831 PtxOp::FusedAddRelu,
832 oxicuda_ptx::ir::PtxType::F32,
833 oxicuda_ptx::arch::SmVersion::Sm80,
834 );
835 let ptx = template
836 .generate()
837 .expect("fused_add_relu PTX generation should succeed");
838 assert!(ptx.contains("elementwise_fused_add_relu_f32"));
839 }
840
841 #[test]
842 fn ptx_template_generates_fused_scale_add_f32() {
843 let template = ElementwiseTemplate::new(
844 PtxOp::FusedScaleAdd,
845 oxicuda_ptx::ir::PtxType::F32,
846 oxicuda_ptx::arch::SmVersion::Sm80,
847 );
848 let ptx = template
849 .generate()
850 .expect("fused_scale_add PTX generation should succeed");
851 assert!(ptx.contains("elementwise_fused_scale_add_f32"));
852 }
853
854 #[test]
855 fn ptx_template_generates_sub_f32() {
856 let template = ElementwiseTemplate::new(
857 PtxOp::Sub,
858 oxicuda_ptx::ir::PtxType::F32,
859 oxicuda_ptx::arch::SmVersion::Sm80,
860 );
861 let ptx = template
862 .generate()
863 .expect("sub PTX generation should succeed");
864 assert!(ptx.contains("elementwise_sub_f32"));
865 assert!(ptx.contains("sub.f32"));
866 }
867
868 #[test]
869 fn ptx_template_generates_div_f32() {
870 let template = ElementwiseTemplate::new(
871 PtxOp::Div,
872 oxicuda_ptx::ir::PtxType::F32,
873 oxicuda_ptx::arch::SmVersion::Sm80,
874 );
875 let ptx = template
876 .generate()
877 .expect("div PTX generation should succeed");
878 assert!(ptx.contains("elementwise_div_f32"));
879 assert!(ptx.contains("div.rn.f32"));
880 }
881
882 #[test]
883 fn ptx_template_generates_pow_f32() {
884 let template = ElementwiseTemplate::new(
885 PtxOp::Pow,
886 oxicuda_ptx::ir::PtxType::F32,
887 oxicuda_ptx::arch::SmVersion::Sm80,
888 );
889 let ptx = template
890 .generate()
891 .expect("pow PTX generation should succeed");
892 assert!(ptx.contains("elementwise_pow_f32"));
893 assert!(ptx.contains("lg2.approx.f32"));
894 assert!(ptx.contains("ex2.approx.f32"));
895 }
896
897 #[test]
898 fn ptx_template_generates_min_f32() {
899 let template = ElementwiseTemplate::new(
900 PtxOp::Min,
901 oxicuda_ptx::ir::PtxType::F32,
902 oxicuda_ptx::arch::SmVersion::Sm80,
903 );
904 let ptx = template
905 .generate()
906 .expect("min PTX generation should succeed");
907 assert!(ptx.contains("elementwise_min_f32"));
908 assert!(ptx.contains("min.f32"));
909 }
910
911 #[test]
912 fn ptx_template_generates_max_f32() {
913 let template = ElementwiseTemplate::new(
914 PtxOp::Max,
915 oxicuda_ptx::ir::PtxType::F32,
916 oxicuda_ptx::arch::SmVersion::Sm80,
917 );
918 let ptx = template
919 .generate()
920 .expect("max PTX generation should succeed");
921 assert!(ptx.contains("elementwise_max_f32"));
922 assert!(ptx.contains("max.f32"));
923 }
924
925 #[test]
926 fn ptx_template_generates_cmp_eq_f32() {
927 let template = ElementwiseTemplate::new(
928 PtxOp::CmpEq,
929 oxicuda_ptx::ir::PtxType::F32,
930 oxicuda_ptx::arch::SmVersion::Sm80,
931 );
932 let ptx = template
933 .generate()
934 .expect("cmp_eq PTX generation should succeed");
935 assert!(ptx.contains("elementwise_cmp_eq_f32"));
936 assert!(ptx.contains("setp.eq.f32"));
937 assert!(ptx.contains("selp.f32"));
938 }
939
940 #[test]
941 fn ptx_template_generates_cmp_ne_f32() {
942 let template = ElementwiseTemplate::new(
943 PtxOp::CmpNe,
944 oxicuda_ptx::ir::PtxType::F32,
945 oxicuda_ptx::arch::SmVersion::Sm80,
946 );
947 let ptx = template
948 .generate()
949 .expect("cmp_ne PTX generation should succeed");
950 assert!(ptx.contains("elementwise_cmp_ne_f32"));
951 assert!(ptx.contains("setp.ne.f32"));
952 }
953
954 #[test]
955 fn ptx_template_generates_cmp_lt_f32() {
956 let template = ElementwiseTemplate::new(
957 PtxOp::CmpLt,
958 oxicuda_ptx::ir::PtxType::F32,
959 oxicuda_ptx::arch::SmVersion::Sm80,
960 );
961 let ptx = template
962 .generate()
963 .expect("cmp_lt PTX generation should succeed");
964 assert!(ptx.contains("elementwise_cmp_lt_f32"));
965 assert!(ptx.contains("setp.lt.f32"));
966 }
967
968 #[test]
969 fn ptx_template_generates_cmp_gt_f32() {
970 let template = ElementwiseTemplate::new(
971 PtxOp::CmpGt,
972 oxicuda_ptx::ir::PtxType::F32,
973 oxicuda_ptx::arch::SmVersion::Sm80,
974 );
975 let ptx = template
976 .generate()
977 .expect("cmp_gt PTX generation should succeed");
978 assert!(ptx.contains("elementwise_cmp_gt_f32"));
979 assert!(ptx.contains("setp.gt.f32"));
980 }
981
982 #[test]
983 fn ptx_template_generates_cmp_le_f32() {
984 let template = ElementwiseTemplate::new(
985 PtxOp::CmpLe,
986 oxicuda_ptx::ir::PtxType::F32,
987 oxicuda_ptx::arch::SmVersion::Sm80,
988 );
989 let ptx = template
990 .generate()
991 .expect("cmp_le PTX generation should succeed");
992 assert!(ptx.contains("elementwise_cmp_le_f32"));
993 assert!(ptx.contains("setp.le.f32"));
994 }
995
996 #[test]
997 fn ptx_template_generates_cmp_ge_f32() {
998 let template = ElementwiseTemplate::new(
999 PtxOp::CmpGe,
1000 oxicuda_ptx::ir::PtxType::F32,
1001 oxicuda_ptx::arch::SmVersion::Sm80,
1002 );
1003 let ptx = template
1004 .generate()
1005 .expect("cmp_ge PTX generation should succeed");
1006 assert!(ptx.contains("elementwise_cmp_ge_f32"));
1007 assert!(ptx.contains("setp.ge.f32"));
1008 }
1009
1010 #[test]
1011 fn ptx_template_generates_or_max_f32() {
1012 let template = ElementwiseTemplate::new(
1013 PtxOp::OrMax,
1014 oxicuda_ptx::ir::PtxType::F32,
1015 oxicuda_ptx::arch::SmVersion::Sm80,
1016 );
1017 let ptx = template
1018 .generate()
1019 .expect("or_max PTX generation should succeed");
1020 assert!(ptx.contains("elementwise_or_max_f32"));
1021 assert!(ptx.contains("max.f32"));
1022 }
1023
1024 #[test]
1025 fn ptx_template_generates_or_prob_sum_f32() {
1026 let template = ElementwiseTemplate::new(
1027 PtxOp::OrProbSum,
1028 oxicuda_ptx::ir::PtxType::F32,
1029 oxicuda_ptx::arch::SmVersion::Sm80,
1030 );
1031 let ptx = template
1032 .generate()
1033 .expect("or_prob_sum PTX generation should succeed");
1034 assert!(ptx.contains("elementwise_or_prob_sum_f32"));
1035 assert!(ptx.contains("mul.f32"));
1036 assert!(ptx.contains("sub.f32"));
1037 assert!(ptx.contains("add.f32"));
1038 }
1039
1040 #[test]
1041 fn ptx_template_generates_nand_f32() {
1042 let template = ElementwiseTemplate::new(
1043 PtxOp::Nand,
1044 oxicuda_ptx::ir::PtxType::F32,
1045 oxicuda_ptx::arch::SmVersion::Sm80,
1046 );
1047 let ptx = template
1048 .generate()
1049 .expect("nand PTX generation should succeed");
1050 assert!(ptx.contains("elementwise_nand_f32"));
1051 assert!(ptx.contains("mul.f32"));
1052 assert!(ptx.contains("sub.f32"));
1053 }
1054
1055 #[test]
1056 fn ptx_template_generates_nor_f32() {
1057 let template = ElementwiseTemplate::new(
1058 PtxOp::Nor,
1059 oxicuda_ptx::ir::PtxType::F32,
1060 oxicuda_ptx::arch::SmVersion::Sm80,
1061 );
1062 let ptx = template
1063 .generate()
1064 .expect("nor PTX generation should succeed");
1065 assert!(ptx.contains("elementwise_nor_f32"));
1066 assert!(ptx.contains("mul.f32"));
1067 assert!(ptx.contains("add.f32"));
1068 }
1069
1070 #[test]
1071 fn ptx_template_generates_xor_f32() {
1072 let template = ElementwiseTemplate::new(
1073 PtxOp::Xor,
1074 oxicuda_ptx::ir::PtxType::F32,
1075 oxicuda_ptx::arch::SmVersion::Sm80,
1076 );
1077 let ptx = template
1078 .generate()
1079 .expect("xor PTX generation should succeed");
1080 assert!(ptx.contains("elementwise_xor_f32"));
1081 assert!(ptx.contains("mul.f32"));
1082 assert!(ptx.contains("add.f32"));
1083 assert!(ptx.contains("0f40000000")); }
1085}