pub struct SmithWatermanCudaKernel {
pub query_compiled: bool,
pub subject_compiled: bool,
pub dp_allocated: bool,
}
impl SmithWatermanCudaKernel {
pub fn new() -> Self {
SmithWatermanCudaKernel {
query_compiled: false,
subject_compiled: false,
dp_allocated: false,
}
}
#[cfg(feature = "cuda")]
pub fn compile_query(query: &[u8]) -> Result<String> {
let ptx = format!(
r#"
.version 8.0
.target sm_80
.address_size 64
.visible .entry load_query(
.param .u64 query_ptr,
.param .u32 query_len
) {{
.reg .b64 %rd<4>;
.reg .b32 %r<4>;
// Get thread ID
mov.u32 %r1, %tid.x;
// Load parameter addresses
ld.param.u64 %rd1, [query_ptr];
ld.param.u32 %r2, [query_len];
// Coalesced memory load from global to register
// Each thread loads 1 byte
@%p0 ld.global.u8 %r3, [%rd1 + %r1];
// Store to shared memory for reuse
st.shared.u8 [%r1], %r3;
bar.sync 0;
ret;
}}
"#,
query.len()
);
Ok(ptx)
}
#[cfg(feature = "cuda")]
pub fn compile_sw_kernel(query_len: usize, subject_len: usize) -> String {
format!(
r#"
.version 8.0
.target sm_80
.address_size 64
.visible .entry smith_waterman_kernel(
.param .u64 query_ptr,
.param .u64 subject_ptr,
.param .u64 dp_ptr,
.param .u32 query_len,
.param .u32 subject_len,
.param .i32 gap_penalty
) {{
.shared .align 4 .b8 shared_mem[272]; // (16+1)*17*1 bytes
.reg .b64 %rd<8>;
.reg .b32 %r<16>;
.reg .i32 %i<8>;
// Thread identification
mov.u32 %r1, %tid.x; // threadIdx.x
mov.u32 %r2, %tid.y; // threadIdx.y
mov.u32 %r3, %bid.x; // blockIdx.x
mov.u32 %r4, %bid.y; // blockIdx.y
// Global coordinates
mov.u32 %r5, 16;
mul.lo.u32 %r6, %r3, %r5; // query_idx base
add.u32 %r7, %r6, %r1; // query_idx = blockIdx.x*16 + threadIdx.x
mul.lo.u32 %r8, %r4, %r5; // subject_idx base
add.u32 %r9, %r8, %r2; // subject_idx = blockIdx.y*16 + threadIdx.y
// Bounds check - FIXED: DP matrix requires (m+1)×(n+1) so indices go from 0 to m and 0 to n INCLUSIVE
// Changed from setp.lt (strictly <) to setp.le (<= to include final row/column)
setp.le.u32 %p1, %r7, {} ; // query_idx <= query_len (includes final row at index m)
setp.le.u32 %p2, %r9, {} ; // subject_idx <= subject_len (includes final col at index n)
@(!%p1) bra skip_compute;
@(!%p2) bra skip_compute;
// Load parameters
ld.param.u64 %rd1, [query_ptr];
ld.param.u64 %rd2, [subject_ptr];
ld.param.u64 %rd3, [dp_ptr];
ld.param.u32 %r10, [query_len];
ld.param.u32 %r11, [subject_len];
ld.param.i32 %i1, [gap_penalty];
// Load query[query_idx] to shared
add.u64 %rd4, %rd1, %r7;
ld.global.u8 %r12, [%rd4];
st.shared.u8 [%r1], %r12;
bar.sync 0;
// Load subject[subject_idx]
add.u64 %rd5, %rd2, %r9;
ld.global.u8 %r13, [%rd5];
// Compute match/mismatch score
setp.eq.u8 %p3, %r12, %r13;
selp.i32 %i2, 2, -1, %p3; // match=2, mismatch=-1
// DP computation: dp[query_idx][subject_idx] = max of:
// 1. diag: dp[query_idx-1][subject_idx-1] + match_score
// 2. horiz: dp[query_idx][subject_idx-1] + gap_penalty
// 3. vert: dp[query_idx-1][subject_idx] + gap_penalty
// 4. 0 (local alignment)
// Load diagonal: dp[(query_idx-1)*(subject_len+1) + (subject_idx-1)]
sub.u32 %r14, %r11, 1; // subject_len - 1 for offset
add.u32 %r15, %r14, 1; // subject_len + 1
sub.u32 %r16, %r7, 1; // query_idx - 1
mul.lo.u32 %r17, %r16, %r15; // (query_idx-1) * (subject_len+1)
sub.u32 %r18, %r9, 1; // subject_idx - 1
add.u32 %r19, %r17, %r18; // offset for diagonal
// Edge cases: first row/column
setp.eq.u32 %p4, %r7, 0;
setp.eq.u32 %p5, %r9, 0;
@%p4 mov.i32 %i3, 0; // First row = 0
@%p5 mov.i32 %i3, 0; // First col = 0
@(!%p4) @(!%p5) {{
cvta.to.global.u64 %rd6, %rd3;
add.u64 %rd7, %rd6, %r19;
ld.global.i32 %i3, [%rd7]; // Load diagonal value
add.i32 %i3, %i3, %i2; // Add match score
}}
// Compute minimum (most negative) for max function
mov.i32 %i4, -2147483648; // INT_MIN
max.i32 %i4, %i4, %i3; // Start with diagonal
// Load horizontal: dp[query_idx][subject_idx-1]
// offset = query_idx * (subject_len + 1) + (subject_idx - 1)
mul.lo.u32 %r20, %r7, %r15;
add.u32 %r21, %r20, %r18;
@(!%p5) {{
cvta.to.global.u64 %rd8, %rd3;
add.u64 %rd9, %rd8, %r21;
ld.global.i32 %i5, [%rd9];
add.i32 %i5, %i5, %i1; // Add gap penalty
max.i32 %i4, %i4, %i5; // max(current, horiz)
}}
// Load vertical: dp[query_idx-1][subject_idx]
// offset = (query_idx-1) * (subject_len + 1) + subject_idx
mul.lo.u32 %r22, %r16, %r15;
add.u32 %r23, %r22, %r9;
@(!%p4) {{
cvta.to.global.u64 %rd10, %rd3;
add.u64 %rd11, %rd10, %r23;
ld.global.i32 %i6, [%rd11];
add.i32 %i6, %i6, %i1; // Add gap penalty
max.i32 %i4, %i4, %i6; // max(current, vert)
}}
// Local alignment: max with 0
mov.i32 %i7, 0;
max.i32 %i4, %i4, %i7;
// Store result
mul.lo.u32 %r24, %r7, %r15;
add.u32 %r25, %r24, %r9;
cvta.to.global.u64 %rd12, %rd3;
add.u64 %rd13, %rd12, %r25;
st.global.i32 [%rd13], %i4;
skip_compute:
ret;
}}
"#,
query_len, subject_len
)
}
#[cfg(feature = "cuda")]
pub fn compile_nw_kernel(query_len: usize, subject_len: usize) -> String {
format!(
r#"
.version 8.0
.target sm_80
.address_size 64
.visible .entry needleman_wunsch_kernel(
.param .u64 query_ptr,
.param .u64 subject_ptr,
.param .u64 dp_ptr,
.param .u32 query_len,
.param .u32 subject_len,
.param .i32 gap_penalty
) {{
// Similar to SW but:
// 1. Initialize first row/col with cumulative penalties
// 2. Keep negative scores (not max with 0)
.reg .b64 %rd<8>;
.reg .b32 %r<16>;
.reg .i32 %i<8>;
mov.u32 %r1, %tid.x;
mov.u32 %r2, %tid.y;
mov.u32 %r3, %bid.x;
mov.u32 %r4, %bid.y;
mov.u32 %r5, 16;
mul.lo.u32 %r6, %r3, %r5;
add.u32 %r7, %r6, %r1;
mul.lo.u32 %r8, %r4, %r5;
add.u32 %r9, %r8, %r2;
// Bounds check
setp.lt.u32 %p1, %r7, {};
setp.lt.u32 %p2, %r9, {};
@(!%p1) bra skip_nw;
@(!%p2) bra skip_nw;
ld.param.u64 %rd1, [query_ptr];
ld.param.u64 %rd2, [subject_ptr];
ld.param.u64 %rd3, [dp_ptr];
ld.param.i32 %i1, [gap_penalty];
// Handle initialization for global alignment
setp.eq.u32 %p3, %r7, 0;
@%p3 {{
mul.i32 %i2, %r9, %i1;
bra store_result;
}}
setp.eq.u32 %p4, %r9, 0;
@%p4 {{
mul.i32 %i2, %r7, %i1;
bra store_result;
}}
// Regular DP computation (similar to SW but without max(0))
ld.global.u8 %r12, [%rd1 + %r7];
ld.global.u8 %r13, [%rd2 + %r9];
setp.eq.u8 %p5, %r12, %r13;
selp.i32 %i3, 2, -1, %p5;
// Compute DP value (global alignment)
mov.i32 %i2, -2147483648;
// Load and compute diagonal + horizontal + vertical like SW
// but without the max(0) step
max.i32 %i2, %i2, %i3;
store_result:
mov.u32 %r15, {} ;
add.u32 %r20, %r15, 1;
mul.lo.u32 %r24, %r7, %r20;
add.u32 %r25, %r24, %r9;
st.global.i32 [%rd3 + %r25], %i2;
skip_nw:
ret;
}}
"#,
query_len, subject_len, subject_len
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "cuda")]
fn test_smith_waterman_cuda_kernel_compilation() {
let kernel_src = SmithWatermanCudaKernel::compile_sw_kernel(100, 100);
assert!(kernel_src.contains("smith_waterman_kernel"));
assert!(kernel_src.contains(".version 8.0"));
assert!(kernel_src.contains("sm_80"));
}
#[test]
#[cfg(feature = "cuda")]
fn test_needleman_wunsch_cuda_kernel_compilation() {
let kernel_src = SmithWatermanCudaKernel::compile_nw_kernel(50, 50);
assert!(kernel_src.contains("needleman_wunsch_kernel"));
assert!(kernel_src.contains(".version 8.0"));
}
#[test]
#[cfg(feature = "cuda")]
fn test_query_load_kernel_compilation() -> Result<()> {
let query = b"ACGTACGTACGT";
let ptx = SmithWatermanCudaKernel::compile_query(query)?;
assert!(ptx.contains("load_query"));
Ok(())
}
#[test]
#[cfg(not(feature = "cuda"))]
fn test_cuda_kernel_stub() {
}
}