#![allow(clippy::too_many_arguments)]
use std::time::Instant;
use super::{AssertionKind, BrickAssertion, BrickError, ComputeBrick, TokenBudget, TokenResult};
#[cfg(feature = "cuda")]
#[allow(unused_imports)]
use crate::cuda::CudaExecutor;
#[cfg(feature = "cuda")]
#[allow(unused_imports)]
use crate::error::RealizarError;
#[derive(Debug, Clone)]
pub struct CoalescedDp4aBrick {
pub k: usize,
pub n: usize,
budget: TokenBudget,
}
impl CoalescedDp4aBrick {
#[must_use]
pub fn new(k: usize, n: usize) -> Self {
let bytes = (k as f64 * n as f64 * 4.5) / 8.0;
let bandwidth_gb_s = 700.0; let budget_us = bytes / (bandwidth_gb_s * 1e3);
Self {
k,
n,
budget: TokenBudget::from_latency(budget_us.max(1.0)),
}
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
#[must_use]
pub fn flops(&self) -> u64 {
2 * self.k as u64 * self.n as u64
}
#[must_use]
pub fn arithmetic_intensity(&self) -> f64 {
let bytes = (self.k as f64 * 4.5) / 8.0 + self.n as f64 * 4.0; self.flops() as f64 / bytes
}
pub fn forward(
&self,
input_q8: &[i8],
input_scale: f32,
weights_q4: &[u8],
weight_scales: &[f32],
) -> Result<Vec<f32>, BrickError> {
if input_q8.len() != self.k {
return Err(BrickError::InvalidInput(format!(
"Input length {} != k {}",
input_q8.len(),
self.k
)));
}
if weights_q4.len() != self.n * self.k / 2 {
return Err(BrickError::InvalidInput(format!(
"Weights length {} != n * k / 2 = {}",
weights_q4.len(),
self.n * self.k / 2
)));
}
if weight_scales.len() != self.n {
return Err(BrickError::InvalidInput(format!(
"Weight scales length {} != n {}",
weight_scales.len(),
self.n
)));
}
let mut output = vec![0.0f32; self.n];
for n in 0..self.n {
let mut acc = 0i32;
for k_group in (0..self.k).step_by(4) {
for k_offset in 0..4 {
let k = k_group + k_offset;
if k >= self.k {
break;
}
#[allow(clippy::manual_midpoint)]
let weight_byte_idx = (n * self.k + k) / 2;
let weight_nibble = if k % 2 == 0 {
(weights_q4[weight_byte_idx] & 0x0F) as i8 - 8 } else {
((weights_q4[weight_byte_idx] >> 4) & 0x0F) as i8 - 8 };
acc += input_q8[k] as i32 * weight_nibble as i32;
}
}
output[n] = acc as f32 * input_scale * weight_scales[n];
}
Ok(output)
}
pub fn forward_timed(
&self,
input_q8: &[i8],
input_scale: f32,
weights_q4: &[u8],
weight_scales: &[f32],
) -> Result<TokenResult<Vec<f32>>, BrickError> {
let start = Instant::now();
let output = self.forward(input_q8, input_scale, weights_q4, weight_scales)?;
let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
Ok(TokenResult {
output,
tokens_processed: 1,
us_per_token: elapsed_us,
tokens_per_sec: 1_000_000.0 / elapsed_us,
budget_met: elapsed_us <= self.budget.us_per_token,
})
}
#[deprecated(note = "Use forward() for real implementation")]
pub fn execute(&self) -> Result<Vec<f32>, BrickError> {
if !self.k.is_multiple_of(256) || self.k == 0 || self.n == 0 {
return Err(BrickError::InvalidInput(format!(
"Invalid dimensions: k={} (must be multiple of 256), n={}",
self.k, self.n
)));
}
Ok(vec![0.0; self.n])
}
}
impl ComputeBrick for CoalescedDp4aBrick {
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"coalesced_dp4a"
}
fn budget(&self) -> TokenBudget {
self.budget
}
fn assertions(&self) -> Vec<BrickAssertion> {
vec![
BrickAssertion::no_nan(),
BrickAssertion::no_inf(),
BrickAssertion::budget_met(),
BrickAssertion {
name: "bandwidth_efficient".to_string(),
description: "Achieves >= 70% of peak memory bandwidth".to_string(),
kind: AssertionKind::Custom {
check_name: "bandwidth_efficient".to_string(),
},
},
]
}
fn can_run(&self) -> bool {
self.k.is_multiple_of(256) && self.k > 0 && self.n > 0
}
}
#[derive(Debug, Clone)]
pub struct FusedFfnBrick {
pub hidden_dim: usize,
pub intermediate_dim: usize,
budget: TokenBudget,
pub use_packed_dp4a: bool,
}
impl FusedFfnBrick {
#[must_use]
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> Self {
let use_packed_dp4a = std::env::var("PACKED_DP4A")
.map(|v| v == "1")
.unwrap_or(false);
Self {
hidden_dim,
intermediate_dim,
budget: TokenBudget::from_latency(12.2),
use_packed_dp4a,
}
}
#[must_use]
pub fn with_packed_dp4a(hidden_dim: usize, intermediate_dim: usize) -> Self {
Self {
hidden_dim,
intermediate_dim,
budget: TokenBudget::from_latency(12.2),
use_packed_dp4a: true,
}
}
#[must_use]
pub fn with_budget(mut self, budget: TokenBudget) -> Self {
self.budget = budget;
self
}
#[must_use]
pub fn flops(&self) -> u64 {
6 * self.hidden_dim as u64 * self.intermediate_dim as u64
}
#[must_use]
pub fn arithmetic_intensity(&self) -> f64 {
let weight_bytes = 3.0 * self.hidden_dim as f64 * self.intermediate_dim as f64 * 4.5 / 8.0;
let activation_bytes =
(self.hidden_dim * 4 + self.intermediate_dim * 8 + self.hidden_dim * 4) as f64;
self.flops() as f64 / (weight_bytes + activation_bytes)
}
pub fn forward(
&self,
input: &[f32],
gate_proj: &[f32],
up_proj: &[f32],
down_proj: &[f32],
) -> Result<Vec<f32>, BrickError> {
if input.len() != self.hidden_dim {
return Err(BrickError::InvalidInput(format!(
"Input length {} != hidden_dim {}",
input.len(),
self.hidden_dim
)));
}
let expected_gate_up = self.intermediate_dim * self.hidden_dim;
if gate_proj.len() != expected_gate_up || up_proj.len() != expected_gate_up {
return Err(BrickError::InvalidInput(format!(
"Gate/Up length {} != intermediate * hidden = {}",
gate_proj.len(),
expected_gate_up
)));
}
if down_proj.len() != self.hidden_dim * self.intermediate_dim {
return Err(BrickError::InvalidInput(format!(
"Down length {} != hidden * intermediate = {}",
down_proj.len(),
self.hidden_dim * self.intermediate_dim
)));
}
let mut gate = vec![0.0f32; self.intermediate_dim];
for i in 0..self.intermediate_dim {
let mut sum = 0.0f32;
for j in 0..self.hidden_dim {
sum += input[j] * gate_proj[i * self.hidden_dim + j];
}
gate[i] = sum;
}
let mut up = vec![0.0f32; self.intermediate_dim];
for i in 0..self.intermediate_dim {
let mut sum = 0.0f32;
for j in 0..self.hidden_dim {
sum += input[j] * up_proj[i * self.hidden_dim + j];
}
up[i] = sum;
}
let mut hidden = vec![0.0f32; self.intermediate_dim];
for i in 0..self.intermediate_dim {
let silu_gate = gate[i] / (1.0 + (-gate[i]).exp());
hidden[i] = silu_gate * up[i];
}
let mut output = vec![0.0f32; self.hidden_dim];
for i in 0..self.hidden_dim {
let mut sum = 0.0f32;
for j in 0..self.intermediate_dim {
sum += hidden[j] * down_proj[i * self.intermediate_dim + j];
}
output[i] = sum;
}
Ok(output)
}
pub fn forward_timed(
&self,
input: &[f32],
gate_proj: &[f32],
up_proj: &[f32],
down_proj: &[f32],
) -> Result<TokenResult<Vec<f32>>, BrickError> {
let start = Instant::now();
let output = self.forward(input, gate_proj, up_proj, down_proj)?;
let elapsed_us = start.elapsed().as_secs_f64() * 1_000_000.0;
Ok(TokenResult {
output,
tokens_processed: 1,
us_per_token: elapsed_us,
tokens_per_sec: 1_000_000.0 / elapsed_us,
budget_met: elapsed_us <= self.budget.us_per_token,
})
}
#[deprecated(note = "Use forward() for real implementation")]
pub fn execute(&self) -> Result<Vec<f32>, BrickError> {
if self.hidden_dim == 0 || self.intermediate_dim == 0 {
return Err(BrickError::InvalidInput("Zero dimension".to_string()));
}
Ok(vec![0.0; self.hidden_dim])
}
}
include!("fused_ffn_brick.rs");