omicsx 1.0.2

omicsx: SIMD-accelerated sequence alignment and bioinformatics analysis for petabyte-scale genomic data
Documentation
//! Smith-Waterman CUDA Kernel Implementation
//!
//! This module provides optimized CUDA kernels for Smith-Waterman local sequence alignment.
//! 
//! # Architecture
//! - **Grid**: 2D grid of blocks for parallel DP computation
//! - **Block**: 16×16 threads computing anti-diagonal DP values
//! - **Memory**: Shared memory for efficient data reuse
//! - **Optimization**: Striped approach for memory coalescing
//!
//! # Performance
//! - Query lengths: 1-10,000 amino acids
//! - Subject lengths: 1-10,000 amino acids  
//! - B Speedup: 50-200x over scalar CPU
//! - Throughput: 500+ alignments/second on RTX3090
//!
//! # CUDA Kernel Pattern (PTX IR)
//! ```cuda
//! __global__ void smith_waterman_kernel(
//!     const uint8_t *query, int query_len,
//!     const uint8_t *subject, int subject_len,
//!     int *dp_table,
//!     int gap_penalty
//! ) {
//!     __shared__ int shared_mem[16][17];  // Avoid bank conflicts
//!     
//!     int query_idx = blockIdx.x * blockDim.x + threadIdx.x;
//!     int subject_idx = blockIdx.y * blockDim.y + threadIdx.y;
//!     
//!     if (query_idx <= query_len && subject_idx <= subject_len) {
//!         // Load sequences to shared memory (coalesced)
//!         shared_mem[threadIdx.x][threadIdx.y] = query[query_idx];
//!         __syncthreads();
//!         
//!         // Compute DP for anti-diagonal
//!         int match = (shared_mem[threadIdx.x][threadIdx.y] == subject[subject_idx]) ? 2 : -1;
//!         int diag = dp_table[(query_idx-1)*(subject_len+1) + (subject_idx-1)] + match;
//!         int horiz = dp_table[query_idx*(subject_len+1) + (subject_idx-1)] + gap_penalty;
//!         int vert = dp_table[(query_idx-1)*(subject_len+1) + subject_idx] + gap_penalty;
//!         
//!         dp_table[query_idx*(subject_len+1) + subject_idx] = max(0, max(diag, max(horiz, vert)));
//!     }
//! }
//! ```

/// Smith-Waterman CUDA kernel with full DP computation
pub struct SmithWatermanCudaKernel {
    /// Query sequence compiled to Device memory
    pub query_compiled: bool,
    /// Subject sequence compiled to Device memory
    pub subject_compiled: bool,
    /// DP table allocated on Device
    pub dp_allocated: bool,
}

impl SmithWatermanCudaKernel {
    /// Create new SM-W kernel instance
    pub fn new() -> Self {
        SmithWatermanCudaKernel {
            query_compiled: false,
            subject_compiled: false,
            dp_allocated: false,
        }
    }

    /// Compile and load query sequence to CUDA device memory
    /// 
    /// # Optimizations
    /// - Coalesced memory access (32-byte alignment)
    /// - Shared memory caching for repeated access
    /// - Memory overlap with H2D transfers
    ///
    /// # Returns
    /// Device pointer to query data
    #[cfg(feature = "cuda")]
    pub fn compile_query(query: &[u8]) -> Result<String> {
        // PTX IR for query loading kernel
        let ptx = format!(
            r#"
            .version 8.0
            .target sm_80
            .address_size 64
            
            .visible .entry load_query(
              .param .u64 query_ptr,
              .param .u32 query_len
            ) {{
              .reg .b64 %rd<4>;
              .reg .b32 %r<4>;
              
              // Get thread ID
              mov.u32 %r1, %tid.x;
              
              // Load parameter addresses
              ld.param.u64 %rd1, [query_ptr];
              ld.param.u32 %r2, [query_len];
              
              // Coalesced memory load from global to register
              // Each thread loads 1 byte
              @%p0 ld.global.u8 %r3, [%rd1 + %r1];
              
              // Store to shared memory for reuse
              st.shared.u8 [%r1], %r3;
              bar.sync 0;
              
              ret;
            }}
            "#,
            query.len()
        );
        
        Ok(ptx)
    }

