use std::collections::VecDeque;
use std::sync::{Arc, Mutex, OnceLock};
use cudarc::cublas::sys::{cublasOperation_t, cublasSideMode_t, cublasStatus_t};
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig};
use cudarc::driver::{CudaContext, CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
use ndarray::{Array2, ArrayBase, Data, Ix1};
use super::device::GpuDeviceInfo;
use super::diagnostics;
use super::driver::{from_col_major, to_col_major, to_i32};
use super::runtime::{GpuRuntime, cuda_context_for};
pub struct DeviceXSession {
rows: usize,
cols: usize,
device: GpuDeviceInfo,
inner: Mutex<SessionInner>,
}
struct SessionInner {
stream: Arc<CudaStream>,
blas: CudaBlas,
x_dev: CudaSlice<f64>,
wy_dev: CudaSlice<f64>,
w_dev: CudaSlice<f64>,
out_pp_dev: CudaSlice<f64>,
out_n_dev: CudaSlice<f64>,
v_p_dev: CudaSlice<f64>,
}
unsafe impl Send for SessionInner {}
impl DeviceXSession {
pub fn xtwx<S: Data<Elem = f64>>(&self, w: &ArrayBase<S, Ix1>) -> Option<Array2<f64>> {
let n = self.rows;
let p = self.cols;
if w.len() != n {
return None;
}
let w_owned_storage: Option<Vec<f64>> = match w.as_slice() {
Some(_) => None,
None => Some(w.iter().copied().collect()),
};
let w_slice: &[f64] = match w_owned_storage.as_ref() {
Some(buf) => buf.as_slice(),
None => w.as_slice().expect("contiguous slice"),
};
let mut inner = self.inner.lock().ok()?;
let stream = inner.stream.clone();
stream.memcpy_htod(w_slice, &mut inner.w_dev).ok()?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let handle = *inner.blas.handle();
let ddgmm_status = {
let SessionInner {
ref x_dev,
ref w_dev,
ref mut wy_dev,
..
} = *inner;
unsafe {
let (x_ptr, _record_x) = x_dev.device_ptr(&stream);
let (w_ptr, _record_w) = w_dev.device_ptr(&stream);
let (wy_ptr, _record_wy) = wy_dev.device_ptr_mut(&stream);
cudarc::cublas::sys::cublasDdgmm(
handle,
cublasSideMode_t::CUBLAS_SIDE_LEFT,
n_i,
p_i,
x_ptr as *const f64,
n_i,
w_ptr as *const f64,
1,
wy_ptr as *mut f64,
n_i,
)
}
};
if ddgmm_status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return None;
}
let cfg = GemmConfig::<f64> {
transa: cublasOperation_t::CUBLAS_OP_T,
transb: cublasOperation_t::CUBLAS_OP_N,
m: p_i,
n: p_i,
k: n_i,
alpha: 1.0,
lda: n_i,
ldb: n_i,
beta: 0.0,
ldc: p_i,
};
let gemm_ok = {
let SessionInner {
ref blas,
ref x_dev,
ref wy_dev,
ref mut out_pp_dev,
..
} = *inner;
unsafe { blas.gemm(cfg, x_dev, wy_dev, out_pp_dev) }.is_ok()
};
if !gemm_ok {
return None;
}
let out_host: Vec<f64> = stream.clone_dtoh(&inner.out_pp_dev).ok()?;
Some(from_col_major(&out_host, p, p))
}
pub fn xv<S: Data<Elem = f64>>(&self, v: &ArrayBase<S, Ix1>) -> Option<ndarray::Array1<f64>> {
let n = self.rows;
let p = self.cols;
if v.len() != p {
return None;
}
let v_owned: Option<Vec<f64>> = match v.as_slice() {
Some(_) => None,
None => Some(v.iter().copied().collect()),
};
let v_slice: &[f64] = match v_owned.as_ref() {
Some(buf) => buf.as_slice(),
None => v.as_slice().expect("contiguous slice"),
};
let mut inner = self.inner.lock().ok()?;
let stream = inner.stream.clone();
stream.memcpy_htod(v_slice, &mut inner.v_p_dev).ok()?;
let n_i = to_i32(n)?;
let p_i = to_i32(p)?;
let cfg = GemvConfig::<f64> {
trans: cublasOperation_t::CUBLAS_OP_N,
m: n_i,
n: p_i,
alpha: 1.0,
lda: n_i,
incx: 1,
beta: 0.0,
incy: 1,
};
let gemv_ok = {
let SessionInner {
ref blas,
ref x_dev,
ref v_p_dev,
ref mut out_n_dev,
..
} = *inner;
unsafe { blas.gemv(cfg, x_dev, v_p_dev, out_n_dev) }.is_ok()
};
if !gemv_ok {
return None;
}
let y_host: Vec<f64> = stream.clone_dtoh(&inner.out_n_dev).ok()?;
Some(ndarray::Array1::from_vec(y_host))
}
#[inline]
pub fn device(&self) -> &GpuDeviceInfo {
&self.device
}
}
pub fn try_fast_xt_diag_x_arc<S: Data<Elem = f64>>(
x: &Arc<Array2<f64>>,
w: &ArrayBase<S, Ix1>,
) -> Option<Array2<f64>> {
let (rows, cols) = x.dim();
debug_assert_eq!(rows, w.len(), "X rows must match W length");
let runtime = GpuRuntime::global();
if !runtime.is_available() {
return None;
}
let policy = runtime.policy();
if !policy.route_xt_diag_y(rows, cols, cols) {
diagnostics::log_policy_cpu(
"xt_diag_x_resident",
format!("rows={rows} cols={cols}"),
format!(
"below cuBLAS policy threshold rows>={} and gemm_flops>={}",
policy.xtwx_min_rows, policy.gemm_min_flops
),
);
return None;
}
let session = cache().get_or_upload(x)?;
let start = std::time::Instant::now();
match session.xtwx(w) {
Some(out) => {
diagnostics::log_gpu_success(
"xt_diag_x_resident",
"cuBLAS",
session.device(),
format!("rows={rows} cols={cols}"),
diagnostics::gemm_flops(cols, cols, rows),
diagnostics::bytes_for_f64(rows),
diagnostics::bytes_for_f64(cols.saturating_mul(cols)),
start.elapsed().as_secs_f64(),
);
Some(out)
}
None => {
diagnostics::log_runtime_cpu(
"xt_diag_x_resident",
"cuBLAS",
format!("rows={rows} cols={cols}"),
);
None
}
}
}
const MAX_CACHE_ENTRIES: usize = 4;
struct SessionCache {
entries: Mutex<VecDeque<CacheEntry>>,
}
struct CacheEntry {
key: usize,
_arc_keepalive: Arc<Array2<f64>>,
outcome: CachedOutcome,
}
#[derive(Clone)]
enum CachedOutcome {
Ready(Arc<DeviceXSession>),
Failed,
}
impl SessionCache {
fn new() -> Self {
Self {
entries: Mutex::new(VecDeque::with_capacity(MAX_CACHE_ENTRIES + 1)),
}
}
fn get_or_upload(&self, x: &Arc<Array2<f64>>) -> Option<Arc<DeviceXSession>> {
let key = Arc::as_ptr(x) as usize;
if let Some(outcome) = self.lookup_and_promote(key) {
return match outcome {
CachedOutcome::Ready(session) => Some(session),
CachedOutcome::Failed => None,
};
}
let our_outcome = match upload_x(x) {
Some(session) => CachedOutcome::Ready(Arc::new(session)),
None => CachedOutcome::Failed,
};
let final_outcome = {
let mut guard = self.entries.lock().ok()?;
if let Some(pos) = guard.iter().position(|e| e.key == key) {
let entry = guard.remove(pos).expect("position just queried");
let peer = entry.outcome.clone();
guard.push_back(entry);
peer
} else {
let entry = CacheEntry {
key,
_arc_keepalive: x.clone(),
outcome: our_outcome.clone(),
};
guard.push_back(entry);
while guard.len() > MAX_CACHE_ENTRIES {
guard.pop_front();
}
our_outcome
}
};
match final_outcome {
CachedOutcome::Ready(session) => Some(session),
CachedOutcome::Failed => None,
}
}
fn lookup_and_promote(&self, key: usize) -> Option<CachedOutcome> {
let mut guard = self.entries.lock().ok()?;
let pos = guard.iter().position(|e| e.key == key)?;
let entry = guard.remove(pos)?;
let outcome = entry.outcome.clone();
guard.push_back(entry);
Some(outcome)
}
}
fn cache() -> &'static SessionCache {
static CACHE: OnceLock<SessionCache> = OnceLock::new();
CACHE.get_or_init(SessionCache::new)
}
fn upload_x(x: &Arc<Array2<f64>>) -> Option<DeviceXSession> {
let runtime = GpuRuntime::global();
let device = runtime.selected_device()?.clone();
let (rows, cols) = x.dim();
if rows == 0 || cols == 0 {
return None;
}
let ctx = match cuda_context_for(device.ordinal) {
Some(ctx) => ctx,
None => CudaContext::new(device.ordinal).ok()?,
};
let stream = ctx.new_stream().ok()?;
let blas = CudaBlas::new(stream.clone()).ok()?;
let upload_start = std::time::Instant::now();
let host_col_major: Vec<f64> = to_col_major(&x.view());
let x_dev: CudaSlice<f64> = stream.clone_htod(&host_col_major).ok()?;
let wy_dev: CudaSlice<f64> = stream.alloc_zeros::<f64>(rows.checked_mul(cols)?).ok()?;
let w_dev: CudaSlice<f64> = stream.alloc_zeros::<f64>(rows).ok()?;
let out_pp_dev: CudaSlice<f64> = stream.alloc_zeros::<f64>(cols.checked_mul(cols)?).ok()?;
let out_n_dev: CudaSlice<f64> = stream.alloc_zeros::<f64>(rows).ok()?;
let v_p_dev: CudaSlice<f64> = stream.alloc_zeros::<f64>(cols).ok()?;
let upload_elapsed = upload_start.elapsed().as_secs_f64();
log::info!(
"[GPU] xt_diag_x_resident upload | device={} '{}' | shape=rows={rows} cols={cols} | bytes={} | elapsed={upload_elapsed:.3}s",
device.ordinal,
device.name,
format_bytes_for_log(
rows.saturating_mul(cols)
.saturating_mul(std::mem::size_of::<f64>())
),
);
Some(DeviceXSession {
rows,
cols,
device,
inner: Mutex::new(SessionInner {
stream,
blas,
x_dev,
wy_dev,
w_dev,
out_pp_dev,
out_n_dev,
v_p_dev,
}),
})
}
fn format_bytes_for_log(bytes: usize) -> String {
const GIB: f64 = 1024.0 * 1024.0 * 1024.0;
const MIB: f64 = 1024.0 * 1024.0;
const KIB: f64 = 1024.0;
let b = bytes as f64;
if b >= GIB {
format!("{:.2}GiB", b / GIB)
} else if b >= MIB {
format!("{:.2}MiB", b / MIB)
} else if b >= KIB {
format!("{:.2}KiB", b / KIB)
} else {
format!("{bytes}B")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn session_returns_none_without_gpu() {
let x = Arc::new(Array2::<f64>::zeros((512, 8)));
let w = ndarray::Array1::<f64>::from_elem(512, 1.0);
let result = try_fast_xt_diag_x_arc(&x, &w);
if GpuRuntime::global().is_available() {
assert!(result.is_some() || result.is_none());
} else {
assert!(result.is_none());
}
}
#[test]
fn cache_does_not_grow_unboundedly() {
if !GpuRuntime::global().is_available() {
return;
}
let cache = cache();
let mut keepalives = Vec::new();
for _ in 0..(MAX_CACHE_ENTRIES + 2) {
let x = Arc::new(Array2::<f64>::zeros((1024, 16)));
let _ = cache.get_or_upload(&x);
keepalives.push(x);
}
let guard = cache.entries.lock().unwrap();
assert!(guard.len() <= MAX_CACHE_ENTRIES);
}
}