#[allow(dead_code, clippy::too_many_arguments)]
pub fn prefill_extract_all_logits(
&mut self,
s: usize,
hidden_dim: u32,
vocab_size: u32,
epsilon: f32,
) -> Result<Vec<f32>, GpuError> {
if s == 0 {
return Ok(Vec::new());
}
let output_norm_buf = self.rmsnorm_cache.get("output_norm.gamma").ok_or_else(|| {
GpuError::InvalidLaunchConfig("realizr#203: output_norm not cached".to_string())
})?;
let output_norm_ptr = output_norm_buf.as_ptr();
let output_norm_len = hidden_dim as usize;
let hidden_buf2_ptr = self.workspace.hidden_buf2.as_ref()
.ok_or_else(|| GpuError::InvalidLaunchConfig("realizr#203: hidden_buf2 missing".to_string()))?
.as_ptr();
let hidden_buf2_len = s * hidden_dim as usize;
let normed_buf = if let Some(ref buf) = self.workspace.normed_hidden_buf {
if buf.len() >= hidden_buf2_len { buf.as_ptr() }
else {
let tmp = GpuBuffer::<f32>::new(&self.context, hidden_buf2_len)?;
let ptr = tmp.as_ptr();
self.workspace.normed_hidden_buf = Some(tmp);
ptr
}
} else {
let tmp = GpuBuffer::<f32>::new(&self.context, hidden_buf2_len)?;
let ptr = tmp.as_ptr();
self.workspace.normed_hidden_buf = Some(tmp);
ptr
};
let hidden_wrapper = unsafe { GpuBuffer::<f32>::from_raw_parts(hidden_buf2_ptr, hidden_buf2_len) };
let normed_wrapper = unsafe { GpuBuffer::<f32>::from_raw_parts(normed_buf, hidden_buf2_len) };
self.batched_rmsnorm_ptr_into(
&hidden_wrapper, output_norm_ptr, output_norm_len,
&normed_wrapper, hidden_dim, s as u32, epsilon,
)?;
std::mem::forget(hidden_wrapper);
std::mem::forget(normed_wrapper);
if self.lm_head_ptr == 0 {
return Err(GpuError::InvalidLaunchConfig("realizr#203: LM head not indexed".to_string()));
}
let lm_head_ptr = self.lm_head_ptr;
let lm_head_qtype = self.lm_head_qtype;
let logits_size = s * vocab_size as usize;
let logits_gpu = GpuBuffer::<f32>::new(&self.context, logits_size)?;
let normed_wrapper2 = unsafe { GpuBuffer::<f32>::from_raw_parts(normed_buf, hidden_buf2_len) };
self.batched_gemv_or_gemm(
lm_head_qtype, lm_head_ptr,
&normed_wrapper2, &logits_gpu,
normed_buf, logits_gpu.as_ptr(),
s as u32, vocab_size, hidden_dim,
)?;
std::mem::forget(normed_wrapper2);
self.stream.synchronize()?;
let mut logits = vec![0.0f32; logits_size];
logits_gpu.copy_to_host(&mut logits)?;
Ok(logits)
}