use crate::protein::AminoAcid;
use crate::scoring::ScoringMatrix;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct GpuDevice {
pub id: u32,
pub name: String,
pub compute_capability: String,
pub total_memory: u64,
pub backend: GpuBackend,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GpuBackend {
Cuda,
Hip,
Vulkan,
}
#[derive(Debug, Clone)]
pub struct GpuAlignConfig {
pub device: u32,
pub batch_size: usize,
pub max_memory: u64,
pub enable_memory_pool: bool,
pub enable_prefetch: bool,
pub enable_compression: bool,
}
impl Default for GpuAlignConfig {
fn default() -> Self {
Self {
device: 0,
batch_size: 128,
max_memory: 8 * 1024 * 1024 * 1024, enable_memory_pool: true,
enable_prefetch: true,
enable_compression: false,
}
}
}
#[derive(Debug, Clone)]
pub struct GpuAlignmentResult {
pub matrix: Vec<Vec<i32>>,
pub max_i: usize,
pub max_j: usize,
pub max_score: i32,
pub gpu_time_ms: f32,
pub transfer_time_ms: f32,
}
pub trait GpuAlignmentKernel {
fn init(&mut self, config: &GpuAlignConfig) -> Result<()>;
fn allocate(&self, size: usize) -> Result<*mut u8>;
fn free(&self, ptr: *mut u8) -> Result<()>;
fn h2d(&self, host_ptr: *const u8, device_ptr: *mut u8, size: usize) -> Result<()>;
fn d2h(&self, device_ptr: *const u8, host_ptr: *mut u8, size: usize) -> Result<()>;
fn smith_waterman(
&self,
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
open_penalty: i32,
extend_penalty: i32,
) -> Result<GpuAlignmentResult>;
fn needleman_wunsch(
&self,
seq1: &[AminoAcid],
seq2: &[AminoAcid],
matrix: &ScoringMatrix,
open_penalty: i32,
extend_penalty: i32,
) -> Result<GpuAlignmentResult>;
fn batch_align(
&self,
queries: &[Vec<AminoAcid>],
subject: &[AminoAcid],
matrix: &ScoringMatrix,
open_penalty: i32,
extend_penalty: i32,
) -> Result<Vec<GpuAlignmentResult>>;
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct GpuMemoryPool {
pools: std::collections::HashMap<usize, Vec<*mut u8>>,
total_allocated: u64,
max_size: u64,
}
impl GpuMemoryPool {
pub fn new(max_size: u64) -> Self {
Self {
pools: std::collections::HashMap::new(),
total_allocated: 0,
max_size,
}
}
pub fn acquire(&mut self, size: usize) -> Result<*mut u8> {
if let Some(buffers) = self.pools.get_mut(&size) {
if let Some(buffer) = buffers.pop() {
return Ok(buffer);
}
}
let layout = std::alloc::Layout::from_size_align(size, 64)
.map_err(|e| crate::error::Error::Custom(e.to_string()))?;
let ptr = unsafe { std::alloc::alloc(layout) };
Ok(ptr)
}
pub fn release(&mut self, size: usize, buffer: *mut u8) {
self.pools.entry(size).or_insert_with(Vec::new).push(buffer);
}
pub fn clear(&mut self) {
for (size, buffers) in self.pools.iter_mut() {
for buffer in buffers.drain(..) {
unsafe {
std::alloc::dealloc(
buffer,
std::alloc::Layout::from_size_align_unchecked(*size, 64),
);
}
}
}
}
}
#[derive(Debug)]
pub struct MultiGpuContext {
pub devices: Vec<GpuDevice>,
pub active_device: u32,
pub kernels: Vec<Box<dyn std::any::Any>>,
}
impl MultiGpuContext {
pub fn detect() -> Result<Self> {
let devices = vec![];
#[cfg(feature = "cuda")]
let devices = Self::detect_cuda()?;
#[cfg(feature = "hip")]
let devices = Self::detect_hip()?;
#[cfg(feature = "vulkan")]
let devices = Self::detect_vulkan()?;
Ok(Self {
devices,
active_device: 0,
kernels: vec![],
})
}
#[cfg(feature = "cuda")]
fn detect_cuda() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
fn detect_cuda() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
#[cfg(feature = "hip")]
fn detect_hip() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
#[cfg(not(feature = "hip"))]
#[allow(dead_code)]
fn detect_hip() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
#[cfg(feature = "vulkan")]
fn detect_vulkan() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
#[cfg(not(feature = "vulkan"))]
#[allow(dead_code)]
fn detect_vulkan() -> Result<Vec<GpuDevice>> {
Ok(vec![])
}
pub fn select_device(&mut self, device_id: u32) -> Result<()> {
if (device_id as usize) >= self.devices.len() {
return Err(crate::error::Error::Custom(format!(
"Device {} not found (only {} available)",
device_id,
self.devices.len()
)));
}
self.active_device = device_id;
Ok(())
}
pub fn list_devices(&self) -> Vec<(u32, String, u64)> {
self.devices
.iter()
.enumerate()
.map(|(id, dev)| (id as u32, dev.name.clone(), dev.total_memory))
.collect()
}
pub fn distribute_batch(
&self,
batch_size: usize,
) -> Vec<(u32, usize, usize)> {
if self.devices.is_empty() {
return vec![];
}
let mut distribution = vec![];
let mut offset = 0;
let items_per_device = (batch_size + self.devices.len() - 1) / self.devices.len();
for (device_idx, _) in self.devices.iter().enumerate() {
let size = std::cmp::min(items_per_device, batch_size - offset);
if size > 0 {
distribution.push((device_idx as u32, offset, size));
offset += size;
}
}
distribution
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gpu_config_default() {
let config = GpuAlignConfig::default();
assert_eq!(config.device, 0);
assert_eq!(config.batch_size, 128);
assert!(config.enable_memory_pool);
assert!(config.enable_prefetch);
}
#[test]
fn test_memory_pool() {
let mut pool = GpuMemoryPool::new(1024 * 1024);
let size = 1024;
let _ptr1 = pool.acquire(size).unwrap();
let buf_count_1 = pool.pools.get(&size).map(|v| v.len()).unwrap_or(0);
assert_eq!(buf_count_1, 0);
let ptr1 = pool.acquire(size).unwrap();
pool.release(size, ptr1);
let buf_count_2 = pool.pools.get(&size).map(|v| v.len()).unwrap_or(0);
assert_eq!(buf_count_2, 1);
pool.clear();
}
#[test]
fn test_multi_gpu_distribution() {
let context = MultiGpuContext {
devices: vec![
GpuDevice {
id: 0,
name: "GPU-0".to_string(),
compute_capability: "8.6".to_string(),
total_memory: 24 * 1024 * 1024 * 1024,
backend: GpuBackend::Cuda,
},
GpuDevice {
id: 1,
name: "GPU-1".to_string(),
compute_capability: "8.6".to_string(),
total_memory: 24 * 1024 * 1024 * 1024,
backend: GpuBackend::Cuda,
},
],
active_device: 0,
kernels: vec![],
};
let dist = context.distribute_batch(10);
assert_eq!(dist.len(), 2);
assert_eq!(dist[0].0, 0); assert_eq!(dist[1].0, 1); }
}