use crate::error::KkError;
use crate::kk_mix::{KkSponge, KkState, KDF_SQUEEZE_ROUNDS, RATE_BYTES, RATE_WORDS, STATE_WORDS};
use zeroize::Zeroize;
extern "C" {
fn kk_cuda_is_available() -> i32;
fn kk_cuda_get_device_name(buf: *mut u8, buf_len: i32) -> i32;
fn kk_cuda_permute_batch(
host_states: *mut u64,
host_rotations: *const u32,
rounds: u32,
num_states: u32,
) -> i32;
fn kk_cuda_permute_batch_persistent(
host_states: *mut u64,
host_rotations: *const u32,
rounds: u32,
num_states: u32,
) -> i32;
fn kk_cuda_free_persistent();
}
pub struct CudaAccelerator {
device_name: String,
}
impl CudaAccelerator {
pub fn new() -> Result<Self, KkError> {
let available = unsafe { kk_cuda_is_available() };
if available != 1 {
return Err(KkError::GpuError("No CUDA-capable GPU found".into()));
}
let mut buf = [0u8; 256];
let rc = unsafe { kk_cuda_get_device_name(buf.as_mut_ptr(), 256) };
if rc != 0 {
return Err(KkError::GpuError("Failed to query CUDA device name".into()));
}
let name = std::str::from_utf8(&buf)
.unwrap_or("unknown")
.trim_end_matches('\0')
.to_string();
Ok(Self { device_name: name })
}
pub fn device_name(&self) -> &str {
&self.device_name
}
pub fn permute_batch(&self, states: &mut [KkState], rotations: &[[u32; 2]; 15], rounds: u32) {
if states.is_empty() {
return;
}
let n = states.len();
let mut flat: Vec<u64> = Vec::with_capacity(n * STATE_WORDS);
for s in states.iter() {
flat.extend_from_slice(s);
}
let mut rot_flat = [0u32; 30];
for (i, pair) in rotations.iter().enumerate() {
rot_flat[i * 2] = pair[0];
rot_flat[i * 2 + 1] = pair[1];
}
let rc = unsafe {
kk_cuda_permute_batch(flat.as_mut_ptr(), rot_flat.as_ptr(), rounds, n as u32)
};
if rc != 0 {
return;
}
for (i, s) in states.iter_mut().enumerate() {
s.copy_from_slice(&flat[i * STATE_WORDS..(i + 1) * STATE_WORDS]);
}
flat.zeroize();
}
pub fn permute_batch_persistent(
&self,
states: &mut [KkState],
rotations: &[[u32; 2]; 15],
rounds: u32,
) {
if states.is_empty() {
return;
}
let n = states.len();
let mut flat: Vec<u64> = Vec::with_capacity(n * STATE_WORDS);
for s in states.iter() {
flat.extend_from_slice(s);
}
let mut rot_flat = [0u32; 30];
for (i, pair) in rotations.iter().enumerate() {
rot_flat[i * 2] = pair[0];
rot_flat[i * 2 + 1] = pair[1];
}
let rc = unsafe {
kk_cuda_permute_batch_persistent(flat.as_mut_ptr(), rot_flat.as_ptr(), rounds, n as u32)
};
if rc != 0 {
return;
}
for (i, s) in states.iter_mut().enumerate() {
s.copy_from_slice(&flat[i * STATE_WORDS..(i + 1) * STATE_WORDS]);
}
flat.zeroize();
}
pub fn free_persistent(&self) {
unsafe { kk_cuda_free_persistent() };
}
pub fn kk_kdf_batch(
&self,
key: &[u8],
salt: &[u8],
infos: &[&[u8]],
output_len: usize,
) -> Vec<Vec<u8>> {
if infos.is_empty() {
return Vec::new();
}
let n = infos.len();
let mut shared = KkSponge::with_entropy_rotations(salt);
shared.absorb(key);
shared.absorb(&(salt.len() as u64).to_le_bytes());
shared.absorb(salt);
let mut sponges: Vec<KkSponge> = (0..n).map(|_| shared.clone()).collect();
drop(shared);
for i in 0..n {
sponges[i].absorb(&(infos[i].len() as u64).to_le_bytes());
sponges[i].absorb(infos[i]);
sponges[i].finalize_absorb_kdf();
}
let rotations = sponges[0].rotations();
let mut raw_states: Vec<KkState> = sponges.iter().map(|s| s.state()).collect();
drop(sponges);
let mut outputs: Vec<Vec<u8>> = (0..n).map(|_| Vec::with_capacity(output_len)).collect();
loop {
for (i, state) in raw_states.iter().enumerate() {
let remaining = output_len - outputs[i].len();
let take = remaining.min(RATE_BYTES);
let rate = rate_bytes_from_state(state);
outputs[i].extend_from_slice(&rate[..take]);
}
if outputs[0].len() >= output_len {
break;
}
self.permute_batch_persistent(&mut raw_states, &rotations, KDF_SQUEEZE_ROUNDS as u32);
}
raw_states.zeroize();
outputs
}
}
impl Drop for CudaAccelerator {
fn drop(&mut self) {
unsafe { kk_cuda_free_persistent() };
}
}
fn rate_bytes_from_state(state: &KkState) -> [u8; RATE_BYTES] {
let mut out = [0u8; RATE_BYTES];
for i in 0..RATE_WORDS {
out[i * 8..(i + 1) * 8].copy_from_slice(&state[i].to_le_bytes());
}
out
}