Skip to main content

runmat_runtime/builtins/math/linalg/solve/
linsolve.rs

1//! MATLAB-compatible `linsolve` builtin with structural hints and GPU-aware fallbacks.
2
3use nalgebra::{linalg::SVD, DMatrix};
4use num_complex::Complex64;
5use runmat_accelerate_api::{
6    AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
14};
15use crate::builtins::common::{
16    gpu_helpers,
17    linalg::{diagonal_rcond, singular_value_rcond},
18    tensor,
19};
20use crate::builtins::math::linalg::type_resolvers::left_divide_type;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const NAME: &str = "linsolve";
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::solve::linsolve")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27    name: "linsolve",
28    op_kind: GpuOpKind::Custom("solve"),
29    supported_precisions: &[ScalarType::F32, ScalarType::F64],
30    broadcast: BroadcastSemantics::None,
31    provider_hooks: &[ProviderHook::Custom("linsolve")],
32    constant_strategy: ConstantStrategy::UniformBuffer,
33    residency: ResidencyPolicy::NewHandle,
34    nan_mode: ReductionNaN::Include,
35    two_pass_threshold: None,
36    workgroup_size: None,
37    accepts_nan_mode: false,
38    notes: "Prefers the provider linsolve hook; WGPU currently supports triangular solves, real F32 TRANSA='T'/'C' variants, a dedicated real F32 POSDEF/Cholesky path, and selected real F32 QR-backed square and rectangular solves, otherwise it gathers to the host solver and re-uploads the result.",
39};
40
41fn builtin_error(message: impl Into<String>) -> RuntimeError {
42    build_runtime_error(message).with_builtin(NAME).build()
43}
44
45fn map_control_flow(err: RuntimeError) -> RuntimeError {
46    let mut builder = build_runtime_error(err.message()).with_builtin(NAME);
47    if let Some(identifier) = err.identifier() {
48        builder = builder.with_identifier(identifier.to_string());
49    }
50    if let Some(task_id) = err.context.task_id.clone() {
51        builder = builder.with_task_id(task_id);
52    }
53    if !err.context.call_stack.is_empty() {
54        builder = builder.with_call_stack(err.context.call_stack.clone());
55    }
56    if let Some(phase) = err.context.phase.clone() {
57        builder = builder.with_phase(phase);
58    }
59    builder.with_source(err).build()
60}
61
62#[runmat_macros::register_fusion_spec(
63    builtin_path = "crate::builtins::math::linalg::solve::linsolve"
64)]
65pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
66    name: "linsolve",
67    shape: ShapeRequirements::Any,
68    constant_strategy: ConstantStrategy::UniformBuffer,
69    elementwise: None,
70    reduction: None,
71    emits_nan: false,
72    notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
73};
74
75#[runtime_builtin(
76    name = "linsolve",
77    category = "math/linalg/solve",
78    summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
79    keywords = "linsolve,linear system,triangular,gpu",
80    accel = "linsolve",
81    type_resolver(left_divide_type),
82    builtin_path = "crate::builtins::math::linalg::solve::linsolve"
83)]
84async fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
85    let eval = evaluate_args(lhs, rhs, &rest).await?;
86    if let Some(out_count) = crate::output_count::current_output_count() {
87        if out_count == 0 {
88            return Ok(Value::OutputList(Vec::new()));
89        }
90        if out_count == 1 {
91            return Ok(Value::OutputList(vec![eval.solution()]));
92        }
93        if out_count == 2 {
94            return Ok(Value::OutputList(vec![
95                eval.solution(),
96                eval.reciprocal_condition(),
97            ]));
98        }
99        return Err(builtin_error(
100            "linsolve currently supports at most two outputs",
101        ));
102    }
103    Ok(eval.solution())
104}
105
106/// Evaluate `linsolve`, returning both the solution and the estimated reciprocal condition number.
107pub async fn evaluate(
108    lhs: Value,
109    rhs: Value,
110    options: SolveOptions,
111) -> BuiltinResult<LinsolveEval> {
112    if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options).await? {
113        return Ok(eval);
114    }
115
116    let lhs_host = crate::dispatcher::gather_if_needed_async(&lhs)
117        .await
118        .map_err(map_control_flow)?;
119    let rhs_host = crate::dispatcher::gather_if_needed_async(&rhs)
120        .await
121        .map_err(map_control_flow)?;
122    let pair = coerce_numeric_pair(lhs_host, rhs_host).await?;
123    match pair {
124        NumericPair::Real(lhs_r, rhs_r) => {
125            let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
126            Ok(LinsolveEval::new(
127                tensor::tensor_into_value(solution),
128                Some(rcond),
129            ))
130        }
131        NumericPair::Complex(lhs_c, rhs_c) => {
132            let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
133            Ok(LinsolveEval::new(
134                Value::ComplexTensor(solution),
135                Some(rcond),
136            ))
137        }
138    }
139}
140
141/// Host implementation shared with acceleration providers that fall back to CPU execution.
142pub fn linsolve_host_real_for_provider(
143    lhs: &Tensor,
144    rhs: &Tensor,
145    options: &ProviderLinsolveOptions,
146) -> BuiltinResult<(Tensor, f64)> {
147    let opts = SolveOptions::from(options);
148    solve_real(lhs.clone(), rhs.clone(), &opts)
149}
150
151/// Result wrapper that exposes both primary and secondary outputs.
152#[derive(Clone)]
153pub struct LinsolveEval {
154    solution: Value,
155    rcond: Option<f64>,
156}
157
158impl LinsolveEval {
159    fn new(solution: Value, rcond: Option<f64>) -> Self {
160        Self { solution, rcond }
161    }
162
163    /// Primary solution output.
164    pub fn solution(&self) -> Value {
165        self.solution.clone()
166    }
167
168    /// Estimated reciprocal condition number (second output).
169    pub fn reciprocal_condition(&self) -> Value {
170        match self.rcond {
171            Some(r) => Value::Num(r),
172            None => Value::Num(f64::NAN),
173        }
174    }
175}
176
177#[derive(Clone, Default)]
178pub struct SolveOptions {
179    lower: bool,
180    upper: bool,
181    rectangular: bool,
182    transposed: bool,
183    conjugate: bool,
184    symmetric: bool,
185    posdef: bool,
186    rcond: Option<f64>,
187}
188
189impl From<&SolveOptions> for ProviderLinsolveOptions {
190    fn from(opts: &SolveOptions) -> Self {
191        Self {
192            lower: opts.lower,
193            upper: opts.upper,
194            rectangular: opts.rectangular,
195            transposed: opts.transposed,
196            conjugate: opts.conjugate,
197            symmetric: opts.symmetric,
198            posdef: opts.posdef,
199            need_rcond: false,
200            rcond: opts.rcond,
201        }
202    }
203}
204
205impl From<&ProviderLinsolveOptions> for SolveOptions {
206    fn from(opts: &ProviderLinsolveOptions) -> Self {
207        Self {
208            lower: opts.lower,
209            upper: opts.upper,
210            rectangular: opts.rectangular,
211            transposed: opts.transposed,
212            conjugate: opts.conjugate,
213            symmetric: opts.symmetric,
214            posdef: opts.posdef,
215            rcond: opts.rcond,
216        }
217    }
218}
219
220fn options_from_rest(rest: &[Value]) -> BuiltinResult<SolveOptions> {
221    match rest.len() {
222        0 => Ok(SolveOptions::default()),
223        1 => parse_options(&rest[0]),
224        _ => Err(builtin_error("linsolve: too many input arguments")),
225    }
226}
227
228/// Public helper for the VM multi-output surface.
229pub async fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> BuiltinResult<LinsolveEval> {
230    let options = options_from_rest(rest)?;
231    evaluate(lhs, rhs, options).await
232}
233
234async fn try_gpu_linsolve(
235    lhs: &Value,
236    rhs: &Value,
237    options: &SolveOptions,
238) -> BuiltinResult<Option<LinsolveEval>> {
239    if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
240        return Ok(None);
241    }
242    let provider = match runmat_accelerate_api::provider() {
243        Some(p) => p,
244        None => return Ok(None),
245    };
246
247    if contains_complex(lhs) || contains_complex(rhs) {
248        return Ok(None);
249    }
250
251    let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
252        Some(op) => op,
253        None => return Ok(None),
254    };
255    let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
256        Some(op) => op,
257        None => {
258            release_operand(provider, &mut lhs_operand);
259            return Ok(None);
260        }
261    };
262
263    if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
264        release_operand(provider, &mut lhs_operand);
265        release_operand(provider, &mut rhs_operand);
266        return Ok(None);
267    }
268
269    let mut provider_opts: ProviderLinsolveOptions = options.into();
270    provider_opts.need_rcond =
271        matches!(crate::output_count::current_output_count(), Some(2)) || options.rcond.is_some();
272    let result = provider
273        .linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
274        .await
275        .ok();
276
277    release_operand(provider, &mut lhs_operand);
278    release_operand(provider, &mut rhs_operand);
279
280    if let Some(ProviderLinsolveResult {
281        solution,
282        reciprocal_condition,
283    }) = result
284    {
285        let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
286        return Ok(Some(eval));
287    }
288
289    Ok(None)
290}
291
292fn parse_options(value: &Value) -> BuiltinResult<SolveOptions> {
293    let struct_val = match value {
294        Value::Struct(s) => s,
295        other => {
296            return Err(builtin_error(format!(
297                "linsolve: opts must be a struct, got {other:?}"
298            )))
299        }
300    };
301    let mut opts = SolveOptions::default();
302    for (key, raw_value) in &struct_val.fields {
303        let name = key.to_ascii_uppercase();
304        match name.as_str() {
305            "LT" => opts.lower = parse_bool_field("LT", raw_value)?,
306            "UT" => opts.upper = parse_bool_field("UT", raw_value)?,
307            "RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
308            "SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
309            "POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
310            "TRANSA" => {
311                let transa = parse_transa(raw_value)?;
312                opts.transposed = transa != TransposeMode::None;
313                opts.conjugate = transa == TransposeMode::Conjugate;
314            }
315            "RCOND" => {
316                let threshold = parse_scalar_f64("RCOND", raw_value)?;
317                if threshold < 0.0 {
318                    return Err(builtin_error("linsolve: RCOND must be non-negative"));
319                }
320                opts.rcond = Some(threshold);
321            }
322            other => return Err(builtin_error(format!("linsolve: unknown option '{other}'"))),
323        }
324    }
325    if opts.lower && opts.upper {
326        return Err(builtin_error("linsolve: LT and UT are mutually exclusive."));
327    }
328    Ok(opts)
329}
330
331fn parse_bool_field(name: &str, value: &Value) -> BuiltinResult<bool> {
332    match value {
333        Value::Bool(b) => Ok(*b),
334        Value::Int(i) => Ok(!i.is_zero()),
335        Value::Num(n) => Ok(*n != 0.0),
336        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
337        Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
338        other => Err(builtin_error(format!(
339            "linsolve: option '{name}' must be logical or numeric, got {other:?}"
340        ))),
341    }
342}
343
344fn parse_scalar_f64(name: &str, value: &Value) -> BuiltinResult<f64> {
345    match value {
346        Value::Num(n) => Ok(*n),
347        Value::Int(i) => Ok(i.to_f64()),
348        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
349        other => Err(builtin_error(format!(
350            "linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
351        ))),
352    }
353}
354
355#[derive(Copy, Clone, PartialEq, Eq)]
356enum TransposeMode {
357    None,
358    Transpose,
359    Conjugate,
360}
361
362fn parse_transa(value: &Value) -> BuiltinResult<TransposeMode> {
363    let text = tensor::value_to_string(value).ok_or_else(|| {
364        builtin_error("linsolve: TRANSA must be a character vector or string scalar")
365    })?;
366    if text.is_empty() {
367        return Err(builtin_error("linsolve: TRANSA cannot be empty"));
368    }
369    match text.trim().to_ascii_uppercase().as_str() {
370        "N" => Ok(TransposeMode::None),
371        "T" => Ok(TransposeMode::Transpose),
372        "C" => Ok(TransposeMode::Conjugate),
373        other => Err(builtin_error(format!(
374            "linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
375        ))),
376    }
377}
378
379enum NumericInput {
380    Real(Tensor),
381    Complex(ComplexTensor),
382}
383
384enum NumericPair {
385    Real(Tensor, Tensor),
386    Complex(ComplexTensor, ComplexTensor),
387}
388
389async fn coerce_numeric_pair(lhs: Value, rhs: Value) -> BuiltinResult<NumericPair> {
390    let lhs_num = coerce_numeric(lhs).await?;
391    let rhs_num = coerce_numeric(rhs).await?;
392    match (lhs_num, rhs_num) {
393        (NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
394            Ok(NumericPair::Real(lhs_r, rhs_r))
395        }
396        (NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
397            Ok(NumericPair::Complex(lhs_c, rhs_c))
398        }
399        (NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
400            let rhs_c = promote_real_tensor(&rhs_r)?;
401            Ok(NumericPair::Complex(lhs_c, rhs_c))
402        }
403        (NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
404            let lhs_c = promote_real_tensor(&lhs_r)?;
405            Ok(NumericPair::Complex(lhs_c, rhs_c))
406        }
407    }
408}
409
410async fn coerce_numeric(value: Value) -> BuiltinResult<NumericInput> {
411    match value {
412        Value::Tensor(tensor) => {
413            ensure_matrix_shape(NAME, &tensor.shape)?;
414            Ok(NumericInput::Real(tensor))
415        }
416        Value::LogicalArray(logical) => {
417            let tensor = tensor::logical_to_tensor(&logical).map_err(builtin_error)?;
418            ensure_matrix_shape(NAME, &tensor.shape)?;
419            Ok(NumericInput::Real(tensor))
420        }
421        Value::Num(n) => {
422            let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(builtin_error)?;
423            Ok(NumericInput::Real(tensor))
424        }
425        Value::Int(i) => {
426            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(builtin_error)?;
427            Ok(NumericInput::Real(tensor))
428        }
429        Value::Bool(b) => {
430            let tensor =
431                Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1]).map_err(builtin_error)?;
432            Ok(NumericInput::Real(tensor))
433        }
434        Value::Complex(re, im) => {
435            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(builtin_error)?;
436            Ok(NumericInput::Complex(tensor))
437        }
438        Value::ComplexTensor(ct) => {
439            ensure_matrix_shape(NAME, &ct.shape)?;
440            Ok(NumericInput::Complex(ct))
441        }
442        Value::GpuTensor(handle) => {
443            let tensor = gpu_helpers::gather_tensor_async(&handle)
444                .await
445                .map_err(map_control_flow)?;
446            ensure_matrix_shape(NAME, &tensor.shape)?;
447            Ok(NumericInput::Real(tensor))
448        }
449        other => Err(builtin_error(format!(
450            "{NAME}: unsupported input type {:?}; convert to numeric values first",
451            other
452        ))),
453    }
454}
455
456fn contains_complex(value: &Value) -> bool {
457    matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
458}
459
460fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
461    crate::builtins::common::shape::is_scalar_shape(&handle.shape)
462}
463
464struct PreparedOperand {
465    handle: GpuTensorHandle,
466    owned: bool,
467}
468
469impl PreparedOperand {
470    fn borrowed(handle: &GpuTensorHandle) -> Self {
471        Self {
472            handle: handle.clone(),
473            owned: false,
474        }
475    }
476
477    fn owned(handle: GpuTensorHandle) -> Self {
478        Self {
479            handle,
480            owned: true,
481        }
482    }
483
484    fn handle(&self) -> &GpuTensorHandle {
485        &self.handle
486    }
487}
488
489fn prepare_gpu_operand(
490    value: &Value,
491    provider: &'static dyn AccelProvider,
492) -> BuiltinResult<Option<PreparedOperand>> {
493    match value {
494        Value::GpuTensor(handle) => {
495            if is_scalar_handle(handle) {
496                Ok(None)
497            } else {
498                Ok(Some(PreparedOperand::borrowed(handle)))
499            }
500        }
501        Value::Tensor(tensor) => {
502            if tensor::is_scalar_tensor(tensor) {
503                Ok(None)
504            } else {
505                let uploaded = upload_tensor(provider, tensor)?;
506                Ok(Some(PreparedOperand::owned(uploaded)))
507            }
508        }
509        Value::LogicalArray(logical) => {
510            if logical.data.len() == 1 {
511                Ok(None)
512            } else {
513                let tensor = tensor::logical_to_tensor(logical).map_err(builtin_error)?;
514                let uploaded = upload_tensor(provider, &tensor)?;
515                Ok(Some(PreparedOperand::owned(uploaded)))
516            }
517        }
518        _ => Ok(None),
519    }
520}
521
522fn upload_tensor(
523    provider: &'static dyn AccelProvider,
524    tensor: &Tensor,
525) -> BuiltinResult<GpuTensorHandle> {
526    let view = HostTensorView {
527        data: &tensor.data,
528        shape: &tensor.shape,
529    };
530    provider
531        .upload(&view)
532        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
533}
534
535fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
536    if operand.owned {
537        let _ = provider.free(&operand.handle);
538        operand.owned = false;
539    }
540}
541
542fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> BuiltinResult<(Tensor, f64)> {
543    let mut lhs_effective = lhs;
544    let mut rhs_effective = rhs;
545    let mut lower = options.lower;
546    let mut upper = options.upper;
547
548    if options.transposed {
549        lhs_effective = transpose_tensor(&lhs_effective);
550        if options.conjugate {
551            conjugate_in_place(&mut lhs_effective);
552        }
553        if lower || upper {
554            std::mem::swap(&mut lower, &mut upper);
555        }
556    }
557
558    rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
559
560    if lower {
561        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
562        let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
563        enforce_rcond(options, rcond)?;
564        return Ok((solution, rcond));
565    }
566
567    if upper {
568        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
569        let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
570        enforce_rcond(options, rcond)?;
571        return Ok((solution, rcond));
572    }
573
574    let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
575    enforce_rcond(options, rcond)?;
576    Ok((solution, rcond))
577}
578
579fn solve_complex(
580    lhs: ComplexTensor,
581    rhs: ComplexTensor,
582    options: &SolveOptions,
583) -> BuiltinResult<(ComplexTensor, f64)> {
584    let mut lhs_effective = lhs;
585    let mut rhs_effective = rhs;
586    let mut lower = options.lower;
587    let mut upper = options.upper;
588
589    if options.transposed {
590        lhs_effective = transpose_complex(&lhs_effective);
591        if options.conjugate {
592            conjugate_complex_in_place(&mut lhs_effective);
593        }
594        if lower || upper {
595            std::mem::swap(&mut lower, &mut upper);
596        }
597    }
598
599    rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
600
601    if lower {
602        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
603        let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
604        enforce_rcond(options, rcond)?;
605        return Ok((solution, rcond));
606    }
607
608    if upper {
609        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
610        let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
611        enforce_rcond(options, rcond)?;
612        return Ok((solution, rcond));
613    }
614
615    let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
616    enforce_rcond(options, rcond)?;
617    Ok((solution, rcond))
618}
619
620fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
621    let n = lhs.rows();
622    let nrhs = rhs.data.len() / n;
623    let mut solution = rhs.data.clone();
624    let mut min_diag = f64::INFINITY;
625    let mut max_diag = 0.0_f64;
626
627    for col in 0..nrhs {
628        for i in 0..n {
629            let diag = lhs.data[i + i * n];
630            let diag_abs = diag.abs();
631            min_diag = min_diag.min(diag_abs);
632            max_diag = max_diag.max(diag_abs);
633            if diag_abs == 0.0 {
634                return Err(builtin_error(
635                    "linsolve: matrix is singular to working precision.",
636                ));
637            }
638            let mut accum = 0.0;
639            for j in 0..i {
640                accum += lhs.data[i + j * n] * solution[j + col * n];
641            }
642            let rhs_value = solution[i + col * n] - accum;
643            solution[i + col * n] = rhs_value / diag;
644        }
645    }
646
647    let rcond = diagonal_rcond(min_diag, max_diag);
648    let tensor = Tensor::new(solution, rhs.shape.clone())
649        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
650    Ok((tensor, rcond))
651}
652
653fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
654    let n = lhs.rows();
655    let nrhs = rhs.data.len() / n;
656    let mut solution = rhs.data.clone();
657    let mut min_diag = f64::INFINITY;
658    let mut max_diag = 0.0_f64;
659
660    for col in 0..nrhs {
661        for row_rev in 0..n {
662            let i = n - 1 - row_rev;
663            let diag = lhs.data[i + i * n];
664            let diag_abs = diag.abs();
665            min_diag = min_diag.min(diag_abs);
666            max_diag = max_diag.max(diag_abs);
667            if diag_abs == 0.0 {
668                return Err(builtin_error(
669                    "linsolve: matrix is singular to working precision.",
670                ));
671            }
672            let mut accum = 0.0;
673            for j in (i + 1)..n {
674                accum += lhs.data[i + j * n] * solution[j + col * n];
675            }
676            let rhs_value = solution[i + col * n] - accum;
677            solution[i + col * n] = rhs_value / diag;
678        }
679    }
680
681    let rcond = diagonal_rcond(min_diag, max_diag);
682    let tensor = Tensor::new(solution, rhs.shape.clone())
683        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
684    Ok((tensor, rcond))
685}
686
687fn forward_substitution_complex(
688    lhs: &ComplexTensor,
689    rhs: &ComplexTensor,
690) -> BuiltinResult<(ComplexTensor, f64)> {
691    let n = lhs.rows;
692    let nrhs = rhs.data.len() / n;
693    let lhs_data: Vec<Complex64> = lhs
694        .data
695        .iter()
696        .map(|&(re, im)| Complex64::new(re, im))
697        .collect();
698    let mut solution: Vec<Complex64> = rhs
699        .data
700        .iter()
701        .map(|&(re, im)| Complex64::new(re, im))
702        .collect();
703    let mut min_diag = f64::INFINITY;
704    let mut max_diag = 0.0_f64;
705
706    for col in 0..nrhs {
707        for i in 0..n {
708            let diag = lhs_data[i + i * n];
709            let diag_abs = diag.norm();
710            min_diag = min_diag.min(diag_abs);
711            max_diag = max_diag.max(diag_abs);
712            if diag_abs == 0.0 {
713                return Err(builtin_error(
714                    "linsolve: matrix is singular to working precision.",
715                ));
716            }
717            let mut accum = Complex64::new(0.0, 0.0);
718            for j in 0..i {
719                accum += lhs_data[i + j * n] * solution[j + col * n];
720            }
721            let rhs_value = solution[i + col * n] - accum;
722            solution[i + col * n] = rhs_value / diag;
723        }
724    }
725
726    let rcond = diagonal_rcond(min_diag, max_diag);
727    let tensor = ComplexTensor::new(
728        solution.iter().map(|c| (c.re, c.im)).collect(),
729        rhs.shape.clone(),
730    )
731    .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
732    Ok((tensor, rcond))
733}
734
735fn backward_substitution_complex(
736    lhs: &ComplexTensor,
737    rhs: &ComplexTensor,
738) -> BuiltinResult<(ComplexTensor, f64)> {
739    let n = lhs.rows;
740    let nrhs = rhs.data.len() / n;
741    let lhs_data: Vec<Complex64> = lhs
742        .data
743        .iter()
744        .map(|&(re, im)| Complex64::new(re, im))
745        .collect();
746    let mut solution: Vec<Complex64> = rhs
747        .data
748        .iter()
749        .map(|&(re, im)| Complex64::new(re, im))
750        .collect();
751    let mut min_diag = f64::INFINITY;
752    let mut max_diag = 0.0_f64;
753
754    for col in 0..nrhs {
755        for row_rev in 0..n {
756            let i = n - 1 - row_rev;
757            let diag = lhs_data[i + i * n];
758            let diag_abs = diag.norm();
759            min_diag = min_diag.min(diag_abs);
760            max_diag = max_diag.max(diag_abs);
761            if diag_abs == 0.0 {
762                return Err(builtin_error(
763                    "linsolve: matrix is singular to working precision.",
764                ));
765            }
766            let mut accum = Complex64::new(0.0, 0.0);
767            for j in (i + 1)..n {
768                accum += lhs_data[i + j * n] * solution[j + col * n];
769            }
770            let rhs_value = solution[i + col * n] - accum;
771            solution[i + col * n] = rhs_value / diag;
772        }
773    }
774
775    let rcond = diagonal_rcond(min_diag, max_diag);
776    let tensor = ComplexTensor::new(
777        solution.iter().map(|c| (c.re, c.im)).collect(),
778        rhs.shape.clone(),
779    )
780    .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
781    Ok((tensor, rcond))
782}
783
784fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
785    let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
786    let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
787    let svd = SVD::new(a.clone(), true, true);
788    let rcond = singular_value_rcond(svd.singular_values.as_slice());
789    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
790    let solution = svd
791        .solve(&b, tol)
792        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
793    let tensor = matrix_real_to_tensor(solution)?;
794    Ok((tensor, rcond))
795}
796
797fn solve_general_complex(
798    lhs: &ComplexTensor,
799    rhs: &ComplexTensor,
800) -> BuiltinResult<(ComplexTensor, f64)> {
801    let a_data: Vec<Complex64> = lhs
802        .data
803        .iter()
804        .map(|&(re, im)| Complex64::new(re, im))
805        .collect();
806    let b_data: Vec<Complex64> = rhs
807        .data
808        .iter()
809        .map(|&(re, im)| Complex64::new(re, im))
810        .collect();
811    let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
812    let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
813    let svd = SVD::new(a.clone(), true, true);
814    let rcond = singular_value_rcond(svd.singular_values.as_slice());
815    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
816    let solution = svd
817        .solve(&b, tol)
818        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
819    let tensor = matrix_complex_to_tensor(solution)?;
820    Ok((tensor, rcond))
821}
822
823fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> BuiltinResult<Tensor> {
824    if rhs.rows() == expected_rows {
825        return Ok(rhs);
826    }
827    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
828        return Tensor::new(rhs.data, vec![expected_rows, 1])
829            .map_err(|e| builtin_error(format!("{NAME}: {e}")));
830    }
831    if rhs.data.is_empty() && expected_rows == 0 {
832        return Ok(rhs);
833    }
834    Err(builtin_error("Matrix dimensions must agree."))
835}
836
837fn normalize_rhs_complex(rhs: ComplexTensor, expected_rows: usize) -> BuiltinResult<ComplexTensor> {
838    if rhs.rows == expected_rows {
839        return Ok(rhs);
840    }
841    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
842        return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
843            .map_err(|e| builtin_error(format!("{NAME}: {e}")));
844    }
845    if rhs.data.is_empty() && expected_rows == 0 {
846        return Ok(rhs);
847    }
848    Err(builtin_error("Matrix dimensions must agree."))
849}
850
851fn enforce_rcond(options: &SolveOptions, rcond: f64) -> BuiltinResult<()> {
852    if let Some(threshold) = options.rcond {
853        if rcond < threshold {
854            return Err(builtin_error(
855                "linsolve: matrix is singular to working precision.",
856            ));
857        }
858    }
859    Ok(())
860}
861
862fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
863    let max_sv = singular_values
864        .iter()
865        .copied()
866        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
867    let max_dim = rows.max(cols) as f64;
868    f64::EPSILON * max_dim * max_sv.max(1.0)
869}
870
871fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> BuiltinResult<Tensor> {
872    let rows = matrix.nrows();
873    let cols = matrix.ncols();
874    Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols])
875        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
876}
877
878fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> BuiltinResult<ComplexTensor> {
879    let rows = matrix.nrows();
880    let cols = matrix.ncols();
881    let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
882    ComplexTensor::new(data, vec![rows, cols]).map_err(|e| builtin_error(format!("{NAME}: {e}")))
883}
884
885fn promote_real_tensor(tensor: &Tensor) -> BuiltinResult<ComplexTensor> {
886    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
887    ComplexTensor::new(data, tensor.shape.clone())
888        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
889}
890
891fn ensure_matrix_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
892    if is_effectively_matrix(shape) {
893        Ok(())
894    } else {
895        Err(builtin_error(format!(
896            "{name}: inputs must be 2-D matrices or vectors"
897        )))
898    }
899}
900
901fn is_effectively_matrix(shape: &[usize]) -> bool {
902    match shape.len() {
903        0..=2 => true,
904        _ => shape.iter().skip(2).all(|&dim| dim == 1),
905    }
906}
907
908fn ensure_square(rows: usize, cols: usize) -> BuiltinResult<()> {
909    if rows == cols {
910        Ok(())
911    } else {
912        Err(builtin_error(
913            "linsolve: triangular solves require a square coefficient matrix.",
914        ))
915    }
916}
917
918fn transpose_tensor(tensor: &Tensor) -> Tensor {
919    let rows = tensor.rows();
920    let cols = tensor.cols();
921    let mut data = vec![0.0; tensor.data.len()];
922    for r in 0..rows {
923        for c in 0..cols {
924            data[c + r * cols] = tensor.data[r + c * rows];
925        }
926    }
927    Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
928}
929
930fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
931    let rows = tensor.rows;
932    let cols = tensor.cols;
933    let mut data = vec![(0.0, 0.0); tensor.data.len()];
934    for r in 0..rows {
935        for c in 0..cols {
936            data[c + r * cols] = tensor.data[r + c * rows];
937        }
938    }
939    ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
940}
941
942fn conjugate_in_place(_tensor: &mut Tensor) {
943    // Real-valued matrices are unaffected by conjugation.
944}
945
946fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
947    for value in &mut tensor.data {
948        value.1 = -value.1;
949    }
950}
951
952#[cfg(test)]
953pub(crate) mod tests {
954    use super::*;
955    use futures::executor::block_on;
956    use runmat_accelerate_api::HostTensorView;
957    use runmat_builtins::{CharArray, ResolveContext, StructValue, Type};
958    fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
959        err
960    }
961
962    fn approx_eq(actual: f64, expected: f64) {
963        assert!((actual - expected).abs() < 1e-7);
964    }
965
966    fn evaluate_args(a: Value, b: Value, rest: &[Value]) -> Result<LinsolveEval, RuntimeError> {
967        block_on(super::evaluate_args(a, b, rest))
968    }
969
970    #[test]
971    fn linsolve_type_uses_rhs_columns() {
972        let out = left_divide_type(
973            &[
974                Type::Tensor {
975                    shape: Some(vec![Some(2), Some(2)]),
976                },
977                Type::Tensor {
978                    shape: Some(vec![Some(2), Some(3)]),
979                },
980            ],
981            &ResolveContext::new(Vec::new()),
982        );
983        assert_eq!(
984            out,
985            Type::Tensor {
986                shape: Some(vec![Some(2), Some(3)])
987            }
988        );
989    }
990
991    use crate::builtins::common::test_support;
992    use runmat_accelerate_api::ProviderTelemetry;
993
994    fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
995        block_on(super::linsolve_builtin(lhs, rhs, rest))
996    }
997
998    fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> BuiltinResult<LinsolveEval> {
999        block_on(super::evaluate(lhs, rhs, options))
1000    }
1001
1002    fn fallback_count(telemetry: &ProviderTelemetry, reason: &str) -> u64 {
1003        telemetry
1004            .solve_fallbacks
1005            .iter()
1006            .find(|entry| entry.reason == reason)
1007            .map(|entry| entry.count)
1008            .unwrap_or(0)
1009    }
1010
1011    #[cfg(feature = "wgpu")]
1012    fn kernel_launch_count(telemetry: &ProviderTelemetry, kernel: &str) -> usize {
1013        telemetry
1014            .kernel_launches
1015            .iter()
1016            .filter(|entry| entry.kernel == kernel)
1017            .count()
1018    }
1019
1020    fn clear_accel_provider_state() {
1021        runmat_accelerate_api::set_thread_provider(None);
1022        runmat_accelerate_api::clear_provider();
1023    }
1024
1025    fn host_linsolve_real(
1026        a: &Tensor,
1027        b: &Tensor,
1028        options: ProviderLinsolveOptions,
1029    ) -> (Tensor, f64) {
1030        super::linsolve_host_real_for_provider(a, b, &options).expect("host linsolve")
1031    }
1032
1033    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1034    #[test]
1035    fn linsolve_basic_square() {
1036        let _accel_guard = test_support::accel_test_lock();
1037        clear_accel_provider_state();
1038        let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
1039        let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1040        let result =
1041            linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
1042        let t = test_support::gather(result).expect("gather");
1043        assert_eq!(t.shape, vec![2, 1]);
1044        approx_eq(t.data[0], 1.0);
1045        approx_eq(t.data[1], 2.0);
1046    }
1047
1048    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1049    #[test]
1050    fn linsolve_lower_triangular_hint() {
1051        let a = Tensor::new(
1052            vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
1053            vec![3, 3],
1054        )
1055        .unwrap();
1056        let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
1057        let mut opts = StructValue::new();
1058        opts.fields.insert("LT".to_string(), Value::Bool(true));
1059        let result = linsolve_builtin(
1060            Value::Tensor(a),
1061            Value::Tensor(b),
1062            vec![Value::Struct(opts)],
1063        )
1064        .expect("linsolve");
1065        let tensor = test_support::gather(result).expect("gather");
1066        assert_eq!(tensor.shape, vec![3, 1]);
1067        approx_eq(tensor.data[0], 3.0);
1068        approx_eq(tensor.data[1], 2.0);
1069        approx_eq(tensor.data[2], 1.0);
1070    }
1071
1072    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1073    #[test]
1074    fn linsolve_transposed_triangular_hint() {
1075        let _accel_guard = test_support::accel_test_lock();
1076        clear_accel_provider_state();
1077        let a = Tensor::new(
1078            vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
1079            vec![3, 3],
1080        )
1081        .unwrap();
1082        let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
1083        let mut opts = StructValue::new();
1084        opts.fields.insert("LT".to_string(), Value::Bool(true));
1085        opts.fields.insert(
1086            "TRANSA".to_string(),
1087            Value::CharArray(CharArray::new_row("T")),
1088        );
1089
1090        let result = linsolve_builtin(
1091            Value::Tensor(a.clone()),
1092            Value::Tensor(b.clone()),
1093            vec![Value::Struct(opts)],
1094        )
1095        .expect("linsolve");
1096        let tensor = test_support::gather(result).expect("gather");
1097        assert_eq!(tensor.shape, vec![3, 1]);
1098
1099        let a_transposed = transpose_tensor(&a);
1100        let (expected_tensor, _) =
1101            host_linsolve_real(&a_transposed, &b, ProviderLinsolveOptions::default());
1102
1103        for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
1104            approx_eq(*actual, *expected);
1105        }
1106    }
1107
1108    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1109    #[test]
1110    fn linsolve_complex_inputs_match_residual() {
1111        let a = ComplexTensor::new(
1112            vec![(2.0, 1.0), (-1.0, 0.0), (1.0, -2.0), (3.0, -2.0)],
1113            vec![2, 2],
1114        )
1115        .unwrap();
1116        let b = ComplexTensor::new(vec![(1.0, 0.0), (4.0, 1.0)], vec![2, 1]).unwrap();
1117        let result = linsolve_builtin(
1118            Value::ComplexTensor(a.clone()),
1119            Value::ComplexTensor(b.clone()),
1120            Vec::new(),
1121        )
1122        .expect("linsolve");
1123        let Value::ComplexTensor(out) = result else {
1124            panic!("expected complex tensor result");
1125        };
1126
1127        let mat_a: Vec<Complex64> = a
1128            .data
1129            .iter()
1130            .map(|&(re, im)| Complex64::new(re, im))
1131            .collect();
1132        let mat_b: Vec<Complex64> = b
1133            .data
1134            .iter()
1135            .map(|&(re, im)| Complex64::new(re, im))
1136            .collect();
1137        let mat_x: Vec<Complex64> = out
1138            .data
1139            .iter()
1140            .map(|&(re, im)| Complex64::new(re, im))
1141            .collect();
1142        let a_mat = DMatrix::from_column_slice(a.rows, a.cols, &mat_a);
1143        let b_mat = DMatrix::from_column_slice(b.rows, b.cols, &mat_b);
1144        let x_mat = DMatrix::from_column_slice(out.rows, out.cols, &mat_x);
1145        let residual = a_mat * x_mat - b_mat;
1146        assert!(residual.norm() < 1e-10, "residual={}", residual.norm());
1147    }
1148
1149    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1150    #[test]
1151    fn linsolve_complex_conjugate_transpose_matches_explicit_reference() {
1152        let a = ComplexTensor::new(
1153            vec![(2.0, 1.0), (0.0, -1.0), (1.0, 2.0), (3.0, 0.5)],
1154            vec![2, 2],
1155        )
1156        .unwrap();
1157        let b = ComplexTensor::new(vec![(1.0, -1.0), (2.0, 0.5)], vec![2, 1]).unwrap();
1158
1159        let mut opts = StructValue::new();
1160        opts.fields.insert(
1161            "TRANSA".to_string(),
1162            Value::CharArray(CharArray::new_row("C")),
1163        );
1164        let result = linsolve_builtin(
1165            Value::ComplexTensor(a.clone()),
1166            Value::ComplexTensor(b.clone()),
1167            vec![Value::Struct(opts)],
1168        )
1169        .expect("linsolve");
1170        let Value::ComplexTensor(out) = result else {
1171            panic!("expected complex tensor result");
1172        };
1173
1174        let mut a_conj_t = transpose_complex(&a);
1175        conjugate_complex_in_place(&mut a_conj_t);
1176        let reference = evaluate(
1177            Value::ComplexTensor(a_conj_t),
1178            Value::ComplexTensor(b.clone()),
1179            SolveOptions::default(),
1180        )
1181        .expect("reference");
1182        let Value::ComplexTensor(expected) = reference.solution() else {
1183            panic!("expected complex tensor reference");
1184        };
1185
1186        assert_eq!(out.shape, expected.shape);
1187        for ((out_re, out_im), (exp_re, exp_im)) in out.data.iter().zip(expected.data.iter()) {
1188            assert!(
1189                (out_re - exp_re).abs() < 1e-10,
1190                "out_re={out_re} exp_re={exp_re}"
1191            );
1192            assert!(
1193                (out_im - exp_im).abs() < 1e-10,
1194                "out_im={out_im} exp_im={exp_im}"
1195            );
1196        }
1197    }
1198
1199    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1200    #[test]
1201    fn linsolve_rcond_enforced() {
1202        let _accel_guard = test_support::accel_test_lock();
1203        clear_accel_provider_state();
1204        let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
1205        let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
1206        let mut opts = StructValue::new();
1207        opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
1208        let err = unwrap_error(
1209            linsolve_builtin(
1210                Value::Tensor(a),
1211                Value::Tensor(b),
1212                vec![Value::Struct(opts)],
1213            )
1214            .expect_err("singular matrix must fail"),
1215        );
1216        assert!(
1217            err.message().contains("singular to working precision"),
1218            "unexpected error message: {err}"
1219        );
1220    }
1221
1222    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1223    #[test]
1224    fn linsolve_recovers_rcond_output() {
1225        let _accel_guard = test_support::accel_test_lock();
1226        clear_accel_provider_state();
1227        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1228        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1229        let eval = evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1230            .expect("evaluate");
1231        let solution_tensor = match eval.solution() {
1232            Value::Tensor(sol) => sol.clone(),
1233            Value::GpuTensor(handle) => {
1234                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
1235            }
1236            other => panic!("unexpected solution value {other:?}"),
1237        };
1238        assert_eq!(solution_tensor.shape, vec![2, 1]);
1239        approx_eq(solution_tensor.data[0], 1.0);
1240        approx_eq(solution_tensor.data[1], 2.0);
1241
1242        let rcond_value = match eval.reciprocal_condition() {
1243            Value::Num(r) => r,
1244            Value::GpuTensor(handle) => {
1245                let gathered =
1246                    test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
1247                gathered.data[0]
1248            }
1249            other => panic!("unexpected rcond value {other:?}"),
1250        };
1251        approx_eq(rcond_value, 1.0);
1252    }
1253
1254    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1255    #[test]
1256    fn gpu_round_trip_matches_cpu() {
1257        test_support::with_test_provider(|provider| {
1258            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1259            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1260
1261            let cpu = linsolve_builtin(
1262                Value::Tensor(a.clone()),
1263                Value::Tensor(b.clone()),
1264                Vec::new(),
1265            )
1266            .expect("cpu linsolve");
1267            let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1268
1269            let view_a = HostTensorView {
1270                data: &a.data,
1271                shape: &a.shape,
1272            };
1273            let view_b = HostTensorView {
1274                data: &b.data,
1275                shape: &b.shape,
1276            };
1277            let ha = provider.upload(&view_a).expect("upload A");
1278            let hb = provider.upload(&view_b).expect("upload B");
1279
1280            let gpu_value = linsolve_builtin(
1281                Value::GpuTensor(ha.clone()),
1282                Value::GpuTensor(hb.clone()),
1283                Vec::new(),
1284            )
1285            .expect("gpu linsolve");
1286            let gathered = test_support::gather(gpu_value).expect("gather");
1287            let _ = provider.free(&ha);
1288            let _ = provider.free(&hb);
1289
1290            assert_eq!(gathered.shape, cpu_tensor.shape);
1291            for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1292                assert!((gpu - cpu).abs() < 1e-12);
1293            }
1294        });
1295    }
1296
1297    #[test]
1298    fn host_inputs_auto_promote_into_provider_solve_path() {
1299        test_support::with_test_provider(|provider| {
1300            provider.reset_telemetry();
1301            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1302            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1303            let _ = linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1304                .expect("host linsolve");
1305            let telemetry = provider.telemetry_snapshot();
1306            assert_eq!(telemetry.linsolve.count, 1);
1307            assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1308            assert!(telemetry.upload_bytes > 0);
1309            assert!(telemetry.download_bytes > 0);
1310        });
1311    }
1312
1313    #[test]
1314    fn provider_telemetry_records_gpu_host_reupload_path() {
1315        test_support::with_test_provider(|provider| {
1316            provider.reset_telemetry();
1317            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1318            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1319            let ha = provider
1320                .upload(&HostTensorView {
1321                    data: &a.data,
1322                    shape: &a.shape,
1323                })
1324                .expect("upload A");
1325            let hb = provider
1326                .upload(&HostTensorView {
1327                    data: &b.data,
1328                    shape: &b.shape,
1329                })
1330                .expect("upload B");
1331
1332            let _ = linsolve_builtin(
1333                Value::GpuTensor(ha.clone()),
1334                Value::GpuTensor(hb.clone()),
1335                Vec::new(),
1336            )
1337            .expect("gpu linsolve");
1338
1339            let telemetry = provider.telemetry_snapshot();
1340            assert_eq!(telemetry.linsolve.count, 1);
1341            assert!(telemetry.upload_bytes > 0);
1342            assert!(telemetry.download_bytes > 0);
1343            assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1344
1345            let _ = provider.free(&ha);
1346            let _ = provider.free(&hb);
1347        });
1348    }
1349
1350    #[test]
1351    fn scalar_gpu_inputs_fall_back_without_provider_solve_dispatch() {
1352        test_support::with_test_provider(|provider| {
1353            provider.reset_telemetry();
1354            let a = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1355            let b = Tensor::new(vec![6.0], vec![1, 1]).unwrap();
1356            let ha = provider
1357                .upload(&HostTensorView {
1358                    data: &a.data,
1359                    shape: &a.shape,
1360                })
1361                .expect("upload A");
1362            let hb = provider
1363                .upload(&HostTensorView {
1364                    data: &b.data,
1365                    shape: &b.shape,
1366                })
1367                .expect("upload B");
1368
1369            let result = linsolve_builtin(
1370                Value::GpuTensor(ha.clone()),
1371                Value::GpuTensor(hb.clone()),
1372                Vec::new(),
1373            )
1374            .expect("fallback linsolve");
1375            let gathered = test_support::gather(result).expect("gather fallback");
1376            assert_eq!(gathered.data, vec![3.0]);
1377
1378            let telemetry = provider.telemetry_snapshot();
1379            assert_eq!(telemetry.linsolve.count, 0);
1380            assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1381            assert!(telemetry.download_bytes > 0);
1382
1383            let _ = provider.free(&ha);
1384            let _ = provider.free(&hb);
1385        });
1386    }
1387
1388    #[cfg(feature = "wgpu")]
1389    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1390    #[test]
1391    fn wgpu_square_linsolve_avoids_host_reupload_fallback() {
1392        let _accel_guard = test_support::accel_test_lock();
1393        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1394            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1395        );
1396        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1397        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1398            return;
1399        }
1400        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1401        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1402
1403        let cpu = linsolve_builtin(
1404            Value::Tensor(a.clone()),
1405            Value::Tensor(b.clone()),
1406            Vec::new(),
1407        )
1408        .expect("cpu linsolve");
1409        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1410        provider.reset_telemetry();
1411
1412        let ha = provider
1413            .upload(&HostTensorView {
1414                data: &a.data,
1415                shape: &a.shape,
1416            })
1417            .expect("upload A");
1418        let hb = provider
1419            .upload(&HostTensorView {
1420                data: &b.data,
1421                shape: &b.shape,
1422            })
1423            .expect("upload B");
1424
1425        let _output_guard = crate::output_count::push_output_count(Some(1));
1426        let gpu_value = linsolve_builtin(
1427            Value::GpuTensor(ha.clone()),
1428            Value::GpuTensor(hb.clone()),
1429            Vec::new(),
1430        )
1431        .expect("gpu square linsolve");
1432        let gpu_solution = match gpu_value {
1433            Value::OutputList(mut outputs) => outputs.remove(0),
1434            other => other,
1435        };
1436        let gathered = test_support::gather(gpu_solution).expect("gather");
1437        let _ = provider.free(&ha);
1438        let _ = provider.free(&hb);
1439
1440        assert_eq!(gathered.shape, cpu_tensor.shape);
1441        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1442            assert!((gpu - cpu).abs() < 1e-4);
1443        }
1444
1445        let telemetry = provider.telemetry_snapshot();
1446        assert_eq!(telemetry.linsolve.count, 1);
1447        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1448        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 0);
1449        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1450    }
1451
1452    #[cfg(feature = "wgpu")]
1453    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1454    #[test]
1455    fn wgpu_square_linsolve_uses_device_path_without_output_count() {
1456        let _accel_guard = test_support::accel_test_lock();
1457        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1458            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1459        );
1460        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1461        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1462            return;
1463        }
1464        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1465        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1466
1467        let cpu = linsolve_builtin(
1468            Value::Tensor(a.clone()),
1469            Value::Tensor(b.clone()),
1470            Vec::new(),
1471        )
1472        .expect("cpu linsolve");
1473        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1474        provider.reset_telemetry();
1475
1476        let ha = provider
1477            .upload(&HostTensorView {
1478                data: &a.data,
1479                shape: &a.shape,
1480            })
1481            .expect("upload A");
1482        let hb = provider
1483            .upload(&HostTensorView {
1484                data: &b.data,
1485                shape: &b.shape,
1486            })
1487            .expect("upload B");
1488
1489        let gpu_value = linsolve_builtin(
1490            Value::GpuTensor(ha.clone()),
1491            Value::GpuTensor(hb.clone()),
1492            Vec::new(),
1493        )
1494        .expect("gpu square linsolve");
1495        let gathered = test_support::gather(gpu_value).expect("gather");
1496        let _ = provider.free(&ha);
1497        let _ = provider.free(&hb);
1498
1499        assert_eq!(gathered.shape, cpu_tensor.shape);
1500        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1501            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1502        }
1503
1504        let telemetry = provider.telemetry_snapshot();
1505        assert_eq!(telemetry.linsolve.count, 1);
1506        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1507        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1508    }
1509
1510    #[cfg(feature = "wgpu")]
1511    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1512    #[test]
1513    fn wgpu_square_linsolve_recovers_rcond_output_on_device() {
1514        let _accel_guard = test_support::accel_test_lock();
1515        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1516            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1517        );
1518        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1519        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1520            return;
1521        }
1522        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1523        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1524
1525        let (_, cpu_rcond) = host_linsolve_real(&a, &b, ProviderLinsolveOptions::default());
1526        provider.reset_telemetry();
1527
1528        let ha = provider
1529            .upload(&HostTensorView {
1530                data: &a.data,
1531                shape: &a.shape,
1532            })
1533            .expect("upload A");
1534        let hb = provider
1535            .upload(&HostTensorView {
1536                data: &b.data,
1537                shape: &b.shape,
1538            })
1539            .expect("upload B");
1540
1541        let _output_guard = crate::output_count::push_output_count(Some(2));
1542        let gpu_value = linsolve_builtin(
1543            Value::GpuTensor(ha.clone()),
1544            Value::GpuTensor(hb.clone()),
1545            Vec::new(),
1546        )
1547        .expect("gpu square linsolve");
1548        let outputs = match gpu_value {
1549            Value::OutputList(outputs) => outputs,
1550            other => panic!("expected output list, got {other:?}"),
1551        };
1552        assert_eq!(outputs.len(), 2);
1553        let gathered = test_support::gather(outputs[0].clone()).expect("gather");
1554        let gpu_rcond = match &outputs[1] {
1555            Value::Num(value) => *value,
1556            other => panic!("unexpected gpu rcond {other:?}"),
1557        };
1558        let _ = provider.free(&ha);
1559        let _ = provider.free(&hb);
1560
1561        assert_eq!(gathered.shape, vec![2, 1]);
1562        assert!(
1563            (gpu_rcond - cpu_rcond).abs() < 1e-4,
1564            "gpu={gpu_rcond} cpu={cpu_rcond}"
1565        );
1566
1567        let telemetry = provider.telemetry_snapshot();
1568        assert_eq!(telemetry.linsolve.count, 1);
1569        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1570        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1571    }
1572
1573    #[cfg(feature = "wgpu")]
1574    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1575    #[test]
1576    fn wgpu_square_linsolve_with_rcond_option_stays_on_device() {
1577        let _accel_guard = test_support::accel_test_lock();
1578        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1579            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1580        );
1581        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1582        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1583            return;
1584        }
1585
1586        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1587        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1588        let mut cpu_opts = StructValue::new();
1589        cpu_opts
1590            .fields
1591            .insert("RCOND".to_string(), Value::Num(0.05));
1592        let cpu = linsolve_builtin(
1593            Value::Tensor(a.clone()),
1594            Value::Tensor(b.clone()),
1595            vec![Value::Struct(cpu_opts)],
1596        )
1597        .expect("cpu linsolve");
1598        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1599        provider.reset_telemetry();
1600
1601        let ha = provider
1602            .upload(&HostTensorView {
1603                data: &a.data,
1604                shape: &a.shape,
1605            })
1606            .expect("upload A");
1607        let hb = provider
1608            .upload(&HostTensorView {
1609                data: &b.data,
1610                shape: &b.shape,
1611            })
1612            .expect("upload B");
1613
1614        let _output_guard = crate::output_count::push_output_count(Some(1));
1615        let mut gpu_opts = StructValue::new();
1616        gpu_opts
1617            .fields
1618            .insert("RCOND".to_string(), Value::Num(0.05));
1619        let gpu_value = linsolve_builtin(
1620            Value::GpuTensor(ha.clone()),
1621            Value::GpuTensor(hb.clone()),
1622            vec![Value::Struct(gpu_opts)],
1623        )
1624        .expect("gpu square linsolve");
1625        let gpu_solution = match gpu_value {
1626            Value::OutputList(mut outputs) => outputs.remove(0),
1627            other => other,
1628        };
1629        let gathered = test_support::gather(gpu_solution).expect("gather");
1630        let _ = provider.free(&ha);
1631        let _ = provider.free(&hb);
1632
1633        assert_eq!(gathered.shape, cpu_tensor.shape);
1634        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1635            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1636        }
1637
1638        let telemetry = provider.telemetry_snapshot();
1639        assert_eq!(telemetry.linsolve.count, 1);
1640        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1641        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1642    }
1643
1644    #[cfg(feature = "wgpu")]
1645    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1646    #[test]
1647    fn wgpu_tall_linsolve_avoids_host_reupload_fallback() {
1648        let _accel_guard = test_support::accel_test_lock();
1649        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1650            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1651        );
1652        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1653        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1654            return;
1655        }
1656        let a = Tensor::new(vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
1657        let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
1658
1659        let cpu = linsolve_builtin(
1660            Value::Tensor(a.clone()),
1661            Value::Tensor(b.clone()),
1662            Vec::new(),
1663        )
1664        .expect("cpu linsolve");
1665        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1666        provider.reset_telemetry();
1667
1668        let ha = provider
1669            .upload(&HostTensorView {
1670                data: &a.data,
1671                shape: &a.shape,
1672            })
1673            .expect("upload A");
1674        let hb = provider
1675            .upload(&HostTensorView {
1676                data: &b.data,
1677                shape: &b.shape,
1678            })
1679            .expect("upload B");
1680
1681        let _output_guard = crate::output_count::push_output_count(Some(1));
1682        let gpu_value = linsolve_builtin(
1683            Value::GpuTensor(ha.clone()),
1684            Value::GpuTensor(hb.clone()),
1685            Vec::new(),
1686        )
1687        .expect("gpu tall linsolve");
1688        let gpu_solution = match gpu_value {
1689            Value::OutputList(mut outputs) => outputs.remove(0),
1690            other => other,
1691        };
1692        let gathered = test_support::gather(gpu_solution).expect("gather");
1693        let _ = provider.free(&ha);
1694        let _ = provider.free(&hb);
1695
1696        assert_eq!(gathered.shape, cpu_tensor.shape);
1697        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1698            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1699        }
1700
1701        let telemetry = provider.telemetry_snapshot();
1702        assert_eq!(telemetry.linsolve.count, 1);
1703        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1704    }
1705
1706    #[cfg(feature = "wgpu")]
1707    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1708    #[test]
1709    fn wgpu_posdef_linsolve_avoids_host_reupload_fallback() {
1710        let _accel_guard = test_support::accel_test_lock();
1711        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1712            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1713        );
1714        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1715        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1716            return;
1717        }
1718        let a = Tensor::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1719        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1720
1721        let mut cpu_opts = StructValue::new();
1722        cpu_opts
1723            .fields
1724            .insert("POSDEF".to_string(), Value::Bool(true));
1725        let cpu = linsolve_builtin(
1726            Value::Tensor(a.clone()),
1727            Value::Tensor(b.clone()),
1728            vec![Value::Struct(cpu_opts)],
1729        )
1730        .expect("cpu linsolve");
1731        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1732        let (_, cpu_rcond) = host_linsolve_real(
1733            &a,
1734            &b,
1735            ProviderLinsolveOptions {
1736                posdef: true,
1737                ..Default::default()
1738            },
1739        );
1740        provider.reset_telemetry();
1741
1742        let ha = provider
1743            .upload(&HostTensorView {
1744                data: &a.data,
1745                shape: &a.shape,
1746            })
1747            .expect("upload A");
1748        let hb = provider
1749            .upload(&HostTensorView {
1750                data: &b.data,
1751                shape: &b.shape,
1752            })
1753            .expect("upload B");
1754
1755        let _output_guard = crate::output_count::push_output_count(Some(2));
1756        let mut gpu_opts = StructValue::new();
1757        gpu_opts
1758            .fields
1759            .insert("POSDEF".to_string(), Value::Bool(true));
1760        let gpu_value = linsolve_builtin(
1761            Value::GpuTensor(ha.clone()),
1762            Value::GpuTensor(hb.clone()),
1763            vec![Value::Struct(gpu_opts)],
1764        )
1765        .expect("gpu posdef linsolve");
1766        let mut outputs = match gpu_value {
1767            Value::OutputList(outputs) => outputs,
1768            other => panic!("expected output list, got {other:?}"),
1769        };
1770        let gpu_rcond = match outputs.remove(1) {
1771            Value::Num(value) => value,
1772            other => panic!("unexpected rcond value {other:?}"),
1773        };
1774        let gpu_solution = outputs.remove(0);
1775        let gathered = test_support::gather(gpu_solution).expect("gather");
1776        let _ = provider.free(&ha);
1777        let _ = provider.free(&hb);
1778
1779        assert_eq!(gathered.shape, cpu_tensor.shape);
1780        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1781            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1782        }
1783        assert!(
1784            (gpu_rcond - cpu_rcond).abs() < 1e-4,
1785            "gpu={gpu_rcond} cpu={cpu_rcond}"
1786        );
1787
1788        let telemetry = provider.telemetry_snapshot();
1789        assert_eq!(telemetry.linsolve.count, 1);
1790        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1791        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
1792        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
1793    }
1794
1795    #[cfg(feature = "wgpu")]
1796    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1797    #[test]
1798    fn wgpu_transposed_posdef_linsolve_uses_cholesky_path() {
1799        let _accel_guard = test_support::accel_test_lock();
1800        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1801            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1802        );
1803        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1804        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1805            return;
1806        }
1807        let a = Tensor::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
1808        let b = Tensor::new(vec![8.0, 9.0], vec![2, 1]).unwrap();
1809
1810        let mut cpu_opts = StructValue::new();
1811        cpu_opts
1812            .fields
1813            .insert("POSDEF".to_string(), Value::Bool(true));
1814        cpu_opts.fields.insert(
1815            "TRANSA".to_string(),
1816            Value::CharArray(CharArray::new_row("T")),
1817        );
1818        let cpu = linsolve_builtin(
1819            Value::Tensor(a.clone()),
1820            Value::Tensor(b.clone()),
1821            vec![Value::Struct(cpu_opts)],
1822        )
1823        .expect("cpu linsolve");
1824        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1825        provider.reset_telemetry();
1826
1827        let ha = provider
1828            .upload(&HostTensorView {
1829                data: &a.data,
1830                shape: &a.shape,
1831            })
1832            .expect("upload A");
1833        let hb = provider
1834            .upload(&HostTensorView {
1835                data: &b.data,
1836                shape: &b.shape,
1837            })
1838            .expect("upload B");
1839
1840        let _output_guard = crate::output_count::push_output_count(Some(1));
1841        let mut gpu_opts = StructValue::new();
1842        gpu_opts
1843            .fields
1844            .insert("POSDEF".to_string(), Value::Bool(true));
1845        gpu_opts.fields.insert(
1846            "TRANSA".to_string(),
1847            Value::CharArray(CharArray::new_row("T")),
1848        );
1849        let gpu_value = linsolve_builtin(
1850            Value::GpuTensor(ha.clone()),
1851            Value::GpuTensor(hb.clone()),
1852            vec![Value::Struct(gpu_opts)],
1853        )
1854        .expect("gpu transposed posdef linsolve");
1855        let gpu_solution = match gpu_value {
1856            Value::OutputList(mut outputs) => outputs.remove(0),
1857            other => other,
1858        };
1859        let gathered = test_support::gather(gpu_solution).expect("gather");
1860        let _ = provider.free(&ha);
1861        let _ = provider.free(&hb);
1862
1863        assert_eq!(gathered.shape, cpu_tensor.shape);
1864        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1865            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1866        }
1867
1868        let telemetry = provider.telemetry_snapshot();
1869        assert_eq!(telemetry.linsolve.count, 1);
1870        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1871        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
1872        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
1873    }
1874
1875    #[cfg(feature = "wgpu")]
1876    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1877    #[test]
1878    fn wgpu_symmetric_linsolve_avoids_host_reupload_fallback() {
1879        let _accel_guard = test_support::accel_test_lock();
1880        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1881            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1882        );
1883        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1884        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1885            return;
1886        }
1887        let a = Tensor::new(vec![5.0, 2.0, 2.0, 6.0], vec![2, 2]).unwrap();
1888        let b = Tensor::new(vec![9.0, 8.0], vec![2, 1]).unwrap();
1889
1890        let mut cpu_opts = StructValue::new();
1891        cpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
1892        let cpu = linsolve_builtin(
1893            Value::Tensor(a.clone()),
1894            Value::Tensor(b.clone()),
1895            vec![Value::Struct(cpu_opts)],
1896        )
1897        .expect("cpu linsolve");
1898        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1899        provider.reset_telemetry();
1900
1901        let ha = provider
1902            .upload(&HostTensorView {
1903                data: &a.data,
1904                shape: &a.shape,
1905            })
1906            .expect("upload A");
1907        let hb = provider
1908            .upload(&HostTensorView {
1909                data: &b.data,
1910                shape: &b.shape,
1911            })
1912            .expect("upload B");
1913
1914        let _output_guard = crate::output_count::push_output_count(Some(1));
1915        let mut gpu_opts = StructValue::new();
1916        gpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
1917        let gpu_value = linsolve_builtin(
1918            Value::GpuTensor(ha.clone()),
1919            Value::GpuTensor(hb.clone()),
1920            vec![Value::Struct(gpu_opts)],
1921        )
1922        .expect("gpu symmetric linsolve");
1923        let gpu_solution = match gpu_value {
1924            Value::OutputList(mut outputs) => outputs.remove(0),
1925            other => other,
1926        };
1927        let gathered = test_support::gather(gpu_solution).expect("gather");
1928        let _ = provider.free(&ha);
1929        let _ = provider.free(&hb);
1930
1931        assert_eq!(gathered.shape, cpu_tensor.shape);
1932        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1933            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1934        }
1935
1936        let telemetry = provider.telemetry_snapshot();
1937        assert_eq!(telemetry.linsolve.count, 1);
1938        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1939    }
1940
1941    #[cfg(feature = "wgpu")]
1942    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1943    #[test]
1944    fn wgpu_transposed_square_linsolve_avoids_host_reupload_fallback() {
1945        let _accel_guard = test_support::accel_test_lock();
1946        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1947            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1948        );
1949        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1950        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1951            return;
1952        }
1953        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1954        let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
1955
1956        let mut cpu_opts = StructValue::new();
1957        cpu_opts.fields.insert(
1958            "TRANSA".to_string(),
1959            Value::CharArray(CharArray::new_row("T")),
1960        );
1961        let cpu = linsolve_builtin(
1962            Value::Tensor(a.clone()),
1963            Value::Tensor(b.clone()),
1964            vec![Value::Struct(cpu_opts)],
1965        )
1966        .expect("cpu linsolve");
1967        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1968        provider.reset_telemetry();
1969
1970        let ha = provider
1971            .upload(&HostTensorView {
1972                data: &a.data,
1973                shape: &a.shape,
1974            })
1975            .expect("upload A");
1976        let hb = provider
1977            .upload(&HostTensorView {
1978                data: &b.data,
1979                shape: &b.shape,
1980            })
1981            .expect("upload B");
1982
1983        let _output_guard = crate::output_count::push_output_count(Some(1));
1984        let mut gpu_opts = StructValue::new();
1985        gpu_opts.fields.insert(
1986            "TRANSA".to_string(),
1987            Value::CharArray(CharArray::new_row("T")),
1988        );
1989        let gpu_value = linsolve_builtin(
1990            Value::GpuTensor(ha.clone()),
1991            Value::GpuTensor(hb.clone()),
1992            vec![Value::Struct(gpu_opts)],
1993        )
1994        .expect("gpu transposed square linsolve");
1995        let gpu_solution = match gpu_value {
1996            Value::OutputList(mut outputs) => outputs.remove(0),
1997            other => other,
1998        };
1999        let gathered = test_support::gather(gpu_solution).expect("gather");
2000        let _ = provider.free(&ha);
2001        let _ = provider.free(&hb);
2002
2003        assert_eq!(gathered.shape, cpu_tensor.shape);
2004        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2005            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2006        }
2007
2008        let telemetry = provider.telemetry_snapshot();
2009        assert_eq!(telemetry.linsolve.count, 1);
2010        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2011    }
2012
2013    #[cfg(feature = "wgpu")]
2014    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2015    #[test]
2016    fn wgpu_conjugate_square_linsolve_avoids_host_reupload_fallback_for_real_inputs() {
2017        let _accel_guard = test_support::accel_test_lock();
2018        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2019            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2020        );
2021        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2022        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2023            return;
2024        }
2025
2026        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2027        let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2028        let mut cpu_opts = StructValue::new();
2029        cpu_opts.fields.insert(
2030            "TRANSA".to_string(),
2031            Value::CharArray(CharArray::new_row("C")),
2032        );
2033        let cpu = linsolve_builtin(
2034            Value::Tensor(a.clone()),
2035            Value::Tensor(b.clone()),
2036            vec![Value::Struct(cpu_opts)],
2037        )
2038        .expect("cpu linsolve");
2039        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2040        provider.reset_telemetry();
2041
2042        let ha = provider
2043            .upload(&HostTensorView {
2044                data: &a.data,
2045                shape: &a.shape,
2046            })
2047            .expect("upload A");
2048        let hb = provider
2049            .upload(&HostTensorView {
2050                data: &b.data,
2051                shape: &b.shape,
2052            })
2053            .expect("upload B");
2054
2055        let _output_guard = crate::output_count::push_output_count(Some(1));
2056        let mut gpu_opts = StructValue::new();
2057        gpu_opts.fields.insert(
2058            "TRANSA".to_string(),
2059            Value::CharArray(CharArray::new_row("C")),
2060        );
2061        let gpu_value = linsolve_builtin(
2062            Value::GpuTensor(ha.clone()),
2063            Value::GpuTensor(hb.clone()),
2064            vec![Value::Struct(gpu_opts)],
2065        )
2066        .expect("gpu conjugate square linsolve");
2067        let gpu_solution = match gpu_value {
2068            Value::OutputList(mut outputs) => outputs.remove(0),
2069            other => other,
2070        };
2071        let gathered = test_support::gather(gpu_solution).expect("gather");
2072        let _ = provider.free(&ha);
2073        let _ = provider.free(&hb);
2074
2075        assert_eq!(gathered.shape, cpu_tensor.shape);
2076        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2077            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2078        }
2079
2080        let telemetry = provider.telemetry_snapshot();
2081        assert_eq!(telemetry.linsolve.count, 1);
2082        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2083        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
2084    }
2085
2086    #[cfg(feature = "wgpu")]
2087    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2088    #[test]
2089    fn wgpu_transposed_rectangular_linsolve_avoids_host_reupload_fallback() {
2090        let _accel_guard = test_support::accel_test_lock();
2091        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2092            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2093        );
2094        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2095        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2096            return;
2097        }
2098        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![2, 3]).unwrap();
2099        let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
2100
2101        let mut cpu_opts = StructValue::new();
2102        cpu_opts.fields.insert(
2103            "TRANSA".to_string(),
2104            Value::CharArray(CharArray::new_row("T")),
2105        );
2106        cpu_opts
2107            .fields
2108            .insert("RECT".to_string(), Value::Bool(true));
2109        let cpu = linsolve_builtin(
2110            Value::Tensor(a.clone()),
2111            Value::Tensor(b.clone()),
2112            vec![Value::Struct(cpu_opts)],
2113        )
2114        .expect("cpu linsolve");
2115        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2116        provider.reset_telemetry();
2117
2118        let ha = provider
2119            .upload(&HostTensorView {
2120                data: &a.data,
2121                shape: &a.shape,
2122            })
2123            .expect("upload A");
2124        let hb = provider
2125            .upload(&HostTensorView {
2126                data: &b.data,
2127                shape: &b.shape,
2128            })
2129            .expect("upload B");
2130
2131        let _output_guard = crate::output_count::push_output_count(Some(1));
2132        let mut gpu_opts = StructValue::new();
2133        gpu_opts.fields.insert(
2134            "TRANSA".to_string(),
2135            Value::CharArray(CharArray::new_row("T")),
2136        );
2137        gpu_opts
2138            .fields
2139            .insert("RECT".to_string(), Value::Bool(true));
2140        let gpu_value = linsolve_builtin(
2141            Value::GpuTensor(ha.clone()),
2142            Value::GpuTensor(hb.clone()),
2143            vec![Value::Struct(gpu_opts)],
2144        )
2145        .expect("gpu transposed rectangular linsolve");
2146        let gpu_solution = match gpu_value {
2147            Value::OutputList(mut outputs) => outputs.remove(0),
2148            other => other,
2149        };
2150        let gathered = test_support::gather(gpu_solution).expect("gather");
2151        let _ = provider.free(&ha);
2152        let _ = provider.free(&hb);
2153
2154        assert_eq!(gathered.shape, cpu_tensor.shape);
2155        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2156            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2157        }
2158
2159        let telemetry = provider.telemetry_snapshot();
2160        assert_eq!(telemetry.linsolve.count, 1);
2161        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2162    }
2163
2164    #[cfg(feature = "wgpu")]
2165    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2166    #[test]
2167    fn wgpu_triangular_hint_avoids_host_reupload_fallback() {
2168        let _accel_guard = test_support::accel_test_lock();
2169        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2170            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2171        );
2172        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2173        let a = Tensor::new(
2174            vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
2175            vec![3, 3],
2176        )
2177        .unwrap();
2178        let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
2179
2180        let cpu = linsolve_builtin(Value::Tensor(a.clone()), Value::Tensor(b.clone()), {
2181            let mut opts = StructValue::new();
2182            opts.fields.insert("LT".to_string(), Value::Bool(true));
2183            vec![Value::Struct(opts)]
2184        })
2185        .expect("cpu linsolve");
2186        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2187        provider.reset_telemetry();
2188
2189        let ha = provider
2190            .upload(&HostTensorView {
2191                data: &a.data,
2192                shape: &a.shape,
2193            })
2194            .expect("upload A");
2195        let hb = provider
2196            .upload(&HostTensorView {
2197                data: &b.data,
2198                shape: &b.shape,
2199            })
2200            .expect("upload B");
2201
2202        let _output_guard = crate::output_count::push_output_count(Some(1));
2203        let mut opts = StructValue::new();
2204        opts.fields.insert("LT".to_string(), Value::Bool(true));
2205        let gpu_value = linsolve_builtin(
2206            Value::GpuTensor(ha.clone()),
2207            Value::GpuTensor(hb.clone()),
2208            vec![Value::Struct(opts)],
2209        )
2210        .expect("gpu triangular linsolve");
2211        let gpu_solution = match gpu_value {
2212            Value::OutputList(mut outputs) => outputs.remove(0),
2213            other => other,
2214        };
2215        let gathered = test_support::gather(gpu_solution).expect("gather");
2216        let _ = provider.free(&ha);
2217        let _ = provider.free(&hb);
2218
2219        assert_eq!(gathered.shape, cpu_tensor.shape);
2220        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2221            assert!((gpu - cpu).abs() < 1e-5);
2222        }
2223
2224        let telemetry = provider.telemetry_snapshot();
2225        assert_eq!(telemetry.linsolve.count, 1);
2226        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2227    }
2228
2229    #[cfg(feature = "wgpu")]
2230    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2231    #[test]
2232    fn wgpu_transposed_triangular_hint_avoids_host_reupload_fallback() {
2233        let _accel_guard = test_support::accel_test_lock();
2234        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2235            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2236        );
2237        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2238        let a = Tensor::new(
2239            vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
2240            vec![3, 3],
2241        )
2242        .unwrap();
2243        let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
2244
2245        let mut cpu_opts = StructValue::new();
2246        cpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2247        cpu_opts.fields.insert(
2248            "TRANSA".to_string(),
2249            Value::CharArray(CharArray::new_row("T")),
2250        );
2251        let cpu = linsolve_builtin(
2252            Value::Tensor(a.clone()),
2253            Value::Tensor(b.clone()),
2254            vec![Value::Struct(cpu_opts)],
2255        )
2256        .expect("cpu linsolve");
2257        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2258        provider.reset_telemetry();
2259
2260        let ha = provider
2261            .upload(&HostTensorView {
2262                data: &a.data,
2263                shape: &a.shape,
2264            })
2265            .expect("upload A");
2266        let hb = provider
2267            .upload(&HostTensorView {
2268                data: &b.data,
2269                shape: &b.shape,
2270            })
2271            .expect("upload B");
2272
2273        let _output_guard = crate::output_count::push_output_count(Some(1));
2274        let mut gpu_opts = StructValue::new();
2275        gpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2276        gpu_opts.fields.insert(
2277            "TRANSA".to_string(),
2278            Value::CharArray(CharArray::new_row("T")),
2279        );
2280        let gpu_value = linsolve_builtin(
2281            Value::GpuTensor(ha.clone()),
2282            Value::GpuTensor(hb.clone()),
2283            vec![Value::Struct(gpu_opts)],
2284        )
2285        .expect("gpu transposed triangular linsolve");
2286        let gpu_solution = match gpu_value {
2287            Value::OutputList(mut outputs) => outputs.remove(0),
2288            other => other,
2289        };
2290        let gathered = test_support::gather(gpu_solution).expect("gather");
2291        let _ = provider.free(&ha);
2292        let _ = provider.free(&hb);
2293
2294        assert_eq!(gathered.shape, cpu_tensor.shape);
2295        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2296            assert!((gpu - cpu).abs() < 1e-5);
2297        }
2298
2299        let telemetry = provider.telemetry_snapshot();
2300        assert_eq!(telemetry.linsolve.count, 1);
2301        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2302    }
2303
2304    #[cfg(feature = "wgpu")]
2305    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2306    #[test]
2307    fn wgpu_round_trip_matches_cpu() {
2308        let _accel_guard = test_support::accel_test_lock();
2309        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2310            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2311        );
2312        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2313        let tol = match provider.precision() {
2314            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2315            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2316        };
2317
2318        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2319        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
2320
2321        let cpu = linsolve_builtin(
2322            Value::Tensor(a.clone()),
2323            Value::Tensor(b.clone()),
2324            Vec::new(),
2325        )
2326        .expect("cpu linsolve");
2327        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2328
2329        let view_a = HostTensorView {
2330            data: &a.data,
2331            shape: &a.shape,
2332        };
2333        let view_b = HostTensorView {
2334            data: &b.data,
2335            shape: &b.shape,
2336        };
2337        let ha = provider.upload(&view_a).expect("upload A");
2338        let hb = provider.upload(&view_b).expect("upload B");
2339        let gpu_value = linsolve_builtin(
2340            Value::GpuTensor(ha.clone()),
2341            Value::GpuTensor(hb.clone()),
2342            Vec::new(),
2343        )
2344        .expect("gpu linsolve");
2345        let gathered = test_support::gather(gpu_value).expect("gather");
2346        let _ = provider.free(&ha);
2347        let _ = provider.free(&hb);
2348
2349        assert_eq!(gathered.shape, cpu_tensor.shape);
2350        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2351            assert!((gpu - cpu).abs() < tol);
2352        }
2353    }
2354}