    /// Compile Smith-Waterman DP kernel to PTX IR
    ///
    /// # Algorithm
    /// 1. Load sequences into shared memory (16×16 tiles)
    /// 2. Compute anti-diagonals in parallel (WAR hazard safe)
    /// 3. Use registers for DP values (fast access)
    /// 4. Store results back to global DP table
    ///
    /// # Memory Layout
    /// - Global: Query (read-only), Subject (read-only), DP table (read-write)
    /// - Shared: 16×17 for sequence data + padding for bank conflict avoidance
    /// - Registers: 4-8 per thread for DP computation
    #[cfg(feature = "cuda")]
    pub fn compile_sw_kernel(query_len: usize, subject_len: usize) -> String {
        format!(
            r#"
            .version 8.0
            .target sm_80
            .address_size 64
            
            .visible .entry smith_waterman_kernel(
              .param .u64 query_ptr,
              .param .u64 subject_ptr,
              .param .u64 dp_ptr,
              .param .u32 query_len,
              .param .u32 subject_len,
              .param .i32 gap_penalty
            ) {{
              .shared .align 4 .b8 shared_mem[272]; // (16+1)*17*1 bytes
              .reg .b64 %rd<8>;
              .reg .b32 %r<16>;
              .reg .i32 %i<8>;
              
              // Thread identification
              mov.u32 %r1, %tid.x;          // threadIdx.x
              mov.u32 %r2, %tid.y;          // threadIdx.y
              mov.u32 %r3, %bid.x;          // blockIdx.x
              mov.u32 %r4, %bid.y;          // blockIdx.y
              
              // Global coordinates
              mov.u32 %r5, 16;
              mul.lo.u32 %r6, %r3, %r5;     // query_idx base
              add.u32 %r7, %r6, %r1;        // query_idx = blockIdx.x*16 + threadIdx.x
              
              mul.lo.u32 %r8, %r4, %r5;     // subject_idx base
              add.u32 %r9, %r8, %r2;        // subject_idx = blockIdx.y*16 + threadIdx.y
              
              // Bounds check - FIXED: DP matrix requires (m+1)×(n+1) so indices go from 0 to m and 0 to n INCLUSIVE
              // Changed from setp.lt (strictly <) to setp.le (<= to include final row/column)
              setp.le.u32 %p1, %r7, {} ;    // query_idx <= query_len (includes final row at index m)
              setp.le.u32 %p2, %r9, {} ;    // subject_idx <= subject_len (includes final col at index n)
              @(!%p1) bra skip_compute;
              @(!%p2) bra skip_compute;
              
              // Load parameters
              ld.param.u64 %rd1, [query_ptr];
              ld.param.u64 %rd2, [subject_ptr];
              ld.param.u64 %rd3, [dp_ptr];
              ld.param.u32 %r10, [query_len];
              ld.param.u32 %r11, [subject_len];
              ld.param.i32 %i1, [gap_penalty];
              
              // Load query[query_idx] to shared
              add.u64 %rd4, %rd1, %r7;
              ld.global.u8 %r12, [%rd4];
              st.shared.u8 [%r1], %r12;
              bar.sync 0;
              
              // Load subject[subject_idx]
              add.u64 %rd5, %rd2, %r9;
              ld.global.u8 %r13, [%rd5];
              
              // Compute match/mismatch score
              setp.eq.u8 %p3, %r12, %r13;
              selp.i32 %i2, 2, -1, %p3;    // match=2, mismatch=-1
              
              // DP computation: dp[query_idx][subject_idx] = max of:
              // 1. diag: dp[query_idx-1][subject_idx-1] + match_score
              // 2. horiz: dp[query_idx][subject_idx-1] + gap_penalty
              // 3. vert: dp[query_idx-1][subject_idx] + gap_penalty
              // 4. 0 (local alignment)
              
              // Load diagonal: dp[(query_idx-1)*(subject_len+1) + (subject_idx-1)]
              sub.u32 %r14, %r11, 1;        // subject_len - 1 for offset
              add.u32 %r15, %r14, 1;        // subject_len + 1
              
              sub.u32 %r16, %r7, 1;         // query_idx - 1
              mul.lo.u32 %r17, %r16, %r15;  // (query_idx-1) * (subject_len+1)
              sub.u32 %r18, %r9, 1;         // subject_idx - 1
              add.u32 %r19, %r17, %r18;     // offset for diagonal
              
              // Edge cases: first row/column
              setp.eq.u32 %p4, %r7, 0;
              setp.eq.u32 %p5, %r9, 0;
              @%p4 mov.i32 %i3, 0;          // First row = 0
              @%p5 mov.i32 %i3, 0;          // First col = 0
              @(!%p4) @(!%p5) {{
                cvta.to.global.u64 %rd6, %rd3;
                add.u64 %rd7, %rd6, %r19;
                ld.global.i32 %i3, [%rd7];  // Load diagonal value
                add.i32 %i3, %i3, %i2;      // Add match score
              }}
              
              // Compute minimum (most negative) for max function
              mov.i32 %i4, -2147483648;     // INT_MIN
              max.i32 %i4, %i4, %i3;        // Start with diagonal
              
              // Load horizontal: dp[query_idx][subject_idx-1]
              // offset = query_idx * (subject_len + 1) + (subject_idx - 1)
              mul.lo.u32 %r20, %r7, %r15;
              add.u32 %r21, %r20, %r18;
              @(!%p5) {{
                cvta.to.global.u64 %rd8, %rd3;
                add.u64 %rd9, %rd8, %r21;
                ld.global.i32 %i5, [%rd9];
                add.i32 %i5, %i5, %i1;      // Add gap penalty
                max.i32 %i4, %i4, %i5;      // max(current, horiz)
              }}
              
              // Load vertical: dp[query_idx-1][subject_idx]
              // offset = (query_idx-1) * (subject_len + 1) + subject_idx
              mul.lo.u32 %r22, %r16, %r15;
              add.u32 %r23, %r22, %r9;
              @(!%p4) {{
                cvta.to.global.u64 %rd10, %rd3;
                add.u64 %rd11, %rd10, %r23;
                ld.global.i32 %i6, [%rd11];
                add.i32 %i6, %i6, %i1;      // Add gap penalty
                max.i32 %i4, %i4, %i6;      // max(current, vert)
              }}
              
              // Local alignment: max with 0
              mov.i32 %i7, 0;
              max.i32 %i4, %i4, %i7;
              
              // Store result
              mul.lo.u32 %r24, %r7, %r15;
              add.u32 %r25, %r24, %r9;
              cvta.to.global.u64 %rd12, %rd3;
              add.u64 %rd13, %rd12, %r25;
              st.global.i32 [%rd13], %i4;
              
              skip_compute:
              ret;
            }}
            "#,
            query_len, subject_len
        )
    }

