use crate::error::KkError;
use crate::kk_mix::{KkSponge, KkState, KDF_SQUEEZE_ROUNDS, RATE_BYTES, RATE_WORDS, STATE_WORDS};
use wgpu::util::DeviceExt;
use zeroize::Zeroize;
pub struct GpuAccelerator {
device: wgpu::Device,
queue: wgpu::Queue,
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
adapter_name: String,
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuParams {
rounds: u32,
num_states: u32,
}
impl GpuAccelerator {
pub fn new() -> Result<Self, KkError> {
pollster::block_on(Self::new_async())
}
async fn new_async() -> Result<Self, KkError> {
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| KkError::GpuError("no GPU adapter found".into()))?;
let adapter_name = adapter.get_info().name.clone();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("kk-crypto-gpu"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits {
max_storage_buffer_binding_size: 256 * 1024 * 1024, max_buffer_size: 256 * 1024 * 1024,
..wgpu::Limits::default()
},
memory_hints: wgpu::MemoryHints::Performance,
},
None,
)
.await
.map_err(|e| KkError::GpuError(format!("device request failed: {e}")))?;
let shader_source = include_str!("kk_permute.wgsl");
let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("kk_permute"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("kk_permute_bgl"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("kk_permute_pl"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("kk_permute_pipeline"),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: Some("kk_permute_kernel"),
compilation_options: Default::default(),
cache: None,
});
Ok(Self {
device,
queue,
pipeline,
bind_group_layout,
adapter_name,
})
}
pub fn device_name(&self) -> &str {
&self.adapter_name
}
pub fn permute_batch(&self, states: &mut [KkState], rotations: &[[u32; 2]; 15], rounds: usize) {
if states.is_empty() {
return;
}
pollster::block_on(self.permute_batch_async(states, rotations, rounds));
}
async fn permute_batch_async(
&self,
states: &mut [KkState],
rotations: &[[u32; 2]; 15],
rounds: usize,
) {
let n = states.len();
let mut state_data: Vec<u32> = Vec::with_capacity(n * STATE_WORDS * 2);
for state in states.iter() {
for &word in state.iter() {
state_data.push(word as u32);
state_data.push((word >> 32) as u32);
}
}
let mut rot_data = [0u32; 30];
for (i, pair) in rotations.iter().enumerate() {
rot_data[i * 2] = pair[0];
rot_data[i * 2 + 1] = pair[1];
}
let params = GpuParams {
rounds: rounds as u32,
num_states: n as u32,
};
let state_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("states"),
contents: bytemuck::cast_slice(&state_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let rot_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("rotations"),
contents: bytemuck::cast_slice(&rot_data),
usage: wgpu::BufferUsages::STORAGE,
});
let params_buf = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
let readback_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("readback"),
size: (state_data.len() * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("kk_bg"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: state_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: rot_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: params_buf.as_entire_binding(),
},
],
});
let workgroup_size = 64u32;
let num_workgroups = (n as u32).div_ceil(workgroup_size);
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("kk_enc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("kk_pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(num_workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(
&state_buf,
0,
&readback_buf,
0,
(state_data.len() * 4) as u64,
);
self.queue.submit(std::iter::once(encoder.finish()));
let (tx, rx) = std::sync::mpsc::channel();
readback_buf
.slice(..)
.map_async(wgpu::MapMode::Read, move |result| {
let _ = tx.send(result);
});
self.device.poll(wgpu::Maintain::Wait);
rx.recv()
.expect("GPU readback channel closed")
.expect("GPU readback mapping failed");
{
let data = readback_buf.slice(..).get_mapped_range();
let result_u32s: &[u32] = bytemuck::cast_slice(&data);
for (si, state) in states.iter_mut().enumerate() {
let base = si * STATE_WORDS * 2;
for w in 0..STATE_WORDS {
let lo = result_u32s[base + w * 2] as u64;
let hi = result_u32s[base + w * 2 + 1] as u64;
state[w] = lo | (hi << 32);
}
}
}
readback_buf.unmap();
state_data.zeroize();
}
pub fn kk_kdf_batch(
&self,
key: &[u8],
salt: &[u8],
infos: &[&[u8]],
output_len: usize,
) -> Vec<Vec<u8>> {
if infos.is_empty() {
return Vec::new();
}
let n = infos.len();
let mut shared = KkSponge::with_entropy_rotations(salt);
shared.absorb(key);
shared.absorb(&(salt.len() as u64).to_le_bytes());
shared.absorb(salt);
let mut sponges: Vec<KkSponge> = (0..n).map(|_| shared.clone()).collect();
drop(shared);
for i in 0..n {
sponges[i].absorb(&(infos[i].len() as u64).to_le_bytes());
sponges[i].absorb(infos[i]);
sponges[i].finalize_absorb_kdf();
}
let rotations = sponges[0].rotations();
let mut raw_states: Vec<KkState> = sponges.iter().map(|s| s.state()).collect();
drop(sponges);
let mut outputs: Vec<Vec<u8>> = (0..n).map(|_| Vec::with_capacity(output_len)).collect();
loop {
for (i, state) in raw_states.iter().enumerate() {
let remaining = output_len - outputs[i].len();
let take = remaining.min(RATE_BYTES);
let rate = rate_bytes_from_state(state);
outputs[i].extend_from_slice(&rate[..take]);
}
if outputs[0].len() >= output_len {
break;
}
self.permute_batch(&mut raw_states, &rotations, KDF_SQUEEZE_ROUNDS);
}
raw_states.zeroize();
outputs
}
}
fn rate_bytes_from_state(state: &KkState) -> [u8; RATE_BYTES] {
let mut out = [0u8; RATE_BYTES];
for i in 0..RATE_WORDS {
out[i * 8..(i + 1) * 8].copy_from_slice(&state[i].to_le_bytes());
}
out
}