mod activations;
mod backward;
mod eigen;
pub(crate) mod linalg;
mod reductions;
#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
use super::runtime;
#[derive(Clone)]
pub struct GpuDevice {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
}
impl GpuDevice {
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
pub fn new() -> Result<Self, String> {
runtime::block_on(async { Self::new_async().await })
}
pub async fn new_async() -> Result<Self, String> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.map_err(|e| format!("Failed to find GPU adapter: {}", e))?;
let mut limits = wgpu::Limits::default();
limits.max_buffer_size = adapter.limits().max_buffer_size;
limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some("Trueno GPU Device"),
required_features: wgpu::Features::empty(),
required_limits: limits,
memory_hints: wgpu::MemoryHints::Performance,
experimental_features: Default::default(),
trace: Default::default(),
})
.await
.map_err(|e| format!("Failed to create device: {}", e))?;
Ok(Self { device, queue })
}
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
pub fn new_with_adapter_index(index: u32) -> Result<Self, String> {
runtime::block_on(async { Self::new_with_adapter_index_async(index).await })
}
pub async fn new_with_adapter_index_async(index: u32) -> Result<Self, String> {
let instance = wgpu::Instance::default();
let adapters = instance.enumerate_adapters(wgpu::Backends::all());
if adapters.is_empty() {
return Err("No GPU adapters found".to_string());
}
let adapter = adapters
.into_iter()
.nth(index as usize)
.ok_or_else(|| format!("GPU adapter index {} out of range", index))?;
let mut limits = wgpu::Limits::default();
limits.max_buffer_size = adapter.limits().max_buffer_size;
limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor {
label: Some(&format!("Trueno GPU Device [{}]", index)),
required_features: wgpu::Features::empty(),
required_limits: limits,
memory_hints: wgpu::MemoryHints::Performance,
experimental_features: Default::default(),
trace: Default::default(),
})
.await
.map_err(|e| format!("Failed to create device at index {}: {}", index, e))?;
Ok(Self { device, queue })
}
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
pub fn list_adapters() -> Vec<(u32, String, String)> {
runtime::block_on(Self::list_adapters_async())
}
pub async fn list_adapters_async() -> Vec<(u32, String, String)> {
let instance = wgpu::Instance::default();
let adapters = instance.enumerate_adapters(wgpu::Backends::all());
adapters
.iter()
.enumerate()
.map(|(idx, adapter)| {
let info = adapter.get_info();
(idx as u32, info.name, format!("{:?}", info.backend))
})
.collect()
}
#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
pub fn is_available() -> bool {
runtime::block_on(Self::is_available_async())
}
pub async fn is_available_async() -> bool {
let instance = wgpu::Instance::default();
instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.is_ok()
}
pub(super) async fn execute_element_wise_op(
&self,
op_name: &str,
shader_source: &str,
input: &[f32],
result: &mut [f32],
uniform_data: Option<&[u8]>,
) -> Result<(), String> {
let len = input.len();
let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("{} Shader", op_name)),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});
let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{} Input", op_name)),
size: std::mem::size_of_val(input) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{} Output", op_name)),
size: std::mem::size_of_val(result) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
let uniform_buffer = uniform_data.map(|data| {
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{} Uniform", op_name)),
size: data.len() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&buffer, 0, data);
buffer
});
let mut bind_group_entries = vec![
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
];
if uniform_buffer.is_some() {
bind_group_entries.push(wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
});
}
let bind_group_layout =
self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(&format!("{} Bind Group Layout", op_name)),
entries: &bind_group_entries,
});
let mut bind_entries = vec![
wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
];
if let Some(ref uniform_buf) = uniform_buffer {
bind_entries.push(wgpu::BindGroupEntry {
binding: 2,
resource: uniform_buf.as_entire_binding(),
});
}
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{} Bind Group", op_name)),
layout: &bind_group_layout,
entries: &bind_entries,
});
let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("{} Pipeline Layout", op_name)),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(&format!("{} Pipeline", op_name)),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(&format!("{} Staging Buffer", op_name)),
size: std::mem::size_of_val(result) as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some(&format!("{} Encoder", op_name)),
});
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&format!("{} Pass", op_name)),
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
let workgroup_size = 256;
let num_workgroups = (len as u32).div_ceil(workgroup_size);
compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(
&output_buffer,
0,
&staging_buffer,
0,
std::mem::size_of_val(result) as u64,
);
self.queue.submit(Some(encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
receiver
.receive()
.await
.ok_or("Failed to receive mapping result")?
.map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
{
let data = buffer_slice.get_mapped_range();
result.copy_from_slice(bytemuck::cast_slice(&data));
}
staging_buffer.unmap();
Ok(())
}
}
#[cfg(all(test, feature = "gpu", not(target_arch = "wasm32")))]
mod tests {
use super::*;
#[test]
fn test_is_available_consistency() {
let available = GpuDevice::is_available();
let device_result = GpuDevice::new();
if available {
assert!(
device_result.is_ok(),
"is_available() returned true, but GpuDevice::new() failed"
);
} else {
eprintln!(
"GPU not available (is_available=false), device creation result: {:?}",
device_result.is_err()
);
}
}
#[test]
fn test_reduce_sum_not_hardcoded() {
if !GpuDevice::is_available() {
eprintln!("GPU not available, skipping test");
return;
}
let device = GpuDevice::new().expect("Failed to create GPU device");
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = runtime::block_on(device.reduce_sum(&input)).expect("reduce_sum failed");
assert_ne!(result, -1.0, "reduce_sum returned hardcoded -1.0 (mutant not killed)");
let expected: f32 = input.iter().sum();
assert!(
(result - expected).abs() < 1e-4,
"reduce_sum({:?}) = {} (expected {})",
input,
result,
expected
);
}
}