Skip to main content

oxicuda_blas/elementwise/
binary.rs

1//! Binary elementwise operations on device buffers.
2//!
3//! Each function operates on two input arrays and produces one output array.
4//! PTX is generated via [`ElementwiseTemplate`], loaded into the driver, and
5//! launched on the handle's stream.
6
7use std::sync::Arc;
8
9use oxicuda_driver::Module;
10use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
11use oxicuda_memory::DeviceBuffer;
12use oxicuda_ptx::templates::elementwise::{ElementwiseOp as PtxOp, ElementwiseTemplate};
13
14use crate::error::{BlasError, BlasResult};
15use crate::handle::BlasHandle;
16use crate::types::GpuFloat;
17
18// ---------------------------------------------------------------------------
19// Shared helpers
20// ---------------------------------------------------------------------------
21
22/// Standard block size for 1-D elementwise kernels.
23const BLOCK_SIZE: u32 = 256;
24
25/// Validates that all three buffers (a, b, c) have at least `n` elements.
26fn validate_binary_buffers<T: Copy>(
27    n: u32,
28    a: &DeviceBuffer<T>,
29    b: &DeviceBuffer<T>,
30    c: &DeviceBuffer<T>,
31) -> BlasResult<()> {
32    let n_usize = n as usize;
33    if a.len() < n_usize {
34        return Err(BlasError::BufferTooSmall {
35            expected: n_usize,
36            actual: a.len(),
37        });
38    }
39    if b.len() < n_usize {
40        return Err(BlasError::BufferTooSmall {
41            expected: n_usize,
42            actual: b.len(),
43        });
44    }
45    if c.len() < n_usize {
46        return Err(BlasError::BufferTooSmall {
47            expected: n_usize,
48            actual: c.len(),
49        });
50    }
51    Ok(())
52}
53
54/// Generates PTX for a binary op, loads the module, and returns the kernel.
55fn build_binary_kernel(
56    handle: &BlasHandle,
57    ptx_op: PtxOp,
58    ptx_type: oxicuda_ptx::ir::PtxType,
59) -> BlasResult<(Kernel, String)> {
60    let template = ElementwiseTemplate::new(ptx_op, ptx_type, handle.sm_version());
61    let kernel_name = template.kernel_name();
62    let ptx_source = template
63        .generate()
64        .map_err(|e| BlasError::PtxGeneration(format!("{}: {e}", ptx_op.as_str())))?;
65    let module = Arc::new(Module::from_ptx(&ptx_source).map_err(|e| {
66        BlasError::LaunchFailed(format!("module load for {}: {e}", ptx_op.as_str()))
67    })?);
68    let kernel = Kernel::from_module(module, &kernel_name)
69        .map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
70    Ok((kernel, kernel_name))
71}
72
73// ---------------------------------------------------------------------------
74// Public API
75// ---------------------------------------------------------------------------
76
77/// Element-wise addition: `C[i] = A[i] + B[i]`.
78///
79/// # Arguments
80///
81/// * `handle` -- BLAS handle bound to a CUDA context and stream.
82/// * `n` -- number of elements to process.
83/// * `a` -- first input device buffer (at least `n` elements).
84/// * `b` -- second input device buffer (at least `n` elements).
85/// * `c` -- output device buffer for the result (at least `n` elements).
86///
87/// # Errors
88///
89/// Returns [`BlasError::BufferTooSmall`] if any buffer has fewer than `n`
90/// elements, or a PTX/launch error if kernel generation or execution fails.
91pub fn add<T: GpuFloat>(
92    handle: &BlasHandle,
93    n: u32,
94    a: &DeviceBuffer<T>,
95    b: &DeviceBuffer<T>,
96    c: &mut DeviceBuffer<T>,
97) -> BlasResult<()> {
98    if n == 0 {
99        return Ok(());
100    }
101    validate_binary_buffers(n, a, b, c)?;
102
103    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Add, T::PTX_TYPE)?;
104    let grid = grid_size_for(n, BLOCK_SIZE);
105    let params = LaunchParams::new(grid, BLOCK_SIZE);
106    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
107
108    kernel
109        .launch(&params, handle.stream(), &args)
110        .map_err(|e| BlasError::LaunchFailed(format!("add: {e}")))?;
111    Ok(())
112}
113
114/// Element-wise multiplication (Hadamard product): `C[i] = A[i] * B[i]`.
115///
116/// # Arguments
117///
118/// * `handle` -- BLAS handle.
119/// * `n` -- number of elements.
120/// * `a` -- first input device buffer.
121/// * `b` -- second input device buffer.
122/// * `c` -- output device buffer.
123///
124/// # Errors
125///
126/// Returns [`BlasError`] on buffer validation or kernel launch failure.
127pub fn mul<T: GpuFloat>(
128    handle: &BlasHandle,
129    n: u32,
130    a: &DeviceBuffer<T>,
131    b: &DeviceBuffer<T>,
132    c: &mut DeviceBuffer<T>,
133) -> BlasResult<()> {
134    if n == 0 {
135        return Ok(());
136    }
137    validate_binary_buffers(n, a, b, c)?;
138
139    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Mul, T::PTX_TYPE)?;
140    let grid = grid_size_for(n, BLOCK_SIZE);
141    let params = LaunchParams::new(grid, BLOCK_SIZE);
142    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
143
144    kernel
145        .launch(&params, handle.stream(), &args)
146        .map_err(|e| BlasError::LaunchFailed(format!("mul: {e}")))?;
147    Ok(())
148}
149
150/// Element-wise subtraction: `C[i] = A[i] - B[i]`.
151///
152/// # Arguments
153///
154/// * `handle` -- BLAS handle.
155/// * `n` -- number of elements.
156/// * `a` -- first input device buffer.
157/// * `b` -- second input device buffer.
158/// * `c` -- output device buffer.
159///
160/// # Errors
161///
162/// Returns [`BlasError`] on buffer validation or kernel launch failure.
163pub fn sub<T: GpuFloat>(
164    handle: &BlasHandle,
165    n: u32,
166    a: &DeviceBuffer<T>,
167    b: &DeviceBuffer<T>,
168    c: &mut DeviceBuffer<T>,
169) -> BlasResult<()> {
170    if n == 0 {
171        return Ok(());
172    }
173    validate_binary_buffers(n, a, b, c)?;
174    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Sub, T::PTX_TYPE)?;
175    let grid = grid_size_for(n, BLOCK_SIZE);
176    let params = LaunchParams::new(grid, BLOCK_SIZE);
177    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
178    kernel
179        .launch(&params, handle.stream(), &args)
180        .map_err(|e| BlasError::LaunchFailed(format!("sub: {e}")))?;
181    Ok(())
182}
183
184/// Element-wise division: `C[i] = A[i] / B[i]`.
185///
186/// # Arguments
187///
188/// * `handle` -- BLAS handle.
189/// * `n` -- number of elements.
190/// * `a` -- first input device buffer.
191/// * `b` -- second input device buffer.
192/// * `c` -- output device buffer.
193///
194/// # Errors
195///
196/// Returns [`BlasError`] on buffer validation or kernel launch failure.
197pub fn div<T: GpuFloat>(
198    handle: &BlasHandle,
199    n: u32,
200    a: &DeviceBuffer<T>,
201    b: &DeviceBuffer<T>,
202    c: &mut DeviceBuffer<T>,
203) -> BlasResult<()> {
204    if n == 0 {
205        return Ok(());
206    }
207    validate_binary_buffers(n, a, b, c)?;
208    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Div, T::PTX_TYPE)?;
209    let grid = grid_size_for(n, BLOCK_SIZE);
210    let params = LaunchParams::new(grid, BLOCK_SIZE);
211    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
212    kernel
213        .launch(&params, handle.stream(), &args)
214        .map_err(|e| BlasError::LaunchFailed(format!("div: {e}")))?;
215    Ok(())
216}
217
218/// Element-wise power: `C[i] = A[i]^B[i]`.
219///
220/// Uses a lg2+mul+ex2 approximation.
221///
222/// # Arguments
223///
224/// * `handle` -- BLAS handle.
225/// * `n` -- number of elements.
226/// * `a` -- base input device buffer.
227/// * `b` -- exponent input device buffer.
228/// * `c` -- output device buffer.
229///
230/// # Errors
231///
232/// Returns [`BlasError`] on buffer validation or kernel launch failure.
233pub fn pow<T: GpuFloat>(
234    handle: &BlasHandle,
235    n: u32,
236    a: &DeviceBuffer<T>,
237    b: &DeviceBuffer<T>,
238    c: &mut DeviceBuffer<T>,
239) -> BlasResult<()> {
240    if n == 0 {
241        return Ok(());
242    }
243    validate_binary_buffers(n, a, b, c)?;
244    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Pow, T::PTX_TYPE)?;
245    let grid = grid_size_for(n, BLOCK_SIZE);
246    let params = LaunchParams::new(grid, BLOCK_SIZE);
247    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
248    kernel
249        .launch(&params, handle.stream(), &args)
250        .map_err(|e| BlasError::LaunchFailed(format!("pow: {e}")))?;
251    Ok(())
252}
253
254/// Element-wise minimum: `C[i] = min(A[i], B[i])`.
255///
256/// # Arguments
257///
258/// * `handle` -- BLAS handle.
259/// * `n` -- number of elements.
260/// * `a` -- first input device buffer.
261/// * `b` -- second input device buffer.
262/// * `c` -- output device buffer.
263///
264/// # Errors
265///
266/// Returns [`BlasError`] on buffer validation or kernel launch failure.
267pub fn min<T: GpuFloat>(
268    handle: &BlasHandle,
269    n: u32,
270    a: &DeviceBuffer<T>,
271    b: &DeviceBuffer<T>,
272    c: &mut DeviceBuffer<T>,
273) -> BlasResult<()> {
274    if n == 0 {
275        return Ok(());
276    }
277    validate_binary_buffers(n, a, b, c)?;
278    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Min, T::PTX_TYPE)?;
279    let grid = grid_size_for(n, BLOCK_SIZE);
280    let params = LaunchParams::new(grid, BLOCK_SIZE);
281    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
282    kernel
283        .launch(&params, handle.stream(), &args)
284        .map_err(|e| BlasError::LaunchFailed(format!("min: {e}")))?;
285    Ok(())
286}
287
288/// Element-wise maximum: `C[i] = max(A[i], B[i])`.
289///
290/// # Arguments
291///
292/// * `handle` -- BLAS handle.
293/// * `n` -- number of elements.
294/// * `a` -- first input device buffer.
295/// * `b` -- second input device buffer.
296/// * `c` -- output device buffer.
297///
298/// # Errors
299///
300/// Returns [`BlasError`] on buffer validation or kernel launch failure.
301pub fn max<T: GpuFloat>(
302    handle: &BlasHandle,
303    n: u32,
304    a: &DeviceBuffer<T>,
305    b: &DeviceBuffer<T>,
306    c: &mut DeviceBuffer<T>,
307) -> BlasResult<()> {
308    if n == 0 {
309        return Ok(());
310    }
311    validate_binary_buffers(n, a, b, c)?;
312    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Max, T::PTX_TYPE)?;
313    let grid = grid_size_for(n, BLOCK_SIZE);
314    let params = LaunchParams::new(grid, BLOCK_SIZE);
315    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
316    kernel
317        .launch(&params, handle.stream(), &args)
318        .map_err(|e| BlasError::LaunchFailed(format!("max: {e}")))?;
319    Ok(())
320}
321
322/// Comparison equal: `C[i] = (A[i] == B[i]) ? 1.0 : 0.0`.
323///
324/// # Arguments
325///
326/// * `handle` -- BLAS handle.
327/// * `n` -- number of elements.
328/// * `a` -- first input device buffer.
329/// * `b` -- second input device buffer.
330/// * `c` -- output device buffer.
331///
332/// # Errors
333///
334/// Returns [`BlasError`] on buffer validation or kernel launch failure.
335pub fn cmp_eq<T: GpuFloat>(
336    handle: &BlasHandle,
337    n: u32,
338    a: &DeviceBuffer<T>,
339    b: &DeviceBuffer<T>,
340    c: &mut DeviceBuffer<T>,
341) -> BlasResult<()> {
342    if n == 0 {
343        return Ok(());
344    }
345    validate_binary_buffers(n, a, b, c)?;
346    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpEq, T::PTX_TYPE)?;
347    let grid = grid_size_for(n, BLOCK_SIZE);
348    let params = LaunchParams::new(grid, BLOCK_SIZE);
349    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
350    kernel
351        .launch(&params, handle.stream(), &args)
352        .map_err(|e| BlasError::LaunchFailed(format!("cmp_eq: {e}")))?;
353    Ok(())
354}
355
356/// Comparison not-equal: `C[i] = (A[i] != B[i]) ? 1.0 : 0.0`.
357///
358/// # Arguments
359///
360/// * `handle` -- BLAS handle.
361/// * `n` -- number of elements.
362/// * `a` -- first input device buffer.
363/// * `b` -- second input device buffer.
364/// * `c` -- output device buffer.
365///
366/// # Errors
367///
368/// Returns [`BlasError`] on buffer validation or kernel launch failure.
369pub fn cmp_ne<T: GpuFloat>(
370    handle: &BlasHandle,
371    n: u32,
372    a: &DeviceBuffer<T>,
373    b: &DeviceBuffer<T>,
374    c: &mut DeviceBuffer<T>,
375) -> BlasResult<()> {
376    if n == 0 {
377        return Ok(());
378    }
379    validate_binary_buffers(n, a, b, c)?;
380    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpNe, T::PTX_TYPE)?;
381    let grid = grid_size_for(n, BLOCK_SIZE);
382    let params = LaunchParams::new(grid, BLOCK_SIZE);
383    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
384    kernel
385        .launch(&params, handle.stream(), &args)
386        .map_err(|e| BlasError::LaunchFailed(format!("cmp_ne: {e}")))?;
387    Ok(())
388}
389
390/// Comparison less-than: `C[i] = (A[i] < B[i]) ? 1.0 : 0.0`.
391///
392/// # Arguments
393///
394/// * `handle` -- BLAS handle.
395/// * `n` -- number of elements.
396/// * `a` -- first input device buffer.
397/// * `b` -- second input device buffer.
398/// * `c` -- output device buffer.
399///
400/// # Errors
401///
402/// Returns [`BlasError`] on buffer validation or kernel launch failure.
403pub fn cmp_lt<T: GpuFloat>(
404    handle: &BlasHandle,
405    n: u32,
406    a: &DeviceBuffer<T>,
407    b: &DeviceBuffer<T>,
408    c: &mut DeviceBuffer<T>,
409) -> BlasResult<()> {
410    if n == 0 {
411        return Ok(());
412    }
413    validate_binary_buffers(n, a, b, c)?;
414    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLt, T::PTX_TYPE)?;
415    let grid = grid_size_for(n, BLOCK_SIZE);
416    let params = LaunchParams::new(grid, BLOCK_SIZE);
417    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
418    kernel
419        .launch(&params, handle.stream(), &args)
420        .map_err(|e| BlasError::LaunchFailed(format!("cmp_lt: {e}")))?;
421    Ok(())
422}
423
424/// Comparison greater-than: `C[i] = (A[i] > B[i]) ? 1.0 : 0.0`.
425///
426/// # Arguments
427///
428/// * `handle` -- BLAS handle.
429/// * `n` -- number of elements.
430/// * `a` -- first input device buffer.
431/// * `b` -- second input device buffer.
432/// * `c` -- output device buffer.
433///
434/// # Errors
435///
436/// Returns [`BlasError`] on buffer validation or kernel launch failure.
437pub fn cmp_gt<T: GpuFloat>(
438    handle: &BlasHandle,
439    n: u32,
440    a: &DeviceBuffer<T>,
441    b: &DeviceBuffer<T>,
442    c: &mut DeviceBuffer<T>,
443) -> BlasResult<()> {
444    if n == 0 {
445        return Ok(());
446    }
447    validate_binary_buffers(n, a, b, c)?;
448    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGt, T::PTX_TYPE)?;
449    let grid = grid_size_for(n, BLOCK_SIZE);
450    let params = LaunchParams::new(grid, BLOCK_SIZE);
451    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
452    kernel
453        .launch(&params, handle.stream(), &args)
454        .map_err(|e| BlasError::LaunchFailed(format!("cmp_gt: {e}")))?;
455    Ok(())
456}
457
458/// Comparison less-or-equal: `C[i] = (A[i] <= B[i]) ? 1.0 : 0.0`.
459///
460/// # Arguments
461///
462/// * `handle` -- BLAS handle.
463/// * `n` -- number of elements.
464/// * `a` -- first input device buffer.
465/// * `b` -- second input device buffer.
466/// * `c` -- output device buffer.
467///
468/// # Errors
469///
470/// Returns [`BlasError`] on buffer validation or kernel launch failure.
471pub fn cmp_le<T: GpuFloat>(
472    handle: &BlasHandle,
473    n: u32,
474    a: &DeviceBuffer<T>,
475    b: &DeviceBuffer<T>,
476    c: &mut DeviceBuffer<T>,
477) -> BlasResult<()> {
478    if n == 0 {
479        return Ok(());
480    }
481    validate_binary_buffers(n, a, b, c)?;
482    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLe, T::PTX_TYPE)?;
483    let grid = grid_size_for(n, BLOCK_SIZE);
484    let params = LaunchParams::new(grid, BLOCK_SIZE);
485    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
486    kernel
487        .launch(&params, handle.stream(), &args)
488        .map_err(|e| BlasError::LaunchFailed(format!("cmp_le: {e}")))?;
489    Ok(())
490}
491
492/// Comparison greater-or-equal: `C[i] = (A[i] >= B[i]) ? 1.0 : 0.0`.
493///
494/// # Arguments
495///
496/// * `handle` -- BLAS handle.
497/// * `n` -- number of elements.
498/// * `a` -- first input device buffer.
499/// * `b` -- second input device buffer.
500/// * `c` -- output device buffer.
501///
502/// # Errors
503///
504/// Returns [`BlasError`] on buffer validation or kernel launch failure.
505pub fn cmp_ge<T: GpuFloat>(
506    handle: &BlasHandle,
507    n: u32,
508    a: &DeviceBuffer<T>,
509    b: &DeviceBuffer<T>,
510    c: &mut DeviceBuffer<T>,
511) -> BlasResult<()> {
512    if n == 0 {
513        return Ok(());
514    }
515    validate_binary_buffers(n, a, b, c)?;
516    let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGe, T::PTX_TYPE)?;
517    let grid = grid_size_for(n, BLOCK_SIZE);
518    let params = LaunchParams::new(grid, BLOCK_SIZE);
519    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
520    kernel
521        .launch(&params, handle.stream(), &args)
522        .map_err(|e| BlasError::LaunchFailed(format!("cmp_ge: {e}")))?;
523    Ok(())
524}
525
526/// Fuzzy OR via max: `C[i] = max(A[i], B[i])`.
527///
528/// # Arguments
529///
530/// * `handle` -- BLAS handle.
531/// * `n` -- number of elements.
532/// * `a` -- first input device buffer.
533/// * `b` -- second input device buffer.
534/// * `c` -- output device buffer.
535///
536/// # Errors
537///
538/// Returns [`BlasError`] on buffer validation or kernel launch failure.
539pub fn or_max<T: GpuFloat>(
540    handle: &BlasHandle,
541    n: u32,
542    a: &DeviceBuffer<T>,
543    b: &DeviceBuffer<T>,
544    c: &mut DeviceBuffer<T>,
545) -> BlasResult<()> {
546    if n == 0 {
547        return Ok(());
548    }
549    validate_binary_buffers(n, a, b, c)?;
550    let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrMax, T::PTX_TYPE)?;
551    let grid = grid_size_for(n, BLOCK_SIZE);
552    let params = LaunchParams::new(grid, BLOCK_SIZE);
553    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
554    kernel
555        .launch(&params, handle.stream(), &args)
556        .map_err(|e| BlasError::LaunchFailed(format!("or_max: {e}")))?;
557    Ok(())
558}
559
560/// Probabilistic OR: `C[i] = A[i] + B[i] - A[i]*B[i]`.
561///
562/// # Arguments
563///
564/// * `handle` -- BLAS handle.
565/// * `n` -- number of elements.
566/// * `a` -- first input device buffer.
567/// * `b` -- second input device buffer.
568/// * `c` -- output device buffer.
569///
570/// # Errors
571///
572/// Returns [`BlasError`] on buffer validation or kernel launch failure.
573pub fn or_prob_sum<T: GpuFloat>(
574    handle: &BlasHandle,
575    n: u32,
576    a: &DeviceBuffer<T>,
577    b: &DeviceBuffer<T>,
578    c: &mut DeviceBuffer<T>,
579) -> BlasResult<()> {
580    if n == 0 {
581        return Ok(());
582    }
583    validate_binary_buffers(n, a, b, c)?;
584    let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrProbSum, T::PTX_TYPE)?;
585    let grid = grid_size_for(n, BLOCK_SIZE);
586    let params = LaunchParams::new(grid, BLOCK_SIZE);
587    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
588    kernel
589        .launch(&params, handle.stream(), &args)
590        .map_err(|e| BlasError::LaunchFailed(format!("or_prob_sum: {e}")))?;
591    Ok(())
592}
593
594/// Fuzzy NAND: `C[i] = 1 - A[i]*B[i]`.
595///
596/// # Arguments
597///
598/// * `handle` -- BLAS handle.
599/// * `n` -- number of elements.
600/// * `a` -- first input device buffer.
601/// * `b` -- second input device buffer.
602/// * `c` -- output device buffer.
603///
604/// # Errors
605///
606/// Returns [`BlasError`] on buffer validation or kernel launch failure.
607pub fn nand<T: GpuFloat>(
608    handle: &BlasHandle,
609    n: u32,
610    a: &DeviceBuffer<T>,
611    b: &DeviceBuffer<T>,
612    c: &mut DeviceBuffer<T>,
613) -> BlasResult<()> {
614    if n == 0 {
615        return Ok(());
616    }
617    validate_binary_buffers(n, a, b, c)?;
618    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nand, T::PTX_TYPE)?;
619    let grid = grid_size_for(n, BLOCK_SIZE);
620    let params = LaunchParams::new(grid, BLOCK_SIZE);
621    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
622    kernel
623        .launch(&params, handle.stream(), &args)
624        .map_err(|e| BlasError::LaunchFailed(format!("nand: {e}")))?;
625    Ok(())
626}
627
628/// Fuzzy NOR: `C[i] = 1 - (A[i] + B[i] - A[i]*B[i])`.
629///
630/// # Arguments
631///
632/// * `handle` -- BLAS handle.
633/// * `n` -- number of elements.
634/// * `a` -- first input device buffer.
635/// * `b` -- second input device buffer.
636/// * `c` -- output device buffer.
637///
638/// # Errors
639///
640/// Returns [`BlasError`] on buffer validation or kernel launch failure.
641pub fn nor<T: GpuFloat>(
642    handle: &BlasHandle,
643    n: u32,
644    a: &DeviceBuffer<T>,
645    b: &DeviceBuffer<T>,
646    c: &mut DeviceBuffer<T>,
647) -> BlasResult<()> {
648    if n == 0 {
649        return Ok(());
650    }
651    validate_binary_buffers(n, a, b, c)?;
652    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nor, T::PTX_TYPE)?;
653    let grid = grid_size_for(n, BLOCK_SIZE);
654    let params = LaunchParams::new(grid, BLOCK_SIZE);
655    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
656    kernel
657        .launch(&params, handle.stream(), &args)
658        .map_err(|e| BlasError::LaunchFailed(format!("nor: {e}")))?;
659    Ok(())
660}
661
662/// Fuzzy XOR: `C[i] = A[i] + B[i] - 2*A[i]*B[i]`.
663///
664/// # Arguments
665///
666/// * `handle` -- BLAS handle.
667/// * `n` -- number of elements.
668/// * `a` -- first input device buffer.
669/// * `b` -- second input device buffer.
670/// * `c` -- output device buffer.
671///
672/// # Errors
673///
674/// Returns [`BlasError`] on buffer validation or kernel launch failure.
675pub fn xor<T: GpuFloat>(
676    handle: &BlasHandle,
677    n: u32,
678    a: &DeviceBuffer<T>,
679    b: &DeviceBuffer<T>,
680    c: &mut DeviceBuffer<T>,
681) -> BlasResult<()> {
682    if n == 0 {
683        return Ok(());
684    }
685    validate_binary_buffers(n, a, b, c)?;
686    let (kernel, _name) = build_binary_kernel(handle, PtxOp::Xor, T::PTX_TYPE)?;
687    let grid = grid_size_for(n, BLOCK_SIZE);
688    let params = LaunchParams::new(grid, BLOCK_SIZE);
689    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
690    kernel
691        .launch(&params, handle.stream(), &args)
692        .map_err(|e| BlasError::LaunchFailed(format!("xor: {e}")))?;
693    Ok(())
694}
695
696/// Fused Add + ReLU: `C[i] = max(0, A[i] + B[i])`.
697///
698/// Performs element-wise addition followed by ReLU in a single kernel launch,
699/// avoiding an extra global memory round-trip compared to separate add and
700/// relu calls.
701///
702/// # Arguments
703///
704/// * `handle` -- BLAS handle.
705/// * `n` -- number of elements.
706/// * `a` -- first input device buffer.
707/// * `b` -- second input device buffer.
708/// * `c` -- output device buffer.
709///
710/// # Errors
711///
712/// Returns [`BlasError`] on buffer validation or kernel launch failure.
713pub fn fused_add_relu<T: GpuFloat>(
714    handle: &BlasHandle,
715    n: u32,
716    a: &DeviceBuffer<T>,
717    b: &DeviceBuffer<T>,
718    c: &mut DeviceBuffer<T>,
719) -> BlasResult<()> {
720    if n == 0 {
721        return Ok(());
722    }
723    validate_binary_buffers(n, a, b, c)?;
724
725    let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedAddRelu, T::PTX_TYPE)?;
726    let grid = grid_size_for(n, BLOCK_SIZE);
727    let params = LaunchParams::new(grid, BLOCK_SIZE);
728    let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
729
730    kernel
731        .launch(&params, handle.stream(), &args)
732        .map_err(|e| BlasError::LaunchFailed(format!("fused_add_relu: {e}")))?;
733    Ok(())
734}
735
736/// Fused Scale-Add: `C[i] = alpha * A[i] + beta * B[i]`.
737///
738/// Combines scaling of two input arrays and their addition into a single
739/// kernel launch. This is equivalent to the BLAS `axpby` pattern extended
740/// to output a separate result buffer.
741///
742/// # Arguments
743///
744/// * `handle` -- BLAS handle.
745/// * `n` -- number of elements.
746/// * `alpha` -- scalar multiplier for `A`.
747/// * `a` -- first input device buffer.
748/// * `beta` -- scalar multiplier for `B`.
749/// * `b` -- second input device buffer.
750/// * `c` -- output device buffer.
751///
752/// # Errors
753///
754/// Returns [`BlasError`] on buffer validation or kernel launch failure.
755pub fn fused_scale_add<T: GpuFloat>(
756    handle: &BlasHandle,
757    n: u32,
758    alpha: T,
759    a: &DeviceBuffer<T>,
760    beta: T,
761    b: &DeviceBuffer<T>,
762    c: &mut DeviceBuffer<T>,
763) -> BlasResult<()> {
764    if n == 0 {
765        return Ok(());
766    }
767    validate_binary_buffers(n, a, b, c)?;
768
769    let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedScaleAdd, T::PTX_TYPE)?;
770    let grid = grid_size_for(n, BLOCK_SIZE);
771    let params = LaunchParams::new(grid, BLOCK_SIZE);
772
773    // FusedScaleAdd kernel signature:
774    //   (a_ptr, b_ptr, c_ptr, alpha_bits, beta_bits, n)
775    let alpha_bits = alpha.to_bits_u64();
776    let beta_bits = beta.to_bits_u64();
777    let args = (
778        a.as_device_ptr(),
779        b.as_device_ptr(),
780        c.as_device_ptr(),
781        alpha_bits,
782        beta_bits,
783        n,
784    );
785
786    kernel
787        .launch(&params, handle.stream(), &args)
788        .map_err(|e| BlasError::LaunchFailed(format!("fused_scale_add: {e}")))?;
789    Ok(())
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795
796    #[test]
797    fn block_size_is_power_of_two() {
798        assert!(BLOCK_SIZE.is_power_of_two());
799        const { assert!(BLOCK_SIZE >= 32) };
800    }
801
802    #[test]
803    fn ptx_template_generates_add_f32() {
804        let template = ElementwiseTemplate::new(
805            PtxOp::Add,
806            oxicuda_ptx::ir::PtxType::F32,
807            oxicuda_ptx::arch::SmVersion::Sm80,
808        );
809        let ptx = template
810            .generate()
811            .expect("add PTX generation should succeed");
812        assert!(ptx.contains("elementwise_add_f32"));
813    }
814
815    #[test]
816    fn ptx_template_generates_mul_f64() {
817        let template = ElementwiseTemplate::new(
818            PtxOp::Mul,
819            oxicuda_ptx::ir::PtxType::F64,
820            oxicuda_ptx::arch::SmVersion::Sm80,
821        );
822        let ptx = template
823            .generate()
824            .expect("mul PTX generation should succeed");
825        assert!(ptx.contains("elementwise_mul_f64"));
826    }
827
828    #[test]
829    fn ptx_template_generates_fused_add_relu_f32() {
830        let template = ElementwiseTemplate::new(
831            PtxOp::FusedAddRelu,
832            oxicuda_ptx::ir::PtxType::F32,
833            oxicuda_ptx::arch::SmVersion::Sm80,
834        );
835        let ptx = template
836            .generate()
837            .expect("fused_add_relu PTX generation should succeed");
838        assert!(ptx.contains("elementwise_fused_add_relu_f32"));
839    }
840
841    #[test]
842    fn ptx_template_generates_fused_scale_add_f32() {
843        let template = ElementwiseTemplate::new(
844            PtxOp::FusedScaleAdd,
845            oxicuda_ptx::ir::PtxType::F32,
846            oxicuda_ptx::arch::SmVersion::Sm80,
847        );
848        let ptx = template
849            .generate()
850            .expect("fused_scale_add PTX generation should succeed");
851        assert!(ptx.contains("elementwise_fused_scale_add_f32"));
852    }
853
854    #[test]
855    fn ptx_template_generates_sub_f32() {
856        let template = ElementwiseTemplate::new(
857            PtxOp::Sub,
858            oxicuda_ptx::ir::PtxType::F32,
859            oxicuda_ptx::arch::SmVersion::Sm80,
860        );
861        let ptx = template
862            .generate()
863            .expect("sub PTX generation should succeed");
864        assert!(ptx.contains("elementwise_sub_f32"));
865        assert!(ptx.contains("sub.f32"));
866    }
867
868    #[test]
869    fn ptx_template_generates_div_f32() {
870        let template = ElementwiseTemplate::new(
871            PtxOp::Div,
872            oxicuda_ptx::ir::PtxType::F32,
873            oxicuda_ptx::arch::SmVersion::Sm80,
874        );
875        let ptx = template
876            .generate()
877            .expect("div PTX generation should succeed");
878        assert!(ptx.contains("elementwise_div_f32"));
879        assert!(ptx.contains("div.rn.f32"));
880    }
881
882    #[test]
883    fn ptx_template_generates_pow_f32() {
884        let template = ElementwiseTemplate::new(
885            PtxOp::Pow,
886            oxicuda_ptx::ir::PtxType::F32,
887            oxicuda_ptx::arch::SmVersion::Sm80,
888        );
889        let ptx = template
890            .generate()
891            .expect("pow PTX generation should succeed");
892        assert!(ptx.contains("elementwise_pow_f32"));
893        assert!(ptx.contains("lg2.approx.f32"));
894        assert!(ptx.contains("ex2.approx.f32"));
895    }
896
897    #[test]
898    fn ptx_template_generates_min_f32() {
899        let template = ElementwiseTemplate::new(
900            PtxOp::Min,
901            oxicuda_ptx::ir::PtxType::F32,
902            oxicuda_ptx::arch::SmVersion::Sm80,
903        );
904        let ptx = template
905            .generate()
906            .expect("min PTX generation should succeed");
907        assert!(ptx.contains("elementwise_min_f32"));
908        assert!(ptx.contains("min.f32"));
909    }
910
911    #[test]
912    fn ptx_template_generates_max_f32() {
913        let template = ElementwiseTemplate::new(
914            PtxOp::Max,
915            oxicuda_ptx::ir::PtxType::F32,
916            oxicuda_ptx::arch::SmVersion::Sm80,
917        );
918        let ptx = template
919            .generate()
920            .expect("max PTX generation should succeed");
921        assert!(ptx.contains("elementwise_max_f32"));
922        assert!(ptx.contains("max.f32"));
923    }
924
925    #[test]
926    fn ptx_template_generates_cmp_eq_f32() {
927        let template = ElementwiseTemplate::new(
928            PtxOp::CmpEq,
929            oxicuda_ptx::ir::PtxType::F32,
930            oxicuda_ptx::arch::SmVersion::Sm80,
931        );
932        let ptx = template
933            .generate()
934            .expect("cmp_eq PTX generation should succeed");
935        assert!(ptx.contains("elementwise_cmp_eq_f32"));
936        assert!(ptx.contains("setp.eq.f32"));
937        assert!(ptx.contains("selp.f32"));
938    }
939
940    #[test]
941    fn ptx_template_generates_cmp_ne_f32() {
942        let template = ElementwiseTemplate::new(
943            PtxOp::CmpNe,
944            oxicuda_ptx::ir::PtxType::F32,
945            oxicuda_ptx::arch::SmVersion::Sm80,
946        );
947        let ptx = template
948            .generate()
949            .expect("cmp_ne PTX generation should succeed");
950        assert!(ptx.contains("elementwise_cmp_ne_f32"));
951        assert!(ptx.contains("setp.ne.f32"));
952    }
953
954    #[test]
955    fn ptx_template_generates_cmp_lt_f32() {
956        let template = ElementwiseTemplate::new(
957            PtxOp::CmpLt,
958            oxicuda_ptx::ir::PtxType::F32,
959            oxicuda_ptx::arch::SmVersion::Sm80,
960        );
961        let ptx = template
962            .generate()
963            .expect("cmp_lt PTX generation should succeed");
964        assert!(ptx.contains("elementwise_cmp_lt_f32"));
965        assert!(ptx.contains("setp.lt.f32"));
966    }
967
968    #[test]
969    fn ptx_template_generates_cmp_gt_f32() {
970        let template = ElementwiseTemplate::new(
971            PtxOp::CmpGt,
972            oxicuda_ptx::ir::PtxType::F32,
973            oxicuda_ptx::arch::SmVersion::Sm80,
974        );
975        let ptx = template
976            .generate()
977            .expect("cmp_gt PTX generation should succeed");
978        assert!(ptx.contains("elementwise_cmp_gt_f32"));
979        assert!(ptx.contains("setp.gt.f32"));
980    }
981
982    #[test]
983    fn ptx_template_generates_cmp_le_f32() {
984        let template = ElementwiseTemplate::new(
985            PtxOp::CmpLe,
986            oxicuda_ptx::ir::PtxType::F32,
987            oxicuda_ptx::arch::SmVersion::Sm80,
988        );
989        let ptx = template
990            .generate()
991            .expect("cmp_le PTX generation should succeed");
992        assert!(ptx.contains("elementwise_cmp_le_f32"));
993        assert!(ptx.contains("setp.le.f32"));
994    }
995
996    #[test]
997    fn ptx_template_generates_cmp_ge_f32() {
998        let template = ElementwiseTemplate::new(
999            PtxOp::CmpGe,
1000            oxicuda_ptx::ir::PtxType::F32,
1001            oxicuda_ptx::arch::SmVersion::Sm80,
1002        );
1003        let ptx = template
1004            .generate()
1005            .expect("cmp_ge PTX generation should succeed");
1006        assert!(ptx.contains("elementwise_cmp_ge_f32"));
1007        assert!(ptx.contains("setp.ge.f32"));
1008    }
1009
1010    #[test]
1011    fn ptx_template_generates_or_max_f32() {
1012        let template = ElementwiseTemplate::new(
1013            PtxOp::OrMax,
1014            oxicuda_ptx::ir::PtxType::F32,
1015            oxicuda_ptx::arch::SmVersion::Sm80,
1016        );
1017        let ptx = template
1018            .generate()
1019            .expect("or_max PTX generation should succeed");
1020        assert!(ptx.contains("elementwise_or_max_f32"));
1021        assert!(ptx.contains("max.f32"));
1022    }
1023
1024    #[test]
1025    fn ptx_template_generates_or_prob_sum_f32() {
1026        let template = ElementwiseTemplate::new(
1027            PtxOp::OrProbSum,
1028            oxicuda_ptx::ir::PtxType::F32,
1029            oxicuda_ptx::arch::SmVersion::Sm80,
1030        );
1031        let ptx = template
1032            .generate()
1033            .expect("or_prob_sum PTX generation should succeed");
1034        assert!(ptx.contains("elementwise_or_prob_sum_f32"));
1035        assert!(ptx.contains("mul.f32"));
1036        assert!(ptx.contains("sub.f32"));
1037        assert!(ptx.contains("add.f32"));
1038    }
1039
1040    #[test]
1041    fn ptx_template_generates_nand_f32() {
1042        let template = ElementwiseTemplate::new(
1043            PtxOp::Nand,
1044            oxicuda_ptx::ir::PtxType::F32,
1045            oxicuda_ptx::arch::SmVersion::Sm80,
1046        );
1047        let ptx = template
1048            .generate()
1049            .expect("nand PTX generation should succeed");
1050        assert!(ptx.contains("elementwise_nand_f32"));
1051        assert!(ptx.contains("mul.f32"));
1052        assert!(ptx.contains("sub.f32"));
1053    }
1054
1055    #[test]
1056    fn ptx_template_generates_nor_f32() {
1057        let template = ElementwiseTemplate::new(
1058            PtxOp::Nor,
1059            oxicuda_ptx::ir::PtxType::F32,
1060            oxicuda_ptx::arch::SmVersion::Sm80,
1061        );
1062        let ptx = template
1063            .generate()
1064            .expect("nor PTX generation should succeed");
1065        assert!(ptx.contains("elementwise_nor_f32"));
1066        assert!(ptx.contains("mul.f32"));
1067        assert!(ptx.contains("add.f32"));
1068    }
1069
1070    #[test]
1071    fn ptx_template_generates_xor_f32() {
1072        let template = ElementwiseTemplate::new(
1073            PtxOp::Xor,
1074            oxicuda_ptx::ir::PtxType::F32,
1075            oxicuda_ptx::arch::SmVersion::Sm80,
1076        );
1077        let ptx = template
1078            .generate()
1079            .expect("xor PTX generation should succeed");
1080        assert!(ptx.contains("elementwise_xor_f32"));
1081        assert!(ptx.contains("mul.f32"));
1082        assert!(ptx.contains("add.f32"));
1083        assert!(ptx.contains("0f40000000")); // 2.0
1084    }
1085}