use trueno_gpu::driver::{CudaContext, CudaStream};
use trueno_gpu::GpuError;
use crate::cuda::kernels::KernelType;
pub struct AsyncPipeline {
compute_stream: CudaStream,
transfer_stream: CudaStream,
layers_queued: usize,
active: bool,
}
impl AsyncPipeline {
pub fn new(context: &CudaContext) -> Result<Self, GpuError> {
let compute_stream = CudaStream::new(context)?;
let transfer_stream = CudaStream::new(context)?;
Ok(Self {
compute_stream,
transfer_stream,
layers_queued: 0,
active: false,
})
}
pub fn begin(&mut self) {
self.active = true;
self.layers_queued = 0;
}
pub fn enqueue_layer(&mut self) -> usize {
let layer_idx = self.layers_queued;
self.layers_queued += 1;
layer_idx
}
#[must_use]
pub fn compute_stream(&self) -> &CudaStream {
&self.compute_stream
}
#[must_use]
pub fn transfer_stream(&self) -> &CudaStream {
&self.transfer_stream
}
pub fn sync(&self) -> Result<(), GpuError> {
self.compute_stream.synchronize()?;
self.transfer_stream.synchronize()?;
Ok(())
}
pub fn end(&mut self) -> Result<(), GpuError> {
self.sync()?;
self.active = false;
Ok(())
}
#[must_use]
pub fn is_active(&self) -> bool {
self.active
}
#[must_use]
pub fn layers_queued(&self) -> usize {
self.layers_queued
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum MemoryPattern {
#[default]
Scalar,
Vector2,
Vector4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RegisterTiling {
pub width: u32,
pub height: u32,
}
impl Default for RegisterTiling {
fn default() -> Self {
Self {
width: 4,
height: 4,
}
}
}
impl RegisterTiling {
#[must_use]
pub const fn large() -> Self {
Self {
width: 8,
height: 8,
}
}
#[must_use]
pub const fn medium() -> Self {
Self {
width: 4,
height: 4,
}
}
#[must_use]
pub const fn small() -> Self {
Self {
width: 2,
height: 2,
}
}
#[must_use]
pub const fn registers_needed(&self) -> u32 {
self.width * self.height
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BankConflictStrategy {
#[default]
None,
Padding,
Xor,
}
#[derive(Debug, Clone, Default)]
pub struct PtxOptimizationHints {
pub memory_pattern: MemoryPattern,
pub register_tiling: RegisterTiling,
pub bank_conflict_strategy: BankConflictStrategy,
pub target_occupancy: f32,
pub enable_ilp: bool,
pub shared_mem_preference: u32,
}
impl PtxOptimizationHints {
#[must_use]
pub fn max_throughput() -> Self {
Self {
memory_pattern: MemoryPattern::Vector4,
register_tiling: RegisterTiling::large(),
bank_conflict_strategy: BankConflictStrategy::Padding,
target_occupancy: 0.75,
enable_ilp: true,
shared_mem_preference: 0,
}
}
#[must_use]
pub fn low_latency() -> Self {
Self {
memory_pattern: MemoryPattern::Scalar,
register_tiling: RegisterTiling::small(),
bank_conflict_strategy: BankConflictStrategy::None,
target_occupancy: 1.0,
enable_ilp: false,
shared_mem_preference: 0,
}
}
#[must_use]
pub fn balanced() -> Self {
Self {
memory_pattern: MemoryPattern::Vector2,
register_tiling: RegisterTiling::medium(),
bank_conflict_strategy: BankConflictStrategy::Padding,
target_occupancy: 0.5,
enable_ilp: true,
shared_mem_preference: 0,
}
}
#[must_use]
pub const fn uses_vectorized_loads(&self) -> bool {
matches!(
self.memory_pattern,
MemoryPattern::Vector2 | MemoryPattern::Vector4
)
}
#[must_use]
pub const fn vector_width(&self) -> u32 {
match self.memory_pattern {
MemoryPattern::Scalar => 1,
MemoryPattern::Vector2 => 2,
MemoryPattern::Vector4 => 4,
}
}
#[must_use]
pub const fn shared_mem_padding(&self) -> u32 {
match self.bank_conflict_strategy {
BankConflictStrategy::Padding => 1,
_ => 0,
}
}
}
pub struct PtxOptimizer {
hints: PtxOptimizationHints,
}
impl PtxOptimizer {
#[must_use]
pub const fn new(hints: PtxOptimizationHints) -> Self {
Self { hints }
}
#[must_use]
pub const fn hints(&self) -> &PtxOptimizationHints {
&self.hints
}
#[must_use]
pub fn summary(&self) -> String {
format!(
"PtxOptimizer[vec={}, tile={}x{}, bank={:?}, ilp={}]",
self.hints.vector_width(),
self.hints.register_tiling.width,
self.hints.register_tiling.height,
self.hints.bank_conflict_strategy,
self.hints.enable_ilp
)
}
#[must_use]
pub const fn padded_shared_mem_row(&self, row_elements: u32) -> u32 {
row_elements + self.hints.shared_mem_padding()
}
#[must_use]
pub const fn estimated_registers(&self) -> u32 {
let base = 16;
let accum = self.hints.register_tiling.registers_needed();
let ilp_extra = if self.hints.enable_ilp { accum } else { 0 };
base + accum + ilp_extra
}
#[must_use]
pub const fn is_high_register_pressure(&self) -> bool {
self.estimated_registers() > 64
}
}
pub mod presets {
use super::KernelType;
pub fn llama_attention(seq_len: u32, head_dim: u32) -> KernelType {
KernelType::Attention {
seq_len,
head_dim,
causal: true,
}
}
pub fn ffn_gemm(batch: u32, hidden: u32, intermediate: u32) -> KernelType {
KernelType::GemmTiled {
m: batch,
n: intermediate,
k: hidden,
tile_size: 32,
}
}
pub fn q4k_inference(batch: u32, hidden: u32, k: u32) -> KernelType {
KernelType::QuantizedGemm {
m: batch,
n: hidden,
k,
}
}
pub fn q4k_ggml_inference(batch: u32, hidden: u32, k: u32) -> KernelType {
debug_assert!(
k.is_multiple_of(256),
"k must be divisible by 256 for GGML super-blocks"
);
KernelType::QuantizedGemmGgml {
m: batch,
n: hidden,
k,
}
}
pub fn rmsnorm(hidden_size: u32) -> KernelType {
KernelType::LayerNorm {
hidden_size,
epsilon: 1e-6,
affine: false,
}
}
pub fn multi_head_attention(seq_len: u32, head_dim: u32, n_heads: u32) -> KernelType {
KernelType::MultiHeadAttention {
seq_len,
head_dim,
n_heads,
causal: true, }
}
pub fn phi2_multi_head_attention(seq_len: u32) -> KernelType {
KernelType::MultiHeadAttention {
seq_len,
head_dim: 80,
n_heads: 32,
causal: true,
}
}
pub fn tensor_core_attention(seq_len: u32, head_dim: u32, n_heads: u32) -> KernelType {
KernelType::AttentionTensorCore {
seq_len,
head_dim,
n_heads,
causal: true, }
}
pub fn llama_tensor_core_attention(seq_len: u32) -> KernelType {
KernelType::AttentionTensorCore {
seq_len,
head_dim: 128,
n_heads: 32,
causal: true,
}
}
}
include!("pipeline_memory_pattern.rs");