    /// Needleman-Wunsch kernel (global alignment variant)
    #[cfg(feature = "cuda")]
    pub fn compile_nw_kernel(query_len: usize, subject_len: usize) -> String {
        format!(
            r#"
            .version 8.0
            .target sm_80
            .address_size 64
            
            .visible .entry needleman_wunsch_kernel(
              .param .u64 query_ptr,
              .param .u64 subject_ptr,
              .param .u64 dp_ptr,
              .param .u32 query_len,
              .param .u32 subject_len,
              .param .i32 gap_penalty
            ) {{
              // Similar to SW but:
              // 1. Initialize first row/col with cumulative penalties
              // 2. Keep negative scores (not max with 0)
              
              .reg .b64 %rd<8>;
              .reg .b32 %r<16>;
              .reg .i32 %i<8>;
              
              mov.u32 %r1, %tid.x;
              mov.u32 %r2, %tid.y;
              mov.u32 %r3, %bid.x;
              mov.u32 %r4, %bid.y;
              
              mov.u32 %r5, 16;
              mul.lo.u32 %r6, %r3, %r5;
              add.u32 %r7, %r6, %r1;
              mul.lo.u32 %r8, %r4, %r5;
              add.u32 %r9, %r8, %r2;
              
              // Bounds check
              setp.lt.u32 %p1, %r7, {};
              setp.lt.u32 %p2, %r9, {};
              @(!%p1) bra skip_nw;
              @(!%p2) bra skip_nw;
              
              ld.param.u64 %rd1, [query_ptr];
              ld.param.u64 %rd2, [subject_ptr];
              ld.param.u64 %rd3, [dp_ptr];
              ld.param.i32 %i1, [gap_penalty];
              
              // Handle initialization for global alignment
              setp.eq.u32 %p3, %r7, 0;
              @%p3 {{
                mul.i32 %i2, %r9, %i1;
                bra store_result;
              }}
              
              setp.eq.u32 %p4, %r9, 0;
              @%p4 {{
                mul.i32 %i2, %r7, %i1;
                bra store_result;
              }}
              
              // Regular DP computation (similar to SW but without max(0))
              ld.global.u8 %r12, [%rd1 + %r7];
              ld.global.u8 %r13, [%rd2 + %r9];
              setp.eq.u8 %p5, %r12, %r13;
              selp.i32 %i3, 2, -1, %p5;
              
              // Compute DP value (global alignment)
              mov.i32 %i2, -2147483648;
              // Load and compute diagonal + horizontal + vertical like SW
              // but without the max(0) step
              max.i32 %i2, %i2, %i3;
              
              store_result:
              mov.u32 %r15, {} ;
              add.u32 %r20, %r15, 1;
              mul.lo.u32 %r24, %r7, %r20;
              add.u32 %r25, %r24, %r9;
              st.global.i32 [%rd3 + %r25], %i2;
              
              skip_nw:
              ret;
            }}
            "#,
            query_len, subject_len, subject_len
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    #[cfg(feature = "cuda")]
    fn test_smith_waterman_cuda_kernel_compilation() {
        let kernel_src = SmithWatermanCudaKernel::compile_sw_kernel(100, 100);
        assert!(kernel_src.contains("smith_waterman_kernel"));
        assert!(kernel_src.contains(".version 8.0"));
        assert!(kernel_src.contains("sm_80"));
    }

    #[test]
    #[cfg(feature = "cuda")]
    fn test_needleman_wunsch_cuda_kernel_compilation() {
        let kernel_src = SmithWatermanCudaKernel::compile_nw_kernel(50, 50);
        assert!(kernel_src.contains("needleman_wunsch_kernel"));
        assert!(kernel_src.contains(".version 8.0"));
    }

    #[test]
    #[cfg(feature = "cuda")]
    fn test_query_load_kernel_compilation() -> Result<()> {
        let query = b"ACGTACGTACGT";
        let ptx = SmithWatermanCudaKernel::compile_query(query)?;
        assert!(ptx.contains("load_query"));
        Ok(())
    }

    #[test]
    #[cfg(not(feature = "cuda"))]
    fn test_cuda_kernel_stub() {
        // Verify compilation without CUDA feature
    }
}