use super::super::CudaRuntime;
use super::super::client::CudaClient;
use super::super::kernels;
use crate::algorithm::linalg::{validate_linalg_dtype, validate_matrix_2d};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::{AllocGuard, RuntimeClient};
use crate::tensor::Tensor;
fn validate_banded(
ab_shape: &[usize],
b_shape: &[usize],
kl: usize,
ku: usize,
) -> Result<(usize, usize)> {
let (ab_rows, n) = validate_matrix_2d(ab_shape)?;
let expected_rows = kl + ku + 1;
if ab_rows != expected_rows {
return Err(Error::ShapeMismatch {
expected: vec![expected_rows, n],
got: ab_shape.to_vec(),
});
}
if n == 0 {
return Err(Error::InvalidArgument {
arg: "ab",
reason: "banded system size n must be > 0".to_string(),
});
}
let nrhs = match b_shape.len() {
1 => {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: b_shape.to_vec(),
});
}
1
}
2 => {
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n, b_shape[1]],
got: b_shape.to_vec(),
});
}
b_shape[1]
}
_ => {
return Err(Error::InvalidArgument {
arg: "b",
reason: format!("expected 1D or 2D tensor, got {}D", b_shape.len()),
});
}
};
Ok((n, nrhs))
}
pub fn solve_banded_impl(
client: &CudaClient,
ab: &Tensor<CudaRuntime>,
b: &Tensor<CudaRuntime>,
kl: usize,
ku: usize,
) -> Result<Tensor<CudaRuntime>> {
validate_linalg_dtype(ab.dtype())?;
if ab.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: ab.dtype(),
rhs: b.dtype(),
});
}
let dtype = ab.dtype();
match dtype {
DType::F32 | DType::F64 => {}
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "solve_banded",
});
}
}
let (n, nrhs) = validate_banded(ab.shape(), b.shape(), kl, ku)?;
let device = client.device();
let elem_size = dtype.size_in_bytes();
let work_rows = 2 * kl + ku + 1;
let work_size = work_rows * n * elem_size;
let x_total_size = n * nrhs * elem_size;
let col_size = n * elem_size;
let work_guard = AllocGuard::new(client.allocator(), work_size)?;
let x_guard = AllocGuard::new(client.allocator(), x_total_size)?;
let work_ptr = work_guard.ptr();
let x_ptr = x_guard.ptr();
let ab_contig = ab.contiguous();
let b_contig = b.contiguous();
let b_is_1d = b.ndim() == 1;
if nrhs == 1 {
let result = unsafe {
kernels::launch_banded_solve(
client.context(),
client.stream(),
device.index,
dtype,
ab_contig.ptr(),
b_contig.ptr(),
x_ptr,
work_ptr,
n,
kl,
ku,
)
};
result?
} else {
let b_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let x_col_guard = AllocGuard::new(client.allocator(), col_size)?;
let b_col_ptr = b_col_guard.ptr();
let x_col_ptr = x_col_guard.ptr();
for rhs in 0..nrhs {
let result = unsafe {
kernels::launch_extract_column(
client.context(),
client.stream(),
device.index,
dtype,
b_contig.ptr(),
b_col_ptr,
n,
nrhs,
rhs,
)
};
result?;
let result = unsafe {
kernels::launch_banded_solve(
client.context(),
client.stream(),
device.index,
dtype,
ab_contig.ptr(),
b_col_ptr,
x_col_ptr,
work_ptr,
n,
kl,
ku,
)
};
result?;
let result = unsafe {
kernels::launch_scatter_column(
client.context(),
client.stream(),
device.index,
dtype,
x_col_ptr,
x_ptr,
n,
rhs,
)
};
result?
}
}
client.synchronize();
let released_ptr = x_guard.release();
let x = if b_is_1d {
unsafe { CudaClient::tensor_from_raw(released_ptr, &[n], dtype, device) }
} else {
unsafe { CudaClient::tensor_from_raw(released_ptr, &[n, nrhs], dtype, device) }
};
Ok(x)
}