use nalgebra::{linalg::SVD, DMatrix};
use num_complex::Complex64;
use runmat_accelerate_api::{
AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
};
use runmat_builtins::{ComplexTensor, Tensor, Value};
use runmat_macros::runtime_builtin;
use crate::builtins::common::spec::{
BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
};
use crate::builtins::common::{
gpu_helpers,
linalg::{diagonal_rcond, singular_value_rcond},
tensor,
};
use crate::builtins::math::linalg::type_resolvers::left_divide_type;
use crate::{build_runtime_error, BuiltinResult, RuntimeError};
const NAME: &str = "linsolve";
#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::solve::linsolve")]
pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
name: "linsolve",
op_kind: GpuOpKind::Custom("solve"),
supported_precisions: &[ScalarType::F32, ScalarType::F64],
broadcast: BroadcastSemantics::None,
provider_hooks: &[ProviderHook::Custom("linsolve")],
constant_strategy: ConstantStrategy::UniformBuffer,
residency: ResidencyPolicy::NewHandle,
nan_mode: ReductionNaN::Include,
two_pass_threshold: None,
workgroup_size: None,
accepts_nan_mode: false,
notes: "Prefers the provider linsolve hook; WGPU currently supports triangular solves, real F32 TRANSA='T'/'C' variants, a dedicated real F32 POSDEF/Cholesky path, and selected real F32 QR-backed square and rectangular solves, otherwise it gathers to the host solver and re-uploads the result.",
};
fn builtin_error(message: impl Into<String>) -> RuntimeError {
build_runtime_error(message).with_builtin(NAME).build()
}
fn map_control_flow(err: RuntimeError) -> RuntimeError {
let mut builder = build_runtime_error(err.message()).with_builtin(NAME);
if let Some(identifier) = err.identifier() {
builder = builder.with_identifier(identifier.to_string());
}
if let Some(task_id) = err.context.task_id.clone() {
builder = builder.with_task_id(task_id);
}
if !err.context.call_stack.is_empty() {
builder = builder.with_call_stack(err.context.call_stack.clone());
}
if let Some(phase) = err.context.phase.clone() {
builder = builder.with_phase(phase);
}
builder.with_source(err).build()
}
#[runmat_macros::register_fusion_spec(
builtin_path = "crate::builtins::math::linalg::solve::linsolve"
)]
pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
name: "linsolve",
shape: ShapeRequirements::Any,
constant_strategy: ConstantStrategy::UniformBuffer,
elementwise: None,
reduction: None,
emits_nan: false,
notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
};
#[runtime_builtin(
name = "linsolve",
category = "math/linalg/solve",
summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
keywords = "linsolve,linear system,triangular,gpu",
accel = "linsolve",
type_resolver(left_divide_type),
builtin_path = "crate::builtins::math::linalg::solve::linsolve"
)]
async fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
let eval = evaluate_args(lhs, rhs, &rest).await?;
if let Some(out_count) = crate::output_count::current_output_count() {
if out_count == 0 {
return Ok(Value::OutputList(Vec::new()));
}
if out_count == 1 {
return Ok(Value::OutputList(vec![eval.solution()]));
}
if out_count == 2 {
return Ok(Value::OutputList(vec![
eval.solution(),
eval.reciprocal_condition(),
]));
}
return Err(builtin_error(
"linsolve currently supports at most two outputs",
));
}
Ok(eval.solution())
}
pub async fn evaluate(
lhs: Value,
rhs: Value,
options: SolveOptions,
) -> BuiltinResult<LinsolveEval> {
if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options).await? {
return Ok(eval);
}
let lhs_host = crate::dispatcher::gather_if_needed_async(&lhs)
.await
.map_err(map_control_flow)?;
let rhs_host = crate::dispatcher::gather_if_needed_async(&rhs)
.await
.map_err(map_control_flow)?;
let pair = coerce_numeric_pair(lhs_host, rhs_host).await?;
match pair {
NumericPair::Real(lhs_r, rhs_r) => {
let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
Ok(LinsolveEval::new(
tensor::tensor_into_value(solution),
Some(rcond),
))
}
NumericPair::Complex(lhs_c, rhs_c) => {
let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
Ok(LinsolveEval::new(
Value::ComplexTensor(solution),
Some(rcond),
))
}
}
}
pub fn linsolve_host_real_for_provider(
lhs: &Tensor,
rhs: &Tensor,
options: &ProviderLinsolveOptions,
) -> BuiltinResult<(Tensor, f64)> {
let opts = SolveOptions::from(options);
solve_real(lhs.clone(), rhs.clone(), &opts)
}
#[derive(Clone)]
pub struct LinsolveEval {
solution: Value,
rcond: Option<f64>,
}
impl LinsolveEval {
fn new(solution: Value, rcond: Option<f64>) -> Self {
Self { solution, rcond }
}
pub fn solution(&self) -> Value {
self.solution.clone()
}
pub fn reciprocal_condition(&self) -> Value {
match self.rcond {
Some(r) => Value::Num(r),
None => Value::Num(f64::NAN),
}
}
}
#[derive(Clone, Default)]
pub struct SolveOptions {
lower: bool,
upper: bool,
rectangular: bool,
transposed: bool,
conjugate: bool,
symmetric: bool,
posdef: bool,
rcond: Option<f64>,
}
impl From<&SolveOptions> for ProviderLinsolveOptions {
fn from(opts: &SolveOptions) -> Self {
Self {
lower: opts.lower,
upper: opts.upper,
rectangular: opts.rectangular,
transposed: opts.transposed,
conjugate: opts.conjugate,
symmetric: opts.symmetric,
posdef: opts.posdef,
need_rcond: false,
rcond: opts.rcond,
}
}
}
impl From<&ProviderLinsolveOptions> for SolveOptions {
fn from(opts: &ProviderLinsolveOptions) -> Self {
Self {
lower: opts.lower,
upper: opts.upper,
rectangular: opts.rectangular,
transposed: opts.transposed,
conjugate: opts.conjugate,
symmetric: opts.symmetric,
posdef: opts.posdef,
rcond: opts.rcond,
}
}
}
fn options_from_rest(rest: &[Value]) -> BuiltinResult<SolveOptions> {
match rest.len() {
0 => Ok(SolveOptions::default()),
1 => parse_options(&rest[0]),
_ => Err(builtin_error("linsolve: too many input arguments")),
}
}
pub async fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> BuiltinResult<LinsolveEval> {
let options = options_from_rest(rest)?;
evaluate(lhs, rhs, options).await
}
async fn try_gpu_linsolve(
lhs: &Value,
rhs: &Value,
options: &SolveOptions,
) -> BuiltinResult<Option<LinsolveEval>> {
if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
return Ok(None);
}
let provider = match runmat_accelerate_api::provider() {
Some(p) => p,
None => return Ok(None),
};
if contains_complex(lhs) || contains_complex(rhs) {
return Ok(None);
}
let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
Some(op) => op,
None => return Ok(None),
};
let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
Some(op) => op,
None => {
release_operand(provider, &mut lhs_operand);
return Ok(None);
}
};
if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
release_operand(provider, &mut lhs_operand);
release_operand(provider, &mut rhs_operand);
return Ok(None);
}
let mut provider_opts: ProviderLinsolveOptions = options.into();
provider_opts.need_rcond =
matches!(crate::output_count::current_output_count(), Some(2)) || options.rcond.is_some();
let result = provider
.linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
.await
.ok();
release_operand(provider, &mut lhs_operand);
release_operand(provider, &mut rhs_operand);
if let Some(ProviderLinsolveResult {
solution,
reciprocal_condition,
}) = result
{
let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
return Ok(Some(eval));
}
Ok(None)
}
fn parse_options(value: &Value) -> BuiltinResult<SolveOptions> {
let struct_val = match value {
Value::Struct(s) => s,
other => {
return Err(builtin_error(format!(
"linsolve: opts must be a struct, got {other:?}"
)))
}
};
let mut opts = SolveOptions::default();
for (key, raw_value) in &struct_val.fields {
let name = key.to_ascii_uppercase();
match name.as_str() {
"LT" => opts.lower = parse_bool_field("LT", raw_value)?,
"UT" => opts.upper = parse_bool_field("UT", raw_value)?,
"RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
"SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
"POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
"TRANSA" => {
let transa = parse_transa(raw_value)?;
opts.transposed = transa != TransposeMode::None;
opts.conjugate = transa == TransposeMode::Conjugate;
}
"RCOND" => {
let threshold = parse_scalar_f64("RCOND", raw_value)?;
if threshold < 0.0 {
return Err(builtin_error("linsolve: RCOND must be non-negative"));
}
opts.rcond = Some(threshold);
}
other => return Err(builtin_error(format!("linsolve: unknown option '{other}'"))),
}
}
if opts.lower && opts.upper {
return Err(builtin_error("linsolve: LT and UT are mutually exclusive."));
}
Ok(opts)
}
fn parse_bool_field(name: &str, value: &Value) -> BuiltinResult<bool> {
match value {
Value::Bool(b) => Ok(*b),
Value::Int(i) => Ok(!i.is_zero()),
Value::Num(n) => Ok(*n != 0.0),
Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
other => Err(builtin_error(format!(
"linsolve: option '{name}' must be logical or numeric, got {other:?}"
))),
}
}
fn parse_scalar_f64(name: &str, value: &Value) -> BuiltinResult<f64> {
match value {
Value::Num(n) => Ok(*n),
Value::Int(i) => Ok(i.to_f64()),
Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
other => Err(builtin_error(format!(
"linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
))),
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
enum TransposeMode {
None,
Transpose,
Conjugate,
}
fn parse_transa(value: &Value) -> BuiltinResult<TransposeMode> {
let text = tensor::value_to_string(value).ok_or_else(|| {
builtin_error("linsolve: TRANSA must be a character vector or string scalar")
})?;
if text.is_empty() {
return Err(builtin_error("linsolve: TRANSA cannot be empty"));
}
match text.trim().to_ascii_uppercase().as_str() {
"N" => Ok(TransposeMode::None),
"T" => Ok(TransposeMode::Transpose),
"C" => Ok(TransposeMode::Conjugate),
other => Err(builtin_error(format!(
"linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
))),
}
}
enum NumericInput {
Real(Tensor),
Complex(ComplexTensor),
}
enum NumericPair {
Real(Tensor, Tensor),
Complex(ComplexTensor, ComplexTensor),
}
async fn coerce_numeric_pair(lhs: Value, rhs: Value) -> BuiltinResult<NumericPair> {
let lhs_num = coerce_numeric(lhs).await?;
let rhs_num = coerce_numeric(rhs).await?;
match (lhs_num, rhs_num) {
(NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
Ok(NumericPair::Real(lhs_r, rhs_r))
}
(NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
Ok(NumericPair::Complex(lhs_c, rhs_c))
}
(NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
let rhs_c = promote_real_tensor(&rhs_r)?;
Ok(NumericPair::Complex(lhs_c, rhs_c))
}
(NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
let lhs_c = promote_real_tensor(&lhs_r)?;
Ok(NumericPair::Complex(lhs_c, rhs_c))
}
}
}
async fn coerce_numeric(value: Value) -> BuiltinResult<NumericInput> {
match value {
Value::Tensor(tensor) => {
ensure_matrix_shape(NAME, &tensor.shape)?;
Ok(NumericInput::Real(tensor))
}
Value::LogicalArray(logical) => {
let tensor = tensor::logical_to_tensor(&logical).map_err(builtin_error)?;
ensure_matrix_shape(NAME, &tensor.shape)?;
Ok(NumericInput::Real(tensor))
}
Value::Num(n) => {
let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(builtin_error)?;
Ok(NumericInput::Real(tensor))
}
Value::Int(i) => {
let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(builtin_error)?;
Ok(NumericInput::Real(tensor))
}
Value::Bool(b) => {
let tensor =
Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1]).map_err(builtin_error)?;
Ok(NumericInput::Real(tensor))
}
Value::Complex(re, im) => {
let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(builtin_error)?;
Ok(NumericInput::Complex(tensor))
}
Value::ComplexTensor(ct) => {
ensure_matrix_shape(NAME, &ct.shape)?;
Ok(NumericInput::Complex(ct))
}
Value::GpuTensor(handle) => {
let tensor = gpu_helpers::gather_tensor_async(&handle)
.await
.map_err(map_control_flow)?;
ensure_matrix_shape(NAME, &tensor.shape)?;
Ok(NumericInput::Real(tensor))
}
other => Err(builtin_error(format!(
"{NAME}: unsupported input type {:?}; convert to numeric values first",
other
))),
}
}
fn contains_complex(value: &Value) -> bool {
matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
}
fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
crate::builtins::common::shape::is_scalar_shape(&handle.shape)
}
struct PreparedOperand {
handle: GpuTensorHandle,
owned: bool,
}
impl PreparedOperand {
fn borrowed(handle: &GpuTensorHandle) -> Self {
Self {
handle: handle.clone(),
owned: false,
}
}
fn owned(handle: GpuTensorHandle) -> Self {
Self {
handle,
owned: true,
}
}
fn handle(&self) -> &GpuTensorHandle {
&self.handle
}
}
fn prepare_gpu_operand(
value: &Value,
provider: &'static dyn AccelProvider,
) -> BuiltinResult<Option<PreparedOperand>> {
match value {
Value::GpuTensor(handle) => {
if is_scalar_handle(handle) {
Ok(None)
} else {
Ok(Some(PreparedOperand::borrowed(handle)))
}
}
Value::Tensor(tensor) => {
if tensor::is_scalar_tensor(tensor) {
Ok(None)
} else {
let uploaded = upload_tensor(provider, tensor)?;
Ok(Some(PreparedOperand::owned(uploaded)))
}
}
Value::LogicalArray(logical) => {
if logical.data.len() == 1 {
Ok(None)
} else {
let tensor = tensor::logical_to_tensor(logical).map_err(builtin_error)?;
let uploaded = upload_tensor(provider, &tensor)?;
Ok(Some(PreparedOperand::owned(uploaded)))
}
}
_ => Ok(None),
}
}
fn upload_tensor(
provider: &'static dyn AccelProvider,
tensor: &Tensor,
) -> BuiltinResult<GpuTensorHandle> {
let view = HostTensorView {
data: &tensor.data,
shape: &tensor.shape,
};
provider
.upload(&view)
.map_err(|e| builtin_error(format!("{NAME}: {e}")))
}
fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
if operand.owned {
let _ = provider.free(&operand.handle);
operand.owned = false;
}
}
fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> BuiltinResult<(Tensor, f64)> {
let mut lhs_effective = lhs;
let mut rhs_effective = rhs;
let mut lower = options.lower;
let mut upper = options.upper;
if options.transposed {
lhs_effective = transpose_tensor(&lhs_effective);
if options.conjugate {
conjugate_in_place(&mut lhs_effective);
}
if lower || upper {
std::mem::swap(&mut lower, &mut upper);
}
}
rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
if lower {
ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
return Ok((solution, rcond));
}
if upper {
ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
return Ok((solution, rcond));
}
let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
Ok((solution, rcond))
}
fn solve_complex(
lhs: ComplexTensor,
rhs: ComplexTensor,
options: &SolveOptions,
) -> BuiltinResult<(ComplexTensor, f64)> {
let mut lhs_effective = lhs;
let mut rhs_effective = rhs;
let mut lower = options.lower;
let mut upper = options.upper;
if options.transposed {
lhs_effective = transpose_complex(&lhs_effective);
if options.conjugate {
conjugate_complex_in_place(&mut lhs_effective);
}
if lower || upper {
std::mem::swap(&mut lower, &mut upper);
}
}
rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
if lower {
ensure_square(lhs_effective.rows, lhs_effective.cols)?;
let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
return Ok((solution, rcond));
}
if upper {
ensure_square(lhs_effective.rows, lhs_effective.cols)?;
let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
return Ok((solution, rcond));
}
let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
enforce_rcond(options, rcond)?;
Ok((solution, rcond))
}
fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
let n = lhs.rows();
let nrhs = rhs.data.len() / n;
let mut solution = rhs.data.clone();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for col in 0..nrhs {
for i in 0..n {
let diag = lhs.data[i + i * n];
let diag_abs = diag.abs();
min_diag = min_diag.min(diag_abs);
max_diag = max_diag.max(diag_abs);
if diag_abs == 0.0 {
return Err(builtin_error(
"linsolve: matrix is singular to working precision.",
));
}
let mut accum = 0.0;
for j in 0..i {
accum += lhs.data[i + j * n] * solution[j + col * n];
}
let rhs_value = solution[i + col * n] - accum;
solution[i + col * n] = rhs_value / diag;
}
}
let rcond = diagonal_rcond(min_diag, max_diag);
let tensor = Tensor::new(solution, rhs.shape.clone())
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
Ok((tensor, rcond))
}
fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
let n = lhs.rows();
let nrhs = rhs.data.len() / n;
let mut solution = rhs.data.clone();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for col in 0..nrhs {
for row_rev in 0..n {
let i = n - 1 - row_rev;
let diag = lhs.data[i + i * n];
let diag_abs = diag.abs();
min_diag = min_diag.min(diag_abs);
max_diag = max_diag.max(diag_abs);
if diag_abs == 0.0 {
return Err(builtin_error(
"linsolve: matrix is singular to working precision.",
));
}
let mut accum = 0.0;
for j in (i + 1)..n {
accum += lhs.data[i + j * n] * solution[j + col * n];
}
let rhs_value = solution[i + col * n] - accum;
solution[i + col * n] = rhs_value / diag;
}
}
let rcond = diagonal_rcond(min_diag, max_diag);
let tensor = Tensor::new(solution, rhs.shape.clone())
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
Ok((tensor, rcond))
}
fn forward_substitution_complex(
lhs: &ComplexTensor,
rhs: &ComplexTensor,
) -> BuiltinResult<(ComplexTensor, f64)> {
let n = lhs.rows;
let nrhs = rhs.data.len() / n;
let lhs_data: Vec<Complex64> = lhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mut solution: Vec<Complex64> = rhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for col in 0..nrhs {
for i in 0..n {
let diag = lhs_data[i + i * n];
let diag_abs = diag.norm();
min_diag = min_diag.min(diag_abs);
max_diag = max_diag.max(diag_abs);
if diag_abs == 0.0 {
return Err(builtin_error(
"linsolve: matrix is singular to working precision.",
));
}
let mut accum = Complex64::new(0.0, 0.0);
for j in 0..i {
accum += lhs_data[i + j * n] * solution[j + col * n];
}
let rhs_value = solution[i + col * n] - accum;
solution[i + col * n] = rhs_value / diag;
}
}
let rcond = diagonal_rcond(min_diag, max_diag);
let tensor = ComplexTensor::new(
solution.iter().map(|c| (c.re, c.im)).collect(),
rhs.shape.clone(),
)
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
Ok((tensor, rcond))
}
fn backward_substitution_complex(
lhs: &ComplexTensor,
rhs: &ComplexTensor,
) -> BuiltinResult<(ComplexTensor, f64)> {
let n = lhs.rows;
let nrhs = rhs.data.len() / n;
let lhs_data: Vec<Complex64> = lhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mut solution: Vec<Complex64> = rhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mut min_diag = f64::INFINITY;
let mut max_diag = 0.0_f64;
for col in 0..nrhs {
for row_rev in 0..n {
let i = n - 1 - row_rev;
let diag = lhs_data[i + i * n];
let diag_abs = diag.norm();
min_diag = min_diag.min(diag_abs);
max_diag = max_diag.max(diag_abs);
if diag_abs == 0.0 {
return Err(builtin_error(
"linsolve: matrix is singular to working precision.",
));
}
let mut accum = Complex64::new(0.0, 0.0);
for j in (i + 1)..n {
accum += lhs_data[i + j * n] * solution[j + col * n];
}
let rhs_value = solution[i + col * n] - accum;
solution[i + col * n] = rhs_value / diag;
}
}
let rcond = diagonal_rcond(min_diag, max_diag);
let tensor = ComplexTensor::new(
solution.iter().map(|c| (c.re, c.im)).collect(),
rhs.shape.clone(),
)
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
Ok((tensor, rcond))
}
fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
let svd = SVD::new(a.clone(), true, true);
let rcond = singular_value_rcond(svd.singular_values.as_slice());
let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
let solution = svd
.solve(&b, tol)
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
let tensor = matrix_real_to_tensor(solution)?;
Ok((tensor, rcond))
}
fn solve_general_complex(
lhs: &ComplexTensor,
rhs: &ComplexTensor,
) -> BuiltinResult<(ComplexTensor, f64)> {
let a_data: Vec<Complex64> = lhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let b_data: Vec<Complex64> = rhs
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
let svd = SVD::new(a.clone(), true, true);
let rcond = singular_value_rcond(svd.singular_values.as_slice());
let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
let solution = svd
.solve(&b, tol)
.map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
let tensor = matrix_complex_to_tensor(solution)?;
Ok((tensor, rcond))
}
fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> BuiltinResult<Tensor> {
if rhs.rows() == expected_rows {
return Ok(rhs);
}
if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
return Tensor::new(rhs.data, vec![expected_rows, 1])
.map_err(|e| builtin_error(format!("{NAME}: {e}")));
}
if rhs.data.is_empty() && expected_rows == 0 {
return Ok(rhs);
}
Err(builtin_error("Matrix dimensions must agree."))
}
fn normalize_rhs_complex(rhs: ComplexTensor, expected_rows: usize) -> BuiltinResult<ComplexTensor> {
if rhs.rows == expected_rows {
return Ok(rhs);
}
if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
.map_err(|e| builtin_error(format!("{NAME}: {e}")));
}
if rhs.data.is_empty() && expected_rows == 0 {
return Ok(rhs);
}
Err(builtin_error("Matrix dimensions must agree."))
}
fn enforce_rcond(options: &SolveOptions, rcond: f64) -> BuiltinResult<()> {
if let Some(threshold) = options.rcond {
if rcond < threshold {
return Err(builtin_error(
"linsolve: matrix is singular to working precision.",
));
}
}
Ok(())
}
fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
let max_sv = singular_values
.iter()
.copied()
.fold(0.0_f64, |acc, value| acc.max(value.abs()));
let max_dim = rows.max(cols) as f64;
f64::EPSILON * max_dim * max_sv.max(1.0)
}
fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> BuiltinResult<Tensor> {
let rows = matrix.nrows();
let cols = matrix.ncols();
Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols])
.map_err(|e| builtin_error(format!("{NAME}: {e}")))
}
fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> BuiltinResult<ComplexTensor> {
let rows = matrix.nrows();
let cols = matrix.ncols();
let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
ComplexTensor::new(data, vec![rows, cols]).map_err(|e| builtin_error(format!("{NAME}: {e}")))
}
fn promote_real_tensor(tensor: &Tensor) -> BuiltinResult<ComplexTensor> {
let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
ComplexTensor::new(data, tensor.shape.clone())
.map_err(|e| builtin_error(format!("{NAME}: {e}")))
}
fn ensure_matrix_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
if is_effectively_matrix(shape) {
Ok(())
} else {
Err(builtin_error(format!(
"{name}: inputs must be 2-D matrices or vectors"
)))
}
}
fn is_effectively_matrix(shape: &[usize]) -> bool {
match shape.len() {
0..=2 => true,
_ => shape.iter().skip(2).all(|&dim| dim == 1),
}
}
fn ensure_square(rows: usize, cols: usize) -> BuiltinResult<()> {
if rows == cols {
Ok(())
} else {
Err(builtin_error(
"linsolve: triangular solves require a square coefficient matrix.",
))
}
}
fn transpose_tensor(tensor: &Tensor) -> Tensor {
let rows = tensor.rows();
let cols = tensor.cols();
let mut data = vec![0.0; tensor.data.len()];
for r in 0..rows {
for c in 0..cols {
data[c + r * cols] = tensor.data[r + c * rows];
}
}
Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
}
fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
let rows = tensor.rows;
let cols = tensor.cols;
let mut data = vec![(0.0, 0.0); tensor.data.len()];
for r in 0..rows {
for c in 0..cols {
data[c + r * cols] = tensor.data[r + c * rows];
}
}
ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
}
fn conjugate_in_place(_tensor: &mut Tensor) {
}
fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
for value in &mut tensor.data {
value.1 = -value.1;
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use futures::executor::block_on;
use runmat_accelerate_api::HostTensorView;
use runmat_builtins::{CharArray, ResolveContext, StructValue, Type};
fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
err
}
fn approx_eq(actual: f64, expected: f64) {
assert!((actual - expected).abs() < 1e-7);
}
fn evaluate_args(a: Value, b: Value, rest: &[Value]) -> Result<LinsolveEval, RuntimeError> {
block_on(super::evaluate_args(a, b, rest))
}
#[test]
fn linsolve_type_uses_rhs_columns() {
let out = left_divide_type(
&[
Type::Tensor {
shape: Some(vec![Some(2), Some(2)]),
},
Type::Tensor {
shape: Some(vec![Some(2), Some(3)]),
},
],
&ResolveContext::new(Vec::new()),
);
assert_eq!(
out,
Type::Tensor {
shape: Some(vec![Some(2), Some(3)])
}
);
}
use crate::builtins::common::test_support;
use runmat_accelerate_api::ProviderTelemetry;
fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
block_on(super::linsolve_builtin(lhs, rhs, rest))
}
fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> BuiltinResult<LinsolveEval> {
block_on(super::evaluate(lhs, rhs, options))
}
fn fallback_count(telemetry: &ProviderTelemetry, reason: &str) -> u64 {
telemetry
.solve_fallbacks
.iter()
.find(|entry| entry.reason == reason)
.map(|entry| entry.count)
.unwrap_or(0)
}
#[cfg(feature = "wgpu")]
fn kernel_launch_count(telemetry: &ProviderTelemetry, kernel: &str) -> usize {
telemetry
.kernel_launches
.iter()
.filter(|entry| entry.kernel == kernel)
.count()
}
fn clear_accel_provider_state() {
runmat_accelerate_api::set_thread_provider(None);
runmat_accelerate_api::clear_provider();
}
fn host_linsolve_real(
a: &Tensor,
b: &Tensor,
options: ProviderLinsolveOptions,
) -> (Tensor, f64) {
super::linsolve_host_real_for_provider(a, b, &options).expect("host linsolve")
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_basic_square() {
let _accel_guard = test_support::accel_test_lock();
clear_accel_provider_state();
let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
let result =
linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
let t = test_support::gather(result).expect("gather");
assert_eq!(t.shape, vec![2, 1]);
approx_eq(t.data[0], 1.0);
approx_eq(t.data[1], 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_lower_triangular_hint() {
let a = Tensor::new(
vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
vec![3, 3],
)
.unwrap();
let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
let mut opts = StructValue::new();
opts.fields.insert("LT".to_string(), Value::Bool(true));
let result = linsolve_builtin(
Value::Tensor(a),
Value::Tensor(b),
vec![Value::Struct(opts)],
)
.expect("linsolve");
let tensor = test_support::gather(result).expect("gather");
assert_eq!(tensor.shape, vec![3, 1]);
approx_eq(tensor.data[0], 3.0);
approx_eq(tensor.data[1], 2.0);
approx_eq(tensor.data[2], 1.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_transposed_triangular_hint() {
let _accel_guard = test_support::accel_test_lock();
clear_accel_provider_state();
let a = Tensor::new(
vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
vec![3, 3],
)
.unwrap();
let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
let mut opts = StructValue::new();
opts.fields.insert("LT".to_string(), Value::Bool(true));
opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let result = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(opts)],
)
.expect("linsolve");
let tensor = test_support::gather(result).expect("gather");
assert_eq!(tensor.shape, vec![3, 1]);
let a_transposed = transpose_tensor(&a);
let (expected_tensor, _) =
host_linsolve_real(&a_transposed, &b, ProviderLinsolveOptions::default());
for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
approx_eq(*actual, *expected);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_complex_inputs_match_residual() {
let a = ComplexTensor::new(
vec![(2.0, 1.0), (-1.0, 0.0), (1.0, -2.0), (3.0, -2.0)],
vec![2, 2],
)
.unwrap();
let b = ComplexTensor::new(vec![(1.0, 0.0), (4.0, 1.0)], vec![2, 1]).unwrap();
let result = linsolve_builtin(
Value::ComplexTensor(a.clone()),
Value::ComplexTensor(b.clone()),
Vec::new(),
)
.expect("linsolve");
let Value::ComplexTensor(out) = result else {
panic!("expected complex tensor result");
};
let mat_a: Vec<Complex64> = a
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mat_b: Vec<Complex64> = b
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let mat_x: Vec<Complex64> = out
.data
.iter()
.map(|&(re, im)| Complex64::new(re, im))
.collect();
let a_mat = DMatrix::from_column_slice(a.rows, a.cols, &mat_a);
let b_mat = DMatrix::from_column_slice(b.rows, b.cols, &mat_b);
let x_mat = DMatrix::from_column_slice(out.rows, out.cols, &mat_x);
let residual = a_mat * x_mat - b_mat;
assert!(residual.norm() < 1e-10, "residual={}", residual.norm());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_complex_conjugate_transpose_matches_explicit_reference() {
let a = ComplexTensor::new(
vec![(2.0, 1.0), (0.0, -1.0), (1.0, 2.0), (3.0, 0.5)],
vec![2, 2],
)
.unwrap();
let b = ComplexTensor::new(vec![(1.0, -1.0), (2.0, 0.5)], vec![2, 1]).unwrap();
let mut opts = StructValue::new();
opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("C")),
);
let result = linsolve_builtin(
Value::ComplexTensor(a.clone()),
Value::ComplexTensor(b.clone()),
vec![Value::Struct(opts)],
)
.expect("linsolve");
let Value::ComplexTensor(out) = result else {
panic!("expected complex tensor result");
};
let mut a_conj_t = transpose_complex(&a);
conjugate_complex_in_place(&mut a_conj_t);
let reference = evaluate(
Value::ComplexTensor(a_conj_t),
Value::ComplexTensor(b.clone()),
SolveOptions::default(),
)
.expect("reference");
let Value::ComplexTensor(expected) = reference.solution() else {
panic!("expected complex tensor reference");
};
assert_eq!(out.shape, expected.shape);
for ((out_re, out_im), (exp_re, exp_im)) in out.data.iter().zip(expected.data.iter()) {
assert!(
(out_re - exp_re).abs() < 1e-10,
"out_re={out_re} exp_re={exp_re}"
);
assert!(
(out_im - exp_im).abs() < 1e-10,
"out_im={out_im} exp_im={exp_im}"
);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_rcond_enforced() {
let _accel_guard = test_support::accel_test_lock();
clear_accel_provider_state();
let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
let mut opts = StructValue::new();
opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
let err = unwrap_error(
linsolve_builtin(
Value::Tensor(a),
Value::Tensor(b),
vec![Value::Struct(opts)],
)
.expect_err("singular matrix must fail"),
);
assert!(
err.message().contains("singular to working precision"),
"unexpected error message: {err}"
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn linsolve_recovers_rcond_output() {
let _accel_guard = test_support::accel_test_lock();
clear_accel_provider_state();
let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
let eval = evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
.expect("evaluate");
let solution_tensor = match eval.solution() {
Value::Tensor(sol) => sol.clone(),
Value::GpuTensor(handle) => {
test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
}
other => panic!("unexpected solution value {other:?}"),
};
assert_eq!(solution_tensor.shape, vec![2, 1]);
approx_eq(solution_tensor.data[0], 1.0);
approx_eq(solution_tensor.data[1], 2.0);
let rcond_value = match eval.reciprocal_condition() {
Value::Num(r) => r,
Value::GpuTensor(handle) => {
let gathered =
test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
gathered.data[0]
}
other => panic!("unexpected rcond value {other:?}"),
};
approx_eq(rcond_value, 1.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gpu_round_trip_matches_cpu() {
test_support::with_test_provider(|provider| {
let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
Vec::new(),
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
let view_a = HostTensorView {
data: &a.data,
shape: &a.shape,
};
let view_b = HostTensorView {
data: &b.data,
shape: &b.shape,
};
let ha = provider.upload(&view_a).expect("upload A");
let hb = provider.upload(&view_b).expect("upload B");
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu linsolve");
let gathered = test_support::gather(gpu_value).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-12);
}
});
}
#[test]
fn host_inputs_auto_promote_into_provider_solve_path() {
test_support::with_test_provider(|provider| {
provider.reset_telemetry();
let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
let _ = linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
.expect("host linsolve");
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
assert!(telemetry.upload_bytes > 0);
assert!(telemetry.download_bytes > 0);
});
}
#[test]
fn provider_telemetry_records_gpu_host_reupload_path() {
test_support::with_test_provider(|provider| {
provider.reset_telemetry();
let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _ = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu linsolve");
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert!(telemetry.upload_bytes > 0);
assert!(telemetry.download_bytes > 0);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
let _ = provider.free(&ha);
let _ = provider.free(&hb);
});
}
#[test]
fn scalar_gpu_inputs_fall_back_without_provider_solve_dispatch() {
test_support::with_test_provider(|provider| {
provider.reset_telemetry();
let a = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
let b = Tensor::new(vec![6.0], vec![1, 1]).unwrap();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let result = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("fallback linsolve");
let gathered = test_support::gather(result).expect("gather fallback");
assert_eq!(gathered.data, vec![3.0]);
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 0);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert!(telemetry.download_bytes > 0);
let _ = provider.free(&ha);
let _ = provider.free(&hb);
});
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_square_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
Vec::new(),
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu square linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4);
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_square_linsolve_uses_device_path_without_output_count() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
Vec::new(),
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu square linsolve");
let gathered = test_support::gather(gpu_value).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_square_linsolve_recovers_rcond_output_on_device() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let (_, cpu_rcond) = host_linsolve_real(&a, &b, ProviderLinsolveOptions::default());
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(2));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu square linsolve");
let outputs = match gpu_value {
Value::OutputList(outputs) => outputs,
other => panic!("expected output list, got {other:?}"),
};
assert_eq!(outputs.len(), 2);
let gathered = test_support::gather(outputs[0].clone()).expect("gather");
let gpu_rcond = match &outputs[1] {
Value::Num(value) => *value,
other => panic!("unexpected gpu rcond {other:?}"),
};
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, vec![2, 1]);
assert!(
(gpu_rcond - cpu_rcond).abs() < 1e-4,
"gpu={gpu_rcond} cpu={cpu_rcond}"
);
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_square_linsolve_with_rcond_option_stays_on_device() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts
.fields
.insert("RCOND".to_string(), Value::Num(0.05));
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts
.fields
.insert("RCOND".to_string(), Value::Num(0.05));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu square linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_tall_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
Vec::new(),
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu tall linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_posdef_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts
.fields
.insert("POSDEF".to_string(), Value::Bool(true));
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
let (_, cpu_rcond) = host_linsolve_real(
&a,
&b,
ProviderLinsolveOptions {
posdef: true,
..Default::default()
},
);
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(2));
let mut gpu_opts = StructValue::new();
gpu_opts
.fields
.insert("POSDEF".to_string(), Value::Bool(true));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu posdef linsolve");
let mut outputs = match gpu_value {
Value::OutputList(outputs) => outputs,
other => panic!("expected output list, got {other:?}"),
};
let gpu_rcond = match outputs.remove(1) {
Value::Num(value) => value,
other => panic!("unexpected rcond value {other:?}"),
};
let gpu_solution = outputs.remove(0);
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
assert!(
(gpu_rcond - cpu_rcond).abs() < 1e-4,
"gpu={gpu_rcond} cpu={cpu_rcond}"
);
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_transposed_posdef_linsolve_uses_cholesky_path() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![8.0, 9.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts
.fields
.insert("POSDEF".to_string(), Value::Bool(true));
cpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts
.fields
.insert("POSDEF".to_string(), Value::Bool(true));
gpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu transposed posdef linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_symmetric_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![5.0, 2.0, 2.0, 6.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![9.0, 8.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu symmetric linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_transposed_square_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu transposed square linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_conjugate_square_linsolve_avoids_host_reupload_fallback_for_real_inputs() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("C")),
);
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("C")),
);
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu conjugate square linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_transposed_rectangular_linsolve_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
return;
}
let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![2, 3]).unwrap();
let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
cpu_opts
.fields
.insert("RECT".to_string(), Value::Bool(true));
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
gpu_opts
.fields
.insert("RECT".to_string(), Value::Bool(true));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu transposed rectangular linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_triangular_hint_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
let a = Tensor::new(
vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
vec![3, 3],
)
.unwrap();
let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
let cpu = linsolve_builtin(Value::Tensor(a.clone()), Value::Tensor(b.clone()), {
let mut opts = StructValue::new();
opts.fields.insert("LT".to_string(), Value::Bool(true));
vec![Value::Struct(opts)]
})
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut opts = StructValue::new();
opts.fields.insert("LT".to_string(), Value::Bool(true));
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(opts)],
)
.expect("gpu triangular linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-5);
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_transposed_triangular_hint_avoids_host_reupload_fallback() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
let a = Tensor::new(
vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
vec![3, 3],
)
.unwrap();
let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
let mut cpu_opts = StructValue::new();
cpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
cpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
vec![Value::Struct(cpu_opts)],
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
provider.reset_telemetry();
let ha = provider
.upload(&HostTensorView {
data: &a.data,
shape: &a.shape,
})
.expect("upload A");
let hb = provider
.upload(&HostTensorView {
data: &b.data,
shape: &b.shape,
})
.expect("upload B");
let _output_guard = crate::output_count::push_output_count(Some(1));
let mut gpu_opts = StructValue::new();
gpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
gpu_opts.fields.insert(
"TRANSA".to_string(),
Value::CharArray(CharArray::new_row("T")),
);
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
vec![Value::Struct(gpu_opts)],
)
.expect("gpu transposed triangular linsolve");
let gpu_solution = match gpu_value {
Value::OutputList(mut outputs) => outputs.remove(0),
other => other,
};
let gathered = test_support::gather(gpu_solution).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < 1e-5);
}
let telemetry = provider.telemetry_snapshot();
assert_eq!(telemetry.linsolve.count, 1);
assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
}
#[cfg(feature = "wgpu")]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn wgpu_round_trip_matches_cpu() {
let _accel_guard = test_support::accel_test_lock();
let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
);
let provider = runmat_accelerate_api::provider().expect("wgpu provider");
let tol = match provider.precision() {
runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
};
let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
let cpu = linsolve_builtin(
Value::Tensor(a.clone()),
Value::Tensor(b.clone()),
Vec::new(),
)
.expect("cpu linsolve");
let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
let view_a = HostTensorView {
data: &a.data,
shape: &a.shape,
};
let view_b = HostTensorView {
data: &b.data,
shape: &b.shape,
};
let ha = provider.upload(&view_a).expect("upload A");
let hb = provider.upload(&view_b).expect("upload B");
let gpu_value = linsolve_builtin(
Value::GpuTensor(ha.clone()),
Value::GpuTensor(hb.clone()),
Vec::new(),
)
.expect("gpu linsolve");
let gathered = test_support::gather(gpu_value).expect("gather");
let _ = provider.free(&ha);
let _ = provider.free(&hb);
assert_eq!(gathered.shape, cpu_tensor.shape);
for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
assert!((gpu - cpu).abs() < tol);
}
}
}