use super::*;
impl WgpuProvider {
pub(crate) fn pagefun_exec(&self, request: &PagefunRequest) -> Result<GpuTensorHandle> {
match request.op {
PagefunOp::Mtimes => self.pagefun_mtimes_exec(request),
}
}
fn pagefun_mtimes_exec(&self, request: &PagefunRequest) -> Result<GpuTensorHandle> {
ensure!(
request.inputs.len() == 2,
"pagefun: @mtimes expects exactly two inputs"
);
ensure!(
request.input_page_dims.len() == request.inputs.len(),
"pagefun: input metadata mismatch"
);
let lhs = &request.inputs[0];
let rhs = &request.inputs[1];
let entry_a = self.get_entry(lhs)?;
let entry_b = self.get_entry(rhs)?;
let view_a = build_matrix_operand_view(lhs, &entry_a)
.map_err(|e| anyhow!("pagefun @mtimes: {e}"))?;
let view_b = build_matrix_operand_view(rhs, &entry_b)
.map_err(|e| anyhow!("pagefun @mtimes: {e}"))?;
let canonical_a = canonical_matrix_shape(&entry_a.shape);
let canonical_b = canonical_matrix_shape(&entry_b.shape);
ensure!(
canonical_a.len() >= 2 && canonical_b.len() >= 2,
"pagefun: @mtimes operands must be at least 2-D"
);
let rows = view_a.rows;
let k_a = view_a.cols;
let k_b = view_b.rows;
let cols = view_b.cols;
ensure!(
k_a == k_b,
"pagefun: inner matrix dimensions must agree ({} vs {})",
k_a,
k_b
);
let rank = request.page_dims.len();
let lhs_dims = pad_dims(request.input_page_dims[0].clone(), rank);
let rhs_dims = pad_dims(request.input_page_dims[1].clone(), rank);
let lhs_strides = compute_page_strides(&lhs_dims);
let rhs_strides = compute_page_strides(&rhs_dims);
let lhs_page_size = rows
.checked_mul(k_a)
.ok_or_else(|| anyhow!("pagefun: lhs page size overflow"))?;
let rhs_page_size = k_b
.checked_mul(cols)
.ok_or_else(|| anyhow!("pagefun: rhs page size overflow"))?;
let out_page_size = rows
.checked_mul(cols)
.ok_or_else(|| anyhow!("pagefun: output page size overflow"))?;
let page_volume = if rank == 0 {
1
} else {
product_checked(&request.page_dims)
.ok_or_else(|| anyhow!("pagefun: page dimensions overflow"))?
};
let total_len = out_page_size
.checked_mul(page_volume)
.ok_or_else(|| anyhow!("pagefun: output size overflow"))?;
let out_buffer =
self.create_storage_buffer_checked(total_len, "runmat-pagefun-mtimes-out")?;
if total_len == 0 {
return Ok(self.register_existing_buffer(
out_buffer,
request.output_shape.clone(),
total_len,
));
}
let m_u32 = u32::try_from(rows)
.map_err(|_| anyhow!("pagefun: matrix row count exceeds GPU limits"))?;
let n_u32 = u32::try_from(cols)
.map_err(|_| anyhow!("pagefun: matrix column count exceeds GPU limits"))?;
let k_u32 = u32::try_from(k_a)
.map_err(|_| anyhow!("pagefun: shared dimension exceeds GPU limits"))?;
let lda = view_a.lda;
let ldb = view_b.lda;
let ldc = m_u32;
let tile = crate::backend::wgpu::config::effective_matmul_tile();
let groups_x = crate::backend::wgpu::dispatch::common::dispatch_size_dim(n_u32, tile);
let groups_y = crate::backend::wgpu::dispatch::common::dispatch_size_dim(m_u32, tile);
self.prepare_matmul_pipeline();
self.device_ref().poll(wgpu::Maintain::Poll);
let start = Instant::now();
let mut multi_index = vec![0usize; rank];
for page_idx in 0..page_volume {
if rank > 0 {
decode_multi_index(page_idx, &request.page_dims, &mut multi_index);
}
let lhs_linear = broadcast_linear_index(&lhs_dims, &lhs_strides, &multi_index);
let rhs_linear = broadcast_linear_index(&rhs_dims, &rhs_strides, &multi_index);
let lhs_offset_elements = lhs_linear
.checked_mul(lhs_page_size)
.ok_or_else(|| anyhow!("pagefun: lhs offset overflow"))?;
let rhs_offset_elements = rhs_linear
.checked_mul(rhs_page_size)
.ok_or_else(|| anyhow!("pagefun: rhs offset overflow"))?;
let out_offset_elements = page_idx
.checked_mul(out_page_size)
.ok_or_else(|| anyhow!("pagefun: output offset overflow"))?;
let lhs_end = lhs_offset_elements
.checked_add(lhs_page_size)
.ok_or_else(|| anyhow!("pagefun: lhs offset overflow"))?;
let rhs_end = rhs_offset_elements
.checked_add(rhs_page_size)
.ok_or_else(|| anyhow!("pagefun: rhs offset overflow"))?;
let out_end = out_offset_elements
.checked_add(out_page_size)
.ok_or_else(|| anyhow!("pagefun: output offset overflow"))?;
ensure!(
lhs_end <= entry_a.len,
"pagefun: lhs page out of bounds (page {})",
page_idx
);
ensure!(
rhs_end <= entry_b.len,
"pagefun: rhs page out of bounds (page {})",
page_idx
);
ensure!(
out_end <= total_len,
"pagefun: output page out of bounds (page {})",
page_idx
);
let offset_a_u32 = u32::try_from(lhs_offset_elements)
.map_err(|_| anyhow!("pagefun: lhs offset exceeds GPU limits"))?;
let offset_b_u32 = u32::try_from(rhs_offset_elements)
.map_err(|_| anyhow!("pagefun: rhs offset exceeds GPU limits"))?;
let offset_out_u32 = u32::try_from(out_offset_elements)
.map_err(|_| anyhow!("pagefun: output offset exceeds GPU limits"))?;
let mut flags = 0u32;
if view_a.transpose {
flags |= crate::backend::wgpu::params::MATMUL_FLAG_TRANSPOSE_A;
}
if view_b.transpose {
flags |= crate::backend::wgpu::params::MATMUL_FLAG_TRANSPOSE_B;
}
let params = crate::backend::wgpu::params::MatmulParams {
m: m_u32,
n: n_u32,
k: k_u32,
lda,
ldb,
ldc,
offset_a: offset_a_u32,
offset_b: offset_b_u32,
offset_out: offset_out_u32,
flags,
};
let params_buffer = self.uniform_buffer(¶ms, "runmat-pagefun-mtimes-params");
let bind_group = self
.device_ref()
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("runmat-pagefun-mtimes-bind"),
layout: &self.pipelines.matmul.layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: entry_a.buffer.as_ref().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: entry_b.buffer.as_ref().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: out_buffer.as_ref().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: params_buffer.as_entire_binding(),
},
],
});
crate::backend::wgpu::dispatch::matmul::run(
self.device_ref(),
self.queue_ref(),
&self.pipelines.matmul.pipeline,
&bind_group,
groups_x,
groups_y,
);
}
self.telemetry.record_matmul_duration(start.elapsed());
let handle =
self.register_existing_buffer(out_buffer, request.output_shape.clone(), total_len);
Ok(handle)
}
}