Skip to main content

runmat_runtime/builtins/math/linalg/solve/
linsolve.rs

1//! MATLAB-compatible `linsolve` builtin with structural hints and GPU-aware fallbacks.
2
3use nalgebra::{linalg::SVD, DMatrix};
4use num_complex::Complex64;
5use runmat_accelerate_api::{
6    AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
7};
8use runmat_builtins::{
9    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11    ComplexTensor, Tensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::spec::{
16    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{
20    gpu_helpers,
21    linalg::{diagonal_rcond, singular_value_rcond},
22    tensor,
23};
24use crate::builtins::math::linalg::type_resolvers::left_divide_type;
25use crate::{build_runtime_error, BuiltinResult, RuntimeError};
26
27const NAME: &str = "linsolve";
28
29const LINSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30    name: "X",
31    ty: BuiltinParamType::NumericArray,
32    arity: BuiltinParamArity::Required,
33    default: None,
34    description: "Solution to A * X = B.",
35}];
36
37const LINSOLVE_OUTPUT_XR: [BuiltinParamDescriptor; 2] = [
38    BuiltinParamDescriptor {
39        name: "X",
40        ty: BuiltinParamType::NumericArray,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "Solution to A * X = B.",
44    },
45    BuiltinParamDescriptor {
46        name: "R",
47        ty: BuiltinParamType::NumericScalar,
48        arity: BuiltinParamArity::Required,
49        default: None,
50        description: "Reciprocal condition estimate.",
51    },
52];
53
54const LINSOLVE_INPUTS_AB: [BuiltinParamDescriptor; 2] = [
55    BuiltinParamDescriptor {
56        name: "A",
57        ty: BuiltinParamType::Any,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Coefficient matrix.",
61    },
62    BuiltinParamDescriptor {
63        name: "B",
64        ty: BuiltinParamType::Any,
65        arity: BuiltinParamArity::Required,
66        default: None,
67        description: "Right-hand side matrix or vector.",
68    },
69];
70
71const LINSOLVE_INPUTS_AB_OPTS: [BuiltinParamDescriptor; 3] = [
72    BuiltinParamDescriptor {
73        name: "A",
74        ty: BuiltinParamType::Any,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "Coefficient matrix.",
78    },
79    BuiltinParamDescriptor {
80        name: "B",
81        ty: BuiltinParamType::Any,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Right-hand side matrix or vector.",
85    },
86    BuiltinParamDescriptor {
87        name: "opts",
88        ty: BuiltinParamType::Any,
89        arity: BuiltinParamArity::Optional,
90        default: None,
91        description: "Structural options (LT, UT, RECT, SYM, POSDEF, TRANSA, RCOND).",
92    },
93];
94
95const LINSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
96    BuiltinSignatureDescriptor {
97        label: "X = linsolve(A, B)",
98        inputs: &LINSOLVE_INPUTS_AB,
99        outputs: &LINSOLVE_OUTPUT_X,
100    },
101    BuiltinSignatureDescriptor {
102        label: "X = linsolve(A, B, opts)",
103        inputs: &LINSOLVE_INPUTS_AB_OPTS,
104        outputs: &LINSOLVE_OUTPUT_X,
105    },
106    BuiltinSignatureDescriptor {
107        label: "[X, R] = linsolve(A, B)",
108        inputs: &LINSOLVE_INPUTS_AB,
109        outputs: &LINSOLVE_OUTPUT_XR,
110    },
111    BuiltinSignatureDescriptor {
112        label: "[X, R] = linsolve(A, B, opts)",
113        inputs: &LINSOLVE_INPUTS_AB_OPTS,
114        outputs: &LINSOLVE_OUTPUT_XR,
115    },
116];
117
118const LINSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
119    code: "RM.LINSOLVE.INVALID_ARGUMENT",
120    identifier: Some("RunMat:linsolve:InvalidArgument"),
121    when: "Options/output count/auxiliary arguments are malformed or unsupported.",
122    message: "linsolve: invalid argument",
123};
124
125const LINSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
126    code: "RM.LINSOLVE.INVALID_INPUT",
127    identifier: Some("RunMat:linsolve:InvalidInput"),
128    when: "Input shape/type cannot be solved under linsolve semantics.",
129    message: "linsolve: invalid input",
130};
131
132const LINSOLVE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
133    code: "RM.LINSOLVE.INTERNAL",
134    identifier: Some("RunMat:linsolve:Internal"),
135    when: "Runtime fails while solving or executing provider fallback paths.",
136    message: "linsolve: internal runtime failure",
137};
138
139const LINSOLVE_ERRORS: [BuiltinErrorDescriptor; 3] = [
140    LINSOLVE_ERROR_INVALID_ARGUMENT,
141    LINSOLVE_ERROR_INVALID_INPUT,
142    LINSOLVE_ERROR_INTERNAL,
143];
144
145pub const LINSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
146    signatures: &LINSOLVE_SIGNATURES,
147    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
148    completion_policy: BuiltinCompletionPolicy::Public,
149    errors: &LINSOLVE_ERRORS,
150};
151
152#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::solve::linsolve")]
153pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
154    name: "linsolve",
155    op_kind: GpuOpKind::Custom("solve"),
156    supported_precisions: &[ScalarType::F32, ScalarType::F64],
157    broadcast: BroadcastSemantics::None,
158    provider_hooks: &[ProviderHook::Custom("linsolve")],
159    constant_strategy: ConstantStrategy::UniformBuffer,
160    residency: ResidencyPolicy::NewHandle,
161    nan_mode: ReductionNaN::Include,
162    two_pass_threshold: None,
163    workgroup_size: None,
164    accepts_nan_mode: false,
165    notes: "Prefers the provider linsolve hook; WGPU currently supports triangular solves, real F32 TRANSA='T'/'C' variants, a dedicated real F32 POSDEF/Cholesky path, and selected real F32 QR-backed square and rectangular solves, otherwise it gathers to the host solver and re-uploads the result.",
166};
167
168fn linsolve_error_with_message(
169    message: impl Into<String>,
170    error: &'static BuiltinErrorDescriptor,
171) -> RuntimeError {
172    let mut builder = build_runtime_error(message).with_builtin(NAME);
173    if let Some(identifier) = error.identifier {
174        builder = builder.with_identifier(identifier);
175    }
176    builder.build()
177}
178
179fn builtin_error(message: impl Into<String>) -> RuntimeError {
180    linsolve_error_with_message(message, &LINSOLVE_ERROR_INVALID_INPUT)
181}
182
183fn argument_error(message: impl Into<String>) -> RuntimeError {
184    linsolve_error_with_message(message, &LINSOLVE_ERROR_INVALID_ARGUMENT)
185}
186
187fn map_control_flow(err: RuntimeError) -> RuntimeError {
188    let mut builder = build_runtime_error(err.message()).with_builtin(NAME);
189    if let Some(identifier) = err.identifier() {
190        builder = builder.with_identifier(identifier.to_string());
191    }
192    if let Some(task_id) = err.context.task_id.clone() {
193        builder = builder.with_task_id(task_id);
194    }
195    if !err.context.call_stack.is_empty() {
196        builder = builder.with_call_stack(err.context.call_stack.clone());
197    }
198    if let Some(phase) = err.context.phase.clone() {
199        builder = builder.with_phase(phase);
200    }
201    builder.with_source(err).build()
202}
203
204#[runmat_macros::register_fusion_spec(
205    builtin_path = "crate::builtins::math::linalg::solve::linsolve"
206)]
207pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
208    name: "linsolve",
209    shape: ShapeRequirements::Any,
210    constant_strategy: ConstantStrategy::UniformBuffer,
211    elementwise: None,
212    reduction: None,
213    emits_nan: false,
214    notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
215};
216
217#[runtime_builtin(
218    name = "linsolve",
219    category = "math/linalg/solve",
220    summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
221    keywords = "linsolve,linear system,triangular,gpu",
222    accel = "linsolve",
223    type_resolver(left_divide_type),
224    descriptor(crate::builtins::math::linalg::solve::linsolve::LINSOLVE_DESCRIPTOR),
225    builtin_path = "crate::builtins::math::linalg::solve::linsolve"
226)]
227async fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
228    let eval = evaluate_args(lhs, rhs, &rest).await?;
229    if let Some(out_count) = crate::output_count::current_output_count() {
230        if out_count == 0 {
231            return Ok(Value::OutputList(Vec::new()));
232        }
233        if out_count == 1 {
234            return Ok(Value::OutputList(vec![eval.solution()]));
235        }
236        if out_count == 2 {
237            return Ok(Value::OutputList(vec![
238                eval.solution(),
239                eval.reciprocal_condition(),
240            ]));
241        }
242        return Err(argument_error(
243            "linsolve currently supports at most two outputs",
244        ));
245    }
246    Ok(eval.solution())
247}
248
249/// Evaluate `linsolve`, returning both the solution and the estimated reciprocal condition number.
250pub async fn evaluate(
251    lhs: Value,
252    rhs: Value,
253    options: SolveOptions,
254) -> BuiltinResult<LinsolveEval> {
255    if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options).await? {
256        return Ok(eval);
257    }
258
259    let lhs_host = crate::dispatcher::gather_if_needed_async(&lhs)
260        .await
261        .map_err(map_control_flow)?;
262    let rhs_host = crate::dispatcher::gather_if_needed_async(&rhs)
263        .await
264        .map_err(map_control_flow)?;
265    let pair = coerce_numeric_pair(lhs_host, rhs_host).await?;
266    match pair {
267        NumericPair::Real(lhs_r, rhs_r) => {
268            let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
269            Ok(LinsolveEval::new(
270                tensor::tensor_into_value(solution),
271                Some(rcond),
272            ))
273        }
274        NumericPair::Complex(lhs_c, rhs_c) => {
275            let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
276            Ok(LinsolveEval::new(
277                Value::ComplexTensor(solution),
278                Some(rcond),
279            ))
280        }
281    }
282}
283
284/// Host implementation shared with acceleration providers that fall back to CPU execution.
285pub fn linsolve_host_real_for_provider(
286    lhs: &Tensor,
287    rhs: &Tensor,
288    options: &ProviderLinsolveOptions,
289) -> BuiltinResult<(Tensor, f64)> {
290    let opts = SolveOptions::from(options);
291    solve_real(lhs.clone(), rhs.clone(), &opts)
292}
293
294/// Result wrapper that exposes both primary and secondary outputs.
295#[derive(Clone)]
296pub struct LinsolveEval {
297    solution: Value,
298    rcond: Option<f64>,
299}
300
301impl LinsolveEval {
302    fn new(solution: Value, rcond: Option<f64>) -> Self {
303        Self { solution, rcond }
304    }
305
306    /// Primary solution output.
307    pub fn solution(&self) -> Value {
308        self.solution.clone()
309    }
310
311    /// Estimated reciprocal condition number (second output).
312    pub fn reciprocal_condition(&self) -> Value {
313        match self.rcond {
314            Some(r) => Value::Num(r),
315            None => Value::Num(f64::NAN),
316        }
317    }
318}
319
320#[derive(Clone, Default)]
321pub struct SolveOptions {
322    lower: bool,
323    upper: bool,
324    rectangular: bool,
325    transposed: bool,
326    conjugate: bool,
327    symmetric: bool,
328    posdef: bool,
329    rcond: Option<f64>,
330}
331
332impl From<&SolveOptions> for ProviderLinsolveOptions {
333    fn from(opts: &SolveOptions) -> Self {
334        Self {
335            lower: opts.lower,
336            upper: opts.upper,
337            rectangular: opts.rectangular,
338            transposed: opts.transposed,
339            conjugate: opts.conjugate,
340            symmetric: opts.symmetric,
341            posdef: opts.posdef,
342            need_rcond: false,
343            rcond: opts.rcond,
344        }
345    }
346}
347
348impl From<&ProviderLinsolveOptions> for SolveOptions {
349    fn from(opts: &ProviderLinsolveOptions) -> Self {
350        Self {
351            lower: opts.lower,
352            upper: opts.upper,
353            rectangular: opts.rectangular,
354            transposed: opts.transposed,
355            conjugate: opts.conjugate,
356            symmetric: opts.symmetric,
357            posdef: opts.posdef,
358            rcond: opts.rcond,
359        }
360    }
361}
362
363fn options_from_rest(rest: &[Value]) -> BuiltinResult<SolveOptions> {
364    match rest.len() {
365        0 => Ok(SolveOptions::default()),
366        1 => parse_options(&rest[0]),
367        _ => Err(argument_error("linsolve: too many input arguments")),
368    }
369}
370
371/// Public helper for the VM multi-output surface.
372pub async fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> BuiltinResult<LinsolveEval> {
373    let options = options_from_rest(rest)?;
374    evaluate(lhs, rhs, options).await
375}
376
377async fn try_gpu_linsolve(
378    lhs: &Value,
379    rhs: &Value,
380    options: &SolveOptions,
381) -> BuiltinResult<Option<LinsolveEval>> {
382    if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
383        return Ok(None);
384    }
385    let provider = match runmat_accelerate_api::provider() {
386        Some(p) => p,
387        None => return Ok(None),
388    };
389
390    if contains_complex(lhs) || contains_complex(rhs) {
391        return Ok(None);
392    }
393
394    let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
395        Some(op) => op,
396        None => return Ok(None),
397    };
398    let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
399        Some(op) => op,
400        None => {
401            release_operand(provider, &mut lhs_operand);
402            return Ok(None);
403        }
404    };
405
406    if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
407        release_operand(provider, &mut lhs_operand);
408        release_operand(provider, &mut rhs_operand);
409        return Ok(None);
410    }
411
412    let mut provider_opts: ProviderLinsolveOptions = options.into();
413    provider_opts.need_rcond =
414        matches!(crate::output_count::current_output_count(), Some(2)) || options.rcond.is_some();
415    let result = provider
416        .linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
417        .await
418        .ok();
419
420    release_operand(provider, &mut lhs_operand);
421    release_operand(provider, &mut rhs_operand);
422
423    if let Some(ProviderLinsolveResult {
424        solution,
425        reciprocal_condition,
426    }) = result
427    {
428        let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
429        return Ok(Some(eval));
430    }
431
432    Ok(None)
433}
434
435fn parse_options(value: &Value) -> BuiltinResult<SolveOptions> {
436    let struct_val = match value {
437        Value::Struct(s) => s,
438        other => {
439            return Err(argument_error(format!(
440                "linsolve: opts must be a struct, got {other:?}"
441            )))
442        }
443    };
444    let mut opts = SolveOptions::default();
445    for (key, raw_value) in &struct_val.fields {
446        let name = key.to_ascii_uppercase();
447        match name.as_str() {
448            "LT" => opts.lower = parse_bool_field("LT", raw_value)?,
449            "UT" => opts.upper = parse_bool_field("UT", raw_value)?,
450            "RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
451            "SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
452            "POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
453            "TRANSA" => {
454                let transa = parse_transa(raw_value)?;
455                opts.transposed = transa != TransposeMode::None;
456                opts.conjugate = transa == TransposeMode::Conjugate;
457            }
458            "RCOND" => {
459                let threshold = parse_scalar_f64("RCOND", raw_value)?;
460                if threshold < 0.0 {
461                    return Err(argument_error("linsolve: RCOND must be non-negative"));
462                }
463                opts.rcond = Some(threshold);
464            }
465            other => {
466                return Err(argument_error(format!(
467                    "linsolve: unknown option '{other}'"
468                )))
469            }
470        }
471    }
472    if opts.lower && opts.upper {
473        return Err(argument_error(
474            "linsolve: LT and UT are mutually exclusive.",
475        ));
476    }
477    Ok(opts)
478}
479
480fn parse_bool_field(name: &str, value: &Value) -> BuiltinResult<bool> {
481    match value {
482        Value::Bool(b) => Ok(*b),
483        Value::Int(i) => Ok(!i.is_zero()),
484        Value::Num(n) => Ok(*n != 0.0),
485        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
486        Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
487        other => Err(argument_error(format!(
488            "linsolve: option '{name}' must be logical or numeric, got {other:?}"
489        ))),
490    }
491}
492
493fn parse_scalar_f64(name: &str, value: &Value) -> BuiltinResult<f64> {
494    match value {
495        Value::Num(n) => Ok(*n),
496        Value::Int(i) => Ok(i.to_f64()),
497        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
498        other => Err(argument_error(format!(
499            "linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
500        ))),
501    }
502}
503
504#[derive(Copy, Clone, PartialEq, Eq)]
505enum TransposeMode {
506    None,
507    Transpose,
508    Conjugate,
509}
510
511fn parse_transa(value: &Value) -> BuiltinResult<TransposeMode> {
512    let text = tensor::value_to_string(value).ok_or_else(|| {
513        argument_error("linsolve: TRANSA must be a character vector or string scalar")
514    })?;
515    if text.is_empty() {
516        return Err(argument_error("linsolve: TRANSA cannot be empty"));
517    }
518    match text.trim().to_ascii_uppercase().as_str() {
519        "N" => Ok(TransposeMode::None),
520        "T" => Ok(TransposeMode::Transpose),
521        "C" => Ok(TransposeMode::Conjugate),
522        other => Err(argument_error(format!(
523            "linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
524        ))),
525    }
526}
527
528enum NumericInput {
529    Real(Tensor),
530    Complex(ComplexTensor),
531}
532
533enum NumericPair {
534    Real(Tensor, Tensor),
535    Complex(ComplexTensor, ComplexTensor),
536}
537
538async fn coerce_numeric_pair(lhs: Value, rhs: Value) -> BuiltinResult<NumericPair> {
539    let lhs_num = coerce_numeric(lhs).await?;
540    let rhs_num = coerce_numeric(rhs).await?;
541    match (lhs_num, rhs_num) {
542        (NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
543            Ok(NumericPair::Real(lhs_r, rhs_r))
544        }
545        (NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
546            Ok(NumericPair::Complex(lhs_c, rhs_c))
547        }
548        (NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
549            let rhs_c = promote_real_tensor(&rhs_r)?;
550            Ok(NumericPair::Complex(lhs_c, rhs_c))
551        }
552        (NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
553            let lhs_c = promote_real_tensor(&lhs_r)?;
554            Ok(NumericPair::Complex(lhs_c, rhs_c))
555        }
556    }
557}
558
559async fn coerce_numeric(value: Value) -> BuiltinResult<NumericInput> {
560    match value {
561        Value::Tensor(tensor) => {
562            ensure_matrix_shape(NAME, &tensor.shape)?;
563            Ok(NumericInput::Real(tensor))
564        }
565        Value::LogicalArray(logical) => {
566            let tensor = tensor::logical_to_tensor(&logical).map_err(builtin_error)?;
567            ensure_matrix_shape(NAME, &tensor.shape)?;
568            Ok(NumericInput::Real(tensor))
569        }
570        Value::Num(n) => {
571            let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(builtin_error)?;
572            Ok(NumericInput::Real(tensor))
573        }
574        Value::Int(i) => {
575            let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(builtin_error)?;
576            Ok(NumericInput::Real(tensor))
577        }
578        Value::Bool(b) => {
579            let tensor =
580                Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1]).map_err(builtin_error)?;
581            Ok(NumericInput::Real(tensor))
582        }
583        Value::Complex(re, im) => {
584            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(builtin_error)?;
585            Ok(NumericInput::Complex(tensor))
586        }
587        Value::ComplexTensor(ct) => {
588            ensure_matrix_shape(NAME, &ct.shape)?;
589            Ok(NumericInput::Complex(ct))
590        }
591        Value::GpuTensor(handle) => {
592            let tensor = gpu_helpers::gather_tensor_async(&handle)
593                .await
594                .map_err(map_control_flow)?;
595            ensure_matrix_shape(NAME, &tensor.shape)?;
596            Ok(NumericInput::Real(tensor))
597        }
598        other => Err(builtin_error(format!(
599            "{NAME}: unsupported input type {:?}; convert to numeric values first",
600            other
601        ))),
602    }
603}
604
605fn contains_complex(value: &Value) -> bool {
606    matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
607}
608
609fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
610    crate::builtins::common::shape::is_scalar_shape(&handle.shape)
611}
612
613struct PreparedOperand {
614    handle: GpuTensorHandle,
615    owned: bool,
616}
617
618impl PreparedOperand {
619    fn borrowed(handle: &GpuTensorHandle) -> Self {
620        Self {
621            handle: handle.clone(),
622            owned: false,
623        }
624    }
625
626    fn owned(handle: GpuTensorHandle) -> Self {
627        Self {
628            handle,
629            owned: true,
630        }
631    }
632
633    fn handle(&self) -> &GpuTensorHandle {
634        &self.handle
635    }
636}
637
638fn prepare_gpu_operand(
639    value: &Value,
640    provider: &'static dyn AccelProvider,
641) -> BuiltinResult<Option<PreparedOperand>> {
642    match value {
643        Value::GpuTensor(handle) => {
644            if is_scalar_handle(handle) {
645                Ok(None)
646            } else {
647                Ok(Some(PreparedOperand::borrowed(handle)))
648            }
649        }
650        Value::Tensor(tensor) => {
651            if tensor::is_scalar_tensor(tensor) {
652                Ok(None)
653            } else {
654                let uploaded = upload_tensor(provider, tensor)?;
655                Ok(Some(PreparedOperand::owned(uploaded)))
656            }
657        }
658        Value::LogicalArray(logical) => {
659            if logical.data.len() == 1 {
660                Ok(None)
661            } else {
662                let tensor = tensor::logical_to_tensor(logical).map_err(builtin_error)?;
663                let uploaded = upload_tensor(provider, &tensor)?;
664                Ok(Some(PreparedOperand::owned(uploaded)))
665            }
666        }
667        _ => Ok(None),
668    }
669}
670
671fn upload_tensor(
672    provider: &'static dyn AccelProvider,
673    tensor: &Tensor,
674) -> BuiltinResult<GpuTensorHandle> {
675    let view = HostTensorView {
676        data: &tensor.data,
677        shape: &tensor.shape,
678    };
679    provider
680        .upload(&view)
681        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
682}
683
684fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
685    if operand.owned {
686        let _ = provider.free(&operand.handle);
687        operand.owned = false;
688    }
689}
690
691fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> BuiltinResult<(Tensor, f64)> {
692    let mut lhs_effective = lhs;
693    let mut rhs_effective = rhs;
694    let mut lower = options.lower;
695    let mut upper = options.upper;
696
697    if options.transposed {
698        lhs_effective = transpose_tensor(&lhs_effective);
699        if options.conjugate {
700            conjugate_in_place(&mut lhs_effective);
701        }
702        if lower || upper {
703            std::mem::swap(&mut lower, &mut upper);
704        }
705    }
706
707    rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
708
709    if lower {
710        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
711        let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
712        enforce_rcond(options, rcond)?;
713        return Ok((solution, rcond));
714    }
715
716    if upper {
717        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
718        let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
719        enforce_rcond(options, rcond)?;
720        return Ok((solution, rcond));
721    }
722
723    let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
724    enforce_rcond(options, rcond)?;
725    Ok((solution, rcond))
726}
727
728fn solve_complex(
729    lhs: ComplexTensor,
730    rhs: ComplexTensor,
731    options: &SolveOptions,
732) -> BuiltinResult<(ComplexTensor, f64)> {
733    let mut lhs_effective = lhs;
734    let mut rhs_effective = rhs;
735    let mut lower = options.lower;
736    let mut upper = options.upper;
737
738    if options.transposed {
739        lhs_effective = transpose_complex(&lhs_effective);
740        if options.conjugate {
741            conjugate_complex_in_place(&mut lhs_effective);
742        }
743        if lower || upper {
744            std::mem::swap(&mut lower, &mut upper);
745        }
746    }
747
748    rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
749
750    if lower {
751        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
752        let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
753        enforce_rcond(options, rcond)?;
754        return Ok((solution, rcond));
755    }
756
757    if upper {
758        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
759        let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
760        enforce_rcond(options, rcond)?;
761        return Ok((solution, rcond));
762    }
763
764    let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
765    enforce_rcond(options, rcond)?;
766    Ok((solution, rcond))
767}
768
769fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
770    let n = lhs.rows();
771    let nrhs = rhs.data.len() / n;
772    let mut solution = rhs.data.clone();
773    let mut min_diag = f64::INFINITY;
774    let mut max_diag = 0.0_f64;
775
776    for col in 0..nrhs {
777        for i in 0..n {
778            let diag = lhs.data[i + i * n];
779            let diag_abs = diag.abs();
780            min_diag = min_diag.min(diag_abs);
781            max_diag = max_diag.max(diag_abs);
782            if diag_abs == 0.0 {
783                return Err(builtin_error(
784                    "linsolve: matrix is singular to working precision.",
785                ));
786            }
787            let mut accum = 0.0;
788            for j in 0..i {
789                accum += lhs.data[i + j * n] * solution[j + col * n];
790            }
791            let rhs_value = solution[i + col * n] - accum;
792            solution[i + col * n] = rhs_value / diag;
793        }
794    }
795
796    let rcond = diagonal_rcond(min_diag, max_diag);
797    let tensor = Tensor::new(solution, rhs.shape.clone())
798        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
799    Ok((tensor, rcond))
800}
801
802fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
803    let n = lhs.rows();
804    let nrhs = rhs.data.len() / n;
805    let mut solution = rhs.data.clone();
806    let mut min_diag = f64::INFINITY;
807    let mut max_diag = 0.0_f64;
808
809    for col in 0..nrhs {
810        for row_rev in 0..n {
811            let i = n - 1 - row_rev;
812            let diag = lhs.data[i + i * n];
813            let diag_abs = diag.abs();
814            min_diag = min_diag.min(diag_abs);
815            max_diag = max_diag.max(diag_abs);
816            if diag_abs == 0.0 {
817                return Err(builtin_error(
818                    "linsolve: matrix is singular to working precision.",
819                ));
820            }
821            let mut accum = 0.0;
822            for j in (i + 1)..n {
823                accum += lhs.data[i + j * n] * solution[j + col * n];
824            }
825            let rhs_value = solution[i + col * n] - accum;
826            solution[i + col * n] = rhs_value / diag;
827        }
828    }
829
830    let rcond = diagonal_rcond(min_diag, max_diag);
831    let tensor = Tensor::new(solution, rhs.shape.clone())
832        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
833    Ok((tensor, rcond))
834}
835
836fn forward_substitution_complex(
837    lhs: &ComplexTensor,
838    rhs: &ComplexTensor,
839) -> BuiltinResult<(ComplexTensor, f64)> {
840    let n = lhs.rows;
841    let nrhs = rhs.data.len() / n;
842    let lhs_data: Vec<Complex64> = lhs
843        .data
844        .iter()
845        .map(|&(re, im)| Complex64::new(re, im))
846        .collect();
847    let mut solution: Vec<Complex64> = rhs
848        .data
849        .iter()
850        .map(|&(re, im)| Complex64::new(re, im))
851        .collect();
852    let mut min_diag = f64::INFINITY;
853    let mut max_diag = 0.0_f64;
854
855    for col in 0..nrhs {
856        for i in 0..n {
857            let diag = lhs_data[i + i * n];
858            let diag_abs = diag.norm();
859            min_diag = min_diag.min(diag_abs);
860            max_diag = max_diag.max(diag_abs);
861            if diag_abs == 0.0 {
862                return Err(builtin_error(
863                    "linsolve: matrix is singular to working precision.",
864                ));
865            }
866            let mut accum = Complex64::new(0.0, 0.0);
867            for j in 0..i {
868                accum += lhs_data[i + j * n] * solution[j + col * n];
869            }
870            let rhs_value = solution[i + col * n] - accum;
871            solution[i + col * n] = rhs_value / diag;
872        }
873    }
874
875    let rcond = diagonal_rcond(min_diag, max_diag);
876    let tensor = ComplexTensor::new(
877        solution.iter().map(|c| (c.re, c.im)).collect(),
878        rhs.shape.clone(),
879    )
880    .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
881    Ok((tensor, rcond))
882}
883
884fn backward_substitution_complex(
885    lhs: &ComplexTensor,
886    rhs: &ComplexTensor,
887) -> BuiltinResult<(ComplexTensor, f64)> {
888    let n = lhs.rows;
889    let nrhs = rhs.data.len() / n;
890    let lhs_data: Vec<Complex64> = lhs
891        .data
892        .iter()
893        .map(|&(re, im)| Complex64::new(re, im))
894        .collect();
895    let mut solution: Vec<Complex64> = rhs
896        .data
897        .iter()
898        .map(|&(re, im)| Complex64::new(re, im))
899        .collect();
900    let mut min_diag = f64::INFINITY;
901    let mut max_diag = 0.0_f64;
902
903    for col in 0..nrhs {
904        for row_rev in 0..n {
905            let i = n - 1 - row_rev;
906            let diag = lhs_data[i + i * n];
907            let diag_abs = diag.norm();
908            min_diag = min_diag.min(diag_abs);
909            max_diag = max_diag.max(diag_abs);
910            if diag_abs == 0.0 {
911                return Err(builtin_error(
912                    "linsolve: matrix is singular to working precision.",
913                ));
914            }
915            let mut accum = Complex64::new(0.0, 0.0);
916            for j in (i + 1)..n {
917                accum += lhs_data[i + j * n] * solution[j + col * n];
918            }
919            let rhs_value = solution[i + col * n] - accum;
920            solution[i + col * n] = rhs_value / diag;
921        }
922    }
923
924    let rcond = diagonal_rcond(min_diag, max_diag);
925    let tensor = ComplexTensor::new(
926        solution.iter().map(|c| (c.re, c.im)).collect(),
927        rhs.shape.clone(),
928    )
929    .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
930    Ok((tensor, rcond))
931}
932
933fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
934    let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
935    let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
936    let svd = SVD::new(a.clone(), true, true);
937    let rcond = singular_value_rcond(svd.singular_values.as_slice());
938    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
939    let solution = svd
940        .solve(&b, tol)
941        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
942    let tensor = matrix_real_to_tensor(solution)?;
943    Ok((tensor, rcond))
944}
945
946fn solve_general_complex(
947    lhs: &ComplexTensor,
948    rhs: &ComplexTensor,
949) -> BuiltinResult<(ComplexTensor, f64)> {
950    let a_data: Vec<Complex64> = lhs
951        .data
952        .iter()
953        .map(|&(re, im)| Complex64::new(re, im))
954        .collect();
955    let b_data: Vec<Complex64> = rhs
956        .data
957        .iter()
958        .map(|&(re, im)| Complex64::new(re, im))
959        .collect();
960    let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
961    let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
962    let svd = SVD::new(a.clone(), true, true);
963    let rcond = singular_value_rcond(svd.singular_values.as_slice());
964    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
965    let solution = svd
966        .solve(&b, tol)
967        .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
968    let tensor = matrix_complex_to_tensor(solution)?;
969    Ok((tensor, rcond))
970}
971
972fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> BuiltinResult<Tensor> {
973    if rhs.rows() == expected_rows {
974        return Ok(rhs);
975    }
976    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
977        return Tensor::new(rhs.data, vec![expected_rows, 1])
978            .map_err(|e| builtin_error(format!("{NAME}: {e}")));
979    }
980    if rhs.data.is_empty() && expected_rows == 0 {
981        return Ok(rhs);
982    }
983    Err(builtin_error("Matrix dimensions must agree."))
984}
985
986fn normalize_rhs_complex(rhs: ComplexTensor, expected_rows: usize) -> BuiltinResult<ComplexTensor> {
987    if rhs.rows == expected_rows {
988        return Ok(rhs);
989    }
990    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
991        return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
992            .map_err(|e| builtin_error(format!("{NAME}: {e}")));
993    }
994    if rhs.data.is_empty() && expected_rows == 0 {
995        return Ok(rhs);
996    }
997    Err(builtin_error("Matrix dimensions must agree."))
998}
999
1000fn enforce_rcond(options: &SolveOptions, rcond: f64) -> BuiltinResult<()> {
1001    if let Some(threshold) = options.rcond {
1002        if rcond < threshold {
1003            return Err(builtin_error(
1004                "linsolve: matrix is singular to working precision.",
1005            ));
1006        }
1007    }
1008    Ok(())
1009}
1010
1011fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
1012    let max_sv = singular_values
1013        .iter()
1014        .copied()
1015        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
1016    let max_dim = rows.max(cols) as f64;
1017    f64::EPSILON * max_dim * max_sv.max(1.0)
1018}
1019
1020fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> BuiltinResult<Tensor> {
1021    let rows = matrix.nrows();
1022    let cols = matrix.ncols();
1023    Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols])
1024        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
1025}
1026
1027fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> BuiltinResult<ComplexTensor> {
1028    let rows = matrix.nrows();
1029    let cols = matrix.ncols();
1030    let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
1031    ComplexTensor::new(data, vec![rows, cols]).map_err(|e| builtin_error(format!("{NAME}: {e}")))
1032}
1033
1034fn promote_real_tensor(tensor: &Tensor) -> BuiltinResult<ComplexTensor> {
1035    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1036    ComplexTensor::new(data, tensor.shape.clone())
1037        .map_err(|e| builtin_error(format!("{NAME}: {e}")))
1038}
1039
1040fn ensure_matrix_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1041    if is_effectively_matrix(shape) {
1042        Ok(())
1043    } else {
1044        Err(builtin_error(format!(
1045            "{name}: inputs must be 2-D matrices or vectors"
1046        )))
1047    }
1048}
1049
1050fn is_effectively_matrix(shape: &[usize]) -> bool {
1051    match shape.len() {
1052        0..=2 => true,
1053        _ => shape.iter().skip(2).all(|&dim| dim == 1),
1054    }
1055}
1056
1057fn ensure_square(rows: usize, cols: usize) -> BuiltinResult<()> {
1058    if rows == cols {
1059        Ok(())
1060    } else {
1061        Err(builtin_error(
1062            "linsolve: triangular solves require a square coefficient matrix.",
1063        ))
1064    }
1065}
1066
1067fn transpose_tensor(tensor: &Tensor) -> Tensor {
1068    let rows = tensor.rows();
1069    let cols = tensor.cols();
1070    let mut data = vec![0.0; tensor.data.len()];
1071    for r in 0..rows {
1072        for c in 0..cols {
1073            data[c + r * cols] = tensor.data[r + c * rows];
1074        }
1075    }
1076    Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
1077}
1078
1079fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
1080    let rows = tensor.rows;
1081    let cols = tensor.cols;
1082    let mut data = vec![(0.0, 0.0); tensor.data.len()];
1083    for r in 0..rows {
1084        for c in 0..cols {
1085            data[c + r * cols] = tensor.data[r + c * rows];
1086        }
1087    }
1088    ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
1089}
1090
1091fn conjugate_in_place(_tensor: &mut Tensor) {
1092    // Real-valued matrices are unaffected by conjugation.
1093}
1094
1095fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
1096    for value in &mut tensor.data {
1097        value.1 = -value.1;
1098    }
1099}
1100
1101#[cfg(test)]
1102pub(crate) mod tests {
1103    use super::*;
1104    use futures::executor::block_on;
1105    use runmat_accelerate_api::HostTensorView;
1106    use runmat_builtins::{CharArray, ResolveContext, StructValue, Type};
1107    fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
1108        err
1109    }
1110
1111    fn approx_eq(actual: f64, expected: f64) {
1112        assert!((actual - expected).abs() < 1e-7);
1113    }
1114
1115    fn evaluate_args(a: Value, b: Value, rest: &[Value]) -> Result<LinsolveEval, RuntimeError> {
1116        block_on(super::evaluate_args(a, b, rest))
1117    }
1118
1119    #[test]
1120    fn linsolve_type_uses_rhs_columns() {
1121        let out = left_divide_type(
1122            &[
1123                Type::Tensor {
1124                    shape: Some(vec![Some(2), Some(2)]),
1125                },
1126                Type::Tensor {
1127                    shape: Some(vec![Some(2), Some(3)]),
1128                },
1129            ],
1130            &ResolveContext::new(Vec::new()),
1131        );
1132        assert_eq!(
1133            out,
1134            Type::Tensor {
1135                shape: Some(vec![Some(2), Some(3)])
1136            }
1137        );
1138    }
1139
1140    #[test]
1141    fn linsolve_descriptor_signatures_cover_core_forms() {
1142        let labels: Vec<&str> = LINSOLVE_DESCRIPTOR
1143            .signatures
1144            .iter()
1145            .map(|signature| signature.label)
1146            .collect();
1147        assert!(labels.contains(&"X = linsolve(A, B)"));
1148        assert!(labels.contains(&"X = linsolve(A, B, opts)"));
1149        assert!(labels.contains(&"[X, R] = linsolve(A, B)"));
1150        assert!(labels.contains(&"[X, R] = linsolve(A, B, opts)"));
1151    }
1152
1153    #[test]
1154    fn linsolve_descriptor_errors_have_stable_codes() {
1155        let codes: Vec<&str> = LINSOLVE_DESCRIPTOR
1156            .errors
1157            .iter()
1158            .map(|err| err.code)
1159            .collect();
1160        assert!(codes.contains(&"RM.LINSOLVE.INVALID_ARGUMENT"));
1161        assert!(codes.contains(&"RM.LINSOLVE.INVALID_INPUT"));
1162        assert!(codes.contains(&"RM.LINSOLVE.INTERNAL"));
1163    }
1164
1165    use crate::builtins::common::test_support;
1166    use runmat_accelerate_api::ProviderTelemetry;
1167
1168    fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1169        block_on(super::linsolve_builtin(lhs, rhs, rest))
1170    }
1171
1172    fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> BuiltinResult<LinsolveEval> {
1173        block_on(super::evaluate(lhs, rhs, options))
1174    }
1175
1176    fn fallback_count(telemetry: &ProviderTelemetry, reason: &str) -> u64 {
1177        telemetry
1178            .solve_fallbacks
1179            .iter()
1180            .find(|entry| entry.reason == reason)
1181            .map(|entry| entry.count)
1182            .unwrap_or(0)
1183    }
1184
1185    #[cfg(feature = "wgpu")]
1186    fn kernel_launch_count(telemetry: &ProviderTelemetry, kernel: &str) -> usize {
1187        telemetry
1188            .kernel_launches
1189            .iter()
1190            .filter(|entry| entry.kernel == kernel)
1191            .count()
1192    }
1193
1194    fn clear_accel_provider_state() {
1195        runmat_accelerate_api::set_thread_provider(None);
1196        runmat_accelerate_api::clear_provider();
1197    }
1198
1199    fn host_linsolve_real(
1200        a: &Tensor,
1201        b: &Tensor,
1202        options: ProviderLinsolveOptions,
1203    ) -> (Tensor, f64) {
1204        super::linsolve_host_real_for_provider(a, b, &options).expect("host linsolve")
1205    }
1206
1207    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1208    #[test]
1209    fn linsolve_basic_square() {
1210        let _accel_guard = test_support::accel_test_lock();
1211        clear_accel_provider_state();
1212        let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
1213        let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1214        let result =
1215            linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
1216        let t = test_support::gather(result).expect("gather");
1217        assert_eq!(t.shape, vec![2, 1]);
1218        approx_eq(t.data[0], 1.0);
1219        approx_eq(t.data[1], 2.0);
1220    }
1221
1222    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1223    #[test]
1224    fn linsolve_lower_triangular_hint() {
1225        let _accel_guard = test_support::accel_test_lock();
1226        clear_accel_provider_state();
1227        let a = Tensor::new(
1228            vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
1229            vec![3, 3],
1230        )
1231        .unwrap();
1232        let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
1233        let mut opts = StructValue::new();
1234        opts.fields.insert("LT".to_string(), Value::Bool(true));
1235        let result = linsolve_builtin(
1236            Value::Tensor(a),
1237            Value::Tensor(b),
1238            vec![Value::Struct(opts)],
1239        )
1240        .expect("linsolve");
1241        let tensor = test_support::gather(result).expect("gather");
1242        assert_eq!(tensor.shape, vec![3, 1]);
1243        approx_eq(tensor.data[0], 3.0);
1244        approx_eq(tensor.data[1], 2.0);
1245        approx_eq(tensor.data[2], 1.0);
1246    }
1247
1248    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1249    #[test]
1250    fn linsolve_transposed_triangular_hint() {
1251        let _accel_guard = test_support::accel_test_lock();
1252        clear_accel_provider_state();
1253        let a = Tensor::new(
1254            vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
1255            vec![3, 3],
1256        )
1257        .unwrap();
1258        let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
1259        let mut opts = StructValue::new();
1260        opts.fields.insert("LT".to_string(), Value::Bool(true));
1261        opts.fields.insert(
1262            "TRANSA".to_string(),
1263            Value::CharArray(CharArray::new_row("T")),
1264        );
1265
1266        let result = linsolve_builtin(
1267            Value::Tensor(a.clone()),
1268            Value::Tensor(b.clone()),
1269            vec![Value::Struct(opts)],
1270        )
1271        .expect("linsolve");
1272        let tensor = test_support::gather(result).expect("gather");
1273        assert_eq!(tensor.shape, vec![3, 1]);
1274
1275        let a_transposed = transpose_tensor(&a);
1276        let (expected_tensor, _) =
1277            host_linsolve_real(&a_transposed, &b, ProviderLinsolveOptions::default());
1278
1279        for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
1280            approx_eq(*actual, *expected);
1281        }
1282    }
1283
1284    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1285    #[test]
1286    fn linsolve_complex_inputs_match_residual() {
1287        let a = ComplexTensor::new(
1288            vec![(2.0, 1.0), (-1.0, 0.0), (1.0, -2.0), (3.0, -2.0)],
1289            vec![2, 2],
1290        )
1291        .unwrap();
1292        let b = ComplexTensor::new(vec![(1.0, 0.0), (4.0, 1.0)], vec![2, 1]).unwrap();
1293        let result = linsolve_builtin(
1294            Value::ComplexTensor(a.clone()),
1295            Value::ComplexTensor(b.clone()),
1296            Vec::new(),
1297        )
1298        .expect("linsolve");
1299        let Value::ComplexTensor(out) = result else {
1300            panic!("expected complex tensor result");
1301        };
1302
1303        let mat_a: Vec<Complex64> = a
1304            .data
1305            .iter()
1306            .map(|&(re, im)| Complex64::new(re, im))
1307            .collect();
1308        let mat_b: Vec<Complex64> = b
1309            .data
1310            .iter()
1311            .map(|&(re, im)| Complex64::new(re, im))
1312            .collect();
1313        let mat_x: Vec<Complex64> = out
1314            .data
1315            .iter()
1316            .map(|&(re, im)| Complex64::new(re, im))
1317            .collect();
1318        let a_mat = DMatrix::from_column_slice(a.rows, a.cols, &mat_a);
1319        let b_mat = DMatrix::from_column_slice(b.rows, b.cols, &mat_b);
1320        let x_mat = DMatrix::from_column_slice(out.rows, out.cols, &mat_x);
1321        let residual = a_mat * x_mat - b_mat;
1322        assert!(residual.norm() < 1e-10, "residual={}", residual.norm());
1323    }
1324
1325    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1326    #[test]
1327    fn linsolve_complex_conjugate_transpose_matches_explicit_reference() {
1328        let a = ComplexTensor::new(
1329            vec![(2.0, 1.0), (0.0, -1.0), (1.0, 2.0), (3.0, 0.5)],
1330            vec![2, 2],
1331        )
1332        .unwrap();
1333        let b = ComplexTensor::new(vec![(1.0, -1.0), (2.0, 0.5)], vec![2, 1]).unwrap();
1334
1335        let mut opts = StructValue::new();
1336        opts.fields.insert(
1337            "TRANSA".to_string(),
1338            Value::CharArray(CharArray::new_row("C")),
1339        );
1340        let result = linsolve_builtin(
1341            Value::ComplexTensor(a.clone()),
1342            Value::ComplexTensor(b.clone()),
1343            vec![Value::Struct(opts)],
1344        )
1345        .expect("linsolve");
1346        let Value::ComplexTensor(out) = result else {
1347            panic!("expected complex tensor result");
1348        };
1349
1350        let mut a_conj_t = transpose_complex(&a);
1351        conjugate_complex_in_place(&mut a_conj_t);
1352        let reference = evaluate(
1353            Value::ComplexTensor(a_conj_t),
1354            Value::ComplexTensor(b.clone()),
1355            SolveOptions::default(),
1356        )
1357        .expect("reference");
1358        let Value::ComplexTensor(expected) = reference.solution() else {
1359            panic!("expected complex tensor reference");
1360        };
1361
1362        assert_eq!(out.shape, expected.shape);
1363        for ((out_re, out_im), (exp_re, exp_im)) in out.data.iter().zip(expected.data.iter()) {
1364            assert!(
1365                (out_re - exp_re).abs() < 1e-10,
1366                "out_re={out_re} exp_re={exp_re}"
1367            );
1368            assert!(
1369                (out_im - exp_im).abs() < 1e-10,
1370                "out_im={out_im} exp_im={exp_im}"
1371            );
1372        }
1373    }
1374
1375    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1376    #[test]
1377    fn linsolve_rcond_enforced() {
1378        let _accel_guard = test_support::accel_test_lock();
1379        clear_accel_provider_state();
1380        let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
1381        let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
1382        let mut opts = StructValue::new();
1383        opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
1384        let err = unwrap_error(
1385            linsolve_builtin(
1386                Value::Tensor(a),
1387                Value::Tensor(b),
1388                vec![Value::Struct(opts)],
1389            )
1390            .expect_err("singular matrix must fail"),
1391        );
1392        assert!(
1393            err.message().contains("singular to working precision"),
1394            "unexpected error message: {err}"
1395        );
1396        assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_INPUT.identifier);
1397    }
1398
1399    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1400    #[test]
1401    fn linsolve_unknown_option_identifier() {
1402        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1403        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1404        let mut opts = StructValue::new();
1405        opts.fields.insert("UNKNOWN".to_string(), Value::Bool(true));
1406        let err = unwrap_error(
1407            linsolve_builtin(
1408                Value::Tensor(a),
1409                Value::Tensor(b),
1410                vec![Value::Struct(opts)],
1411            )
1412            .expect_err("unknown option should fail"),
1413        );
1414        assert!(err.message().contains("unknown option"));
1415        assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_ARGUMENT.identifier);
1416    }
1417
1418    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1419    #[test]
1420    fn linsolve_output_count_limit_identifier() {
1421        let a = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1422        let b = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1423        let _guard = crate::output_count::push_output_count(Some(3));
1424        let err = unwrap_error(
1425            linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1426                .expect_err("three outputs should fail"),
1427        );
1428        assert!(err.message().contains("at most two outputs"));
1429        assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_ARGUMENT.identifier);
1430    }
1431
1432    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1433    #[test]
1434    fn linsolve_recovers_rcond_output() {
1435        let _accel_guard = test_support::accel_test_lock();
1436        clear_accel_provider_state();
1437        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1438        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1439        let eval = evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1440            .expect("evaluate");
1441        let solution_tensor = match eval.solution() {
1442            Value::Tensor(sol) => sol.clone(),
1443            Value::GpuTensor(handle) => {
1444                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
1445            }
1446            other => panic!("unexpected solution value {other:?}"),
1447        };
1448        assert_eq!(solution_tensor.shape, vec![2, 1]);
1449        approx_eq(solution_tensor.data[0], 1.0);
1450        approx_eq(solution_tensor.data[1], 2.0);
1451
1452        let rcond_value = match eval.reciprocal_condition() {
1453            Value::Num(r) => r,
1454            Value::GpuTensor(handle) => {
1455                let gathered =
1456                    test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
1457                gathered.data[0]
1458            }
1459            other => panic!("unexpected rcond value {other:?}"),
1460        };
1461        approx_eq(rcond_value, 1.0);
1462    }
1463
1464    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1465    #[test]
1466    fn gpu_round_trip_matches_cpu() {
1467        test_support::with_test_provider(|provider| {
1468            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1469            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1470
1471            let cpu = linsolve_builtin(
1472                Value::Tensor(a.clone()),
1473                Value::Tensor(b.clone()),
1474                Vec::new(),
1475            )
1476            .expect("cpu linsolve");
1477            let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1478
1479            let view_a = HostTensorView {
1480                data: &a.data,
1481                shape: &a.shape,
1482            };
1483            let view_b = HostTensorView {
1484                data: &b.data,
1485                shape: &b.shape,
1486            };
1487            let ha = provider.upload(&view_a).expect("upload A");
1488            let hb = provider.upload(&view_b).expect("upload B");
1489
1490            let gpu_value = linsolve_builtin(
1491                Value::GpuTensor(ha.clone()),
1492                Value::GpuTensor(hb.clone()),
1493                Vec::new(),
1494            )
1495            .expect("gpu linsolve");
1496            let gathered = test_support::gather(gpu_value).expect("gather");
1497            let _ = provider.free(&ha);
1498            let _ = provider.free(&hb);
1499
1500            assert_eq!(gathered.shape, cpu_tensor.shape);
1501            for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1502                assert!((gpu - cpu).abs() < 1e-12);
1503            }
1504        });
1505    }
1506
1507    #[test]
1508    fn host_inputs_auto_promote_into_provider_solve_path() {
1509        test_support::with_test_provider(|provider| {
1510            provider.reset_telemetry();
1511            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1512            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1513            let _ = linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1514                .expect("host linsolve");
1515            let telemetry = provider.telemetry_snapshot();
1516            assert!(telemetry.linsolve.count >= 1);
1517            assert!(fallback_count(&telemetry, "linsolve:host_reupload") >= 1);
1518            assert!(telemetry.upload_bytes > 0);
1519            assert!(telemetry.download_bytes > 0);
1520        });
1521    }
1522
1523    #[test]
1524    fn provider_telemetry_records_gpu_host_reupload_path() {
1525        test_support::with_test_provider(|provider| {
1526            provider.reset_telemetry();
1527            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1528            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1529            let ha = provider
1530                .upload(&HostTensorView {
1531                    data: &a.data,
1532                    shape: &a.shape,
1533                })
1534                .expect("upload A");
1535            let hb = provider
1536                .upload(&HostTensorView {
1537                    data: &b.data,
1538                    shape: &b.shape,
1539                })
1540                .expect("upload B");
1541
1542            let _ = linsolve_builtin(
1543                Value::GpuTensor(ha.clone()),
1544                Value::GpuTensor(hb.clone()),
1545                Vec::new(),
1546            )
1547            .expect("gpu linsolve");
1548
1549            let telemetry = provider.telemetry_snapshot();
1550            assert_eq!(telemetry.linsolve.count, 1);
1551            assert!(telemetry.upload_bytes > 0);
1552            assert!(telemetry.download_bytes > 0);
1553            assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1554
1555            let _ = provider.free(&ha);
1556            let _ = provider.free(&hb);
1557        });
1558    }
1559
1560    #[test]
1561    fn scalar_gpu_inputs_fall_back_without_provider_solve_dispatch() {
1562        test_support::with_test_provider(|provider| {
1563            provider.reset_telemetry();
1564            let a = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1565            let b = Tensor::new(vec![6.0], vec![1, 1]).unwrap();
1566            let ha = provider
1567                .upload(&HostTensorView {
1568                    data: &a.data,
1569                    shape: &a.shape,
1570                })
1571                .expect("upload A");
1572            let hb = provider
1573                .upload(&HostTensorView {
1574                    data: &b.data,
1575                    shape: &b.shape,
1576                })
1577                .expect("upload B");
1578
1579            let result = linsolve_builtin(
1580                Value::GpuTensor(ha.clone()),
1581                Value::GpuTensor(hb.clone()),
1582                Vec::new(),
1583            )
1584            .expect("fallback linsolve");
1585            let gathered = test_support::gather(result).expect("gather fallback");
1586            assert_eq!(gathered.data, vec![3.0]);
1587
1588            let telemetry = provider.telemetry_snapshot();
1589            assert_eq!(telemetry.linsolve.count, 0);
1590            assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1591            assert!(telemetry.download_bytes > 0);
1592
1593            let _ = provider.free(&ha);
1594            let _ = provider.free(&hb);
1595        });
1596    }
1597
1598    #[cfg(feature = "wgpu")]
1599    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1600    #[test]
1601    fn wgpu_square_linsolve_avoids_host_reupload_fallback() {
1602        let _accel_guard = test_support::accel_test_lock();
1603        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1604            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1605        );
1606        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1607        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1608            return;
1609        }
1610        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1611        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1612
1613        let cpu = linsolve_builtin(
1614            Value::Tensor(a.clone()),
1615            Value::Tensor(b.clone()),
1616            Vec::new(),
1617        )
1618        .expect("cpu linsolve");
1619        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1620        provider.reset_telemetry();
1621
1622        let ha = provider
1623            .upload(&HostTensorView {
1624                data: &a.data,
1625                shape: &a.shape,
1626            })
1627            .expect("upload A");
1628        let hb = provider
1629            .upload(&HostTensorView {
1630                data: &b.data,
1631                shape: &b.shape,
1632            })
1633            .expect("upload B");
1634
1635        let _output_guard = crate::output_count::push_output_count(Some(1));
1636        let gpu_value = linsolve_builtin(
1637            Value::GpuTensor(ha.clone()),
1638            Value::GpuTensor(hb.clone()),
1639            Vec::new(),
1640        )
1641        .expect("gpu square linsolve");
1642        let gpu_solution = match gpu_value {
1643            Value::OutputList(mut outputs) => outputs.remove(0),
1644            other => other,
1645        };
1646        let gathered = test_support::gather(gpu_solution).expect("gather");
1647        let _ = provider.free(&ha);
1648        let _ = provider.free(&hb);
1649
1650        assert_eq!(gathered.shape, cpu_tensor.shape);
1651        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1652            assert!((gpu - cpu).abs() < 1e-4);
1653        }
1654
1655        let telemetry = provider.telemetry_snapshot();
1656        assert_eq!(telemetry.linsolve.count, 1);
1657        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1658        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 0);
1659        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1660    }
1661
1662    #[cfg(feature = "wgpu")]
1663    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1664    #[test]
1665    fn wgpu_square_linsolve_uses_device_path_without_output_count() {
1666        let _accel_guard = test_support::accel_test_lock();
1667        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1668            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1669        );
1670        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1671        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1672            return;
1673        }
1674        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1675        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1676
1677        let cpu = linsolve_builtin(
1678            Value::Tensor(a.clone()),
1679            Value::Tensor(b.clone()),
1680            Vec::new(),
1681        )
1682        .expect("cpu linsolve");
1683        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1684        provider.reset_telemetry();
1685
1686        let ha = provider
1687            .upload(&HostTensorView {
1688                data: &a.data,
1689                shape: &a.shape,
1690            })
1691            .expect("upload A");
1692        let hb = provider
1693            .upload(&HostTensorView {
1694                data: &b.data,
1695                shape: &b.shape,
1696            })
1697            .expect("upload B");
1698
1699        let gpu_value = linsolve_builtin(
1700            Value::GpuTensor(ha.clone()),
1701            Value::GpuTensor(hb.clone()),
1702            Vec::new(),
1703        )
1704        .expect("gpu square linsolve");
1705        let gathered = test_support::gather(gpu_value).expect("gather");
1706        let _ = provider.free(&ha);
1707        let _ = provider.free(&hb);
1708
1709        assert_eq!(gathered.shape, cpu_tensor.shape);
1710        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1711            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1712        }
1713
1714        let telemetry = provider.telemetry_snapshot();
1715        assert_eq!(telemetry.linsolve.count, 1);
1716        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1717        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1718    }
1719
1720    #[cfg(feature = "wgpu")]
1721    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1722    #[test]
1723    fn wgpu_square_linsolve_recovers_rcond_output_on_device() {
1724        let _accel_guard = test_support::accel_test_lock();
1725        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1726            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1727        );
1728        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1729        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1730            return;
1731        }
1732        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1733        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1734
1735        let (_, cpu_rcond) = host_linsolve_real(&a, &b, ProviderLinsolveOptions::default());
1736        provider.reset_telemetry();
1737
1738        let ha = provider
1739            .upload(&HostTensorView {
1740                data: &a.data,
1741                shape: &a.shape,
1742            })
1743            .expect("upload A");
1744        let hb = provider
1745            .upload(&HostTensorView {
1746                data: &b.data,
1747                shape: &b.shape,
1748            })
1749            .expect("upload B");
1750
1751        let _output_guard = crate::output_count::push_output_count(Some(2));
1752        let gpu_value = linsolve_builtin(
1753            Value::GpuTensor(ha.clone()),
1754            Value::GpuTensor(hb.clone()),
1755            Vec::new(),
1756        )
1757        .expect("gpu square linsolve");
1758        let outputs = match gpu_value {
1759            Value::OutputList(outputs) => outputs,
1760            other => panic!("expected output list, got {other:?}"),
1761        };
1762        assert_eq!(outputs.len(), 2);
1763        let gathered = test_support::gather(outputs[0].clone()).expect("gather");
1764        let gpu_rcond = match &outputs[1] {
1765            Value::Num(value) => *value,
1766            other => panic!("unexpected gpu rcond {other:?}"),
1767        };
1768        let _ = provider.free(&ha);
1769        let _ = provider.free(&hb);
1770
1771        assert_eq!(gathered.shape, vec![2, 1]);
1772        assert!(
1773            (gpu_rcond - cpu_rcond).abs() < 1e-4,
1774            "gpu={gpu_rcond} cpu={cpu_rcond}"
1775        );
1776
1777        let telemetry = provider.telemetry_snapshot();
1778        assert_eq!(telemetry.linsolve.count, 1);
1779        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1780        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1781    }
1782
1783    #[cfg(feature = "wgpu")]
1784    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1785    #[test]
1786    fn wgpu_square_linsolve_with_rcond_option_stays_on_device() {
1787        let _accel_guard = test_support::accel_test_lock();
1788        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1789            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1790        );
1791        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1792        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1793            return;
1794        }
1795
1796        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1797        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1798        let mut cpu_opts = StructValue::new();
1799        cpu_opts
1800            .fields
1801            .insert("RCOND".to_string(), Value::Num(0.05));
1802        let cpu = linsolve_builtin(
1803            Value::Tensor(a.clone()),
1804            Value::Tensor(b.clone()),
1805            vec![Value::Struct(cpu_opts)],
1806        )
1807        .expect("cpu linsolve");
1808        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1809        provider.reset_telemetry();
1810
1811        let ha = provider
1812            .upload(&HostTensorView {
1813                data: &a.data,
1814                shape: &a.shape,
1815            })
1816            .expect("upload A");
1817        let hb = provider
1818            .upload(&HostTensorView {
1819                data: &b.data,
1820                shape: &b.shape,
1821            })
1822            .expect("upload B");
1823
1824        let _output_guard = crate::output_count::push_output_count(Some(1));
1825        let mut gpu_opts = StructValue::new();
1826        gpu_opts
1827            .fields
1828            .insert("RCOND".to_string(), Value::Num(0.05));
1829        let gpu_value = linsolve_builtin(
1830            Value::GpuTensor(ha.clone()),
1831            Value::GpuTensor(hb.clone()),
1832            vec![Value::Struct(gpu_opts)],
1833        )
1834        .expect("gpu square linsolve");
1835        let gpu_solution = match gpu_value {
1836            Value::OutputList(mut outputs) => outputs.remove(0),
1837            other => other,
1838        };
1839        let gathered = test_support::gather(gpu_solution).expect("gather");
1840        let _ = provider.free(&ha);
1841        let _ = provider.free(&hb);
1842
1843        assert_eq!(gathered.shape, cpu_tensor.shape);
1844        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1845            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1846        }
1847
1848        let telemetry = provider.telemetry_snapshot();
1849        assert_eq!(telemetry.linsolve.count, 1);
1850        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1851        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1852    }
1853
1854    #[cfg(feature = "wgpu")]
1855    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1856    #[test]
1857    fn wgpu_tall_linsolve_avoids_host_reupload_fallback() {
1858        let _accel_guard = test_support::accel_test_lock();
1859        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1860            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1861        );
1862        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1863        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1864            return;
1865        }
1866        let a = Tensor::new(vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
1867        let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
1868
1869        let cpu = linsolve_builtin(
1870            Value::Tensor(a.clone()),
1871            Value::Tensor(b.clone()),
1872            Vec::new(),
1873        )
1874        .expect("cpu linsolve");
1875        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1876        provider.reset_telemetry();
1877
1878        let ha = provider
1879            .upload(&HostTensorView {
1880                data: &a.data,
1881                shape: &a.shape,
1882            })
1883            .expect("upload A");
1884        let hb = provider
1885            .upload(&HostTensorView {
1886                data: &b.data,
1887                shape: &b.shape,
1888            })
1889            .expect("upload B");
1890
1891        let _output_guard = crate::output_count::push_output_count(Some(1));
1892        let gpu_value = linsolve_builtin(
1893            Value::GpuTensor(ha.clone()),
1894            Value::GpuTensor(hb.clone()),
1895            Vec::new(),
1896        )
1897        .expect("gpu tall linsolve");
1898        let gpu_solution = match gpu_value {
1899            Value::OutputList(mut outputs) => outputs.remove(0),
1900            other => other,
1901        };
1902        let gathered = test_support::gather(gpu_solution).expect("gather");
1903        let _ = provider.free(&ha);
1904        let _ = provider.free(&hb);
1905
1906        assert_eq!(gathered.shape, cpu_tensor.shape);
1907        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1908            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1909        }
1910
1911        let telemetry = provider.telemetry_snapshot();
1912        assert_eq!(telemetry.linsolve.count, 1);
1913        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1914    }
1915
1916    #[cfg(feature = "wgpu")]
1917    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1918    #[test]
1919    fn wgpu_posdef_linsolve_avoids_host_reupload_fallback() {
1920        let _accel_guard = test_support::accel_test_lock();
1921        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1922            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1923        );
1924        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1925        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1926            return;
1927        }
1928        let a = Tensor::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1929        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1930
1931        let mut cpu_opts = StructValue::new();
1932        cpu_opts
1933            .fields
1934            .insert("POSDEF".to_string(), Value::Bool(true));
1935        let cpu = linsolve_builtin(
1936            Value::Tensor(a.clone()),
1937            Value::Tensor(b.clone()),
1938            vec![Value::Struct(cpu_opts)],
1939        )
1940        .expect("cpu linsolve");
1941        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1942        let (_, cpu_rcond) = host_linsolve_real(
1943            &a,
1944            &b,
1945            ProviderLinsolveOptions {
1946                posdef: true,
1947                ..Default::default()
1948            },
1949        );
1950        provider.reset_telemetry();
1951
1952        let ha = provider
1953            .upload(&HostTensorView {
1954                data: &a.data,
1955                shape: &a.shape,
1956            })
1957            .expect("upload A");
1958        let hb = provider
1959            .upload(&HostTensorView {
1960                data: &b.data,
1961                shape: &b.shape,
1962            })
1963            .expect("upload B");
1964
1965        let _output_guard = crate::output_count::push_output_count(Some(2));
1966        let mut gpu_opts = StructValue::new();
1967        gpu_opts
1968            .fields
1969            .insert("POSDEF".to_string(), Value::Bool(true));
1970        let gpu_value = linsolve_builtin(
1971            Value::GpuTensor(ha.clone()),
1972            Value::GpuTensor(hb.clone()),
1973            vec![Value::Struct(gpu_opts)],
1974        )
1975        .expect("gpu posdef linsolve");
1976        let mut outputs = match gpu_value {
1977            Value::OutputList(outputs) => outputs,
1978            other => panic!("expected output list, got {other:?}"),
1979        };
1980        let gpu_rcond = match outputs.remove(1) {
1981            Value::Num(value) => value,
1982            other => panic!("unexpected rcond value {other:?}"),
1983        };
1984        let gpu_solution = outputs.remove(0);
1985        let gathered = test_support::gather(gpu_solution).expect("gather");
1986        let _ = provider.free(&ha);
1987        let _ = provider.free(&hb);
1988
1989        assert_eq!(gathered.shape, cpu_tensor.shape);
1990        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1991            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1992        }
1993        assert!(
1994            (gpu_rcond - cpu_rcond).abs() < 1e-4,
1995            "gpu={gpu_rcond} cpu={cpu_rcond}"
1996        );
1997
1998        let telemetry = provider.telemetry_snapshot();
1999        assert_eq!(telemetry.linsolve.count, 1);
2000        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2001        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
2002        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
2003    }
2004
2005    #[cfg(feature = "wgpu")]
2006    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2007    #[test]
2008    fn wgpu_transposed_posdef_linsolve_uses_cholesky_path() {
2009        let _accel_guard = test_support::accel_test_lock();
2010        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2011            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2012        );
2013        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2014        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2015            return;
2016        }
2017        let a = Tensor::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
2018        let b = Tensor::new(vec![8.0, 9.0], vec![2, 1]).unwrap();
2019
2020        let mut cpu_opts = StructValue::new();
2021        cpu_opts
2022            .fields
2023            .insert("POSDEF".to_string(), Value::Bool(true));
2024        cpu_opts.fields.insert(
2025            "TRANSA".to_string(),
2026            Value::CharArray(CharArray::new_row("T")),
2027        );
2028        let cpu = linsolve_builtin(
2029            Value::Tensor(a.clone()),
2030            Value::Tensor(b.clone()),
2031            vec![Value::Struct(cpu_opts)],
2032        )
2033        .expect("cpu linsolve");
2034        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2035        provider.reset_telemetry();
2036
2037        let ha = provider
2038            .upload(&HostTensorView {
2039                data: &a.data,
2040                shape: &a.shape,
2041            })
2042            .expect("upload A");
2043        let hb = provider
2044            .upload(&HostTensorView {
2045                data: &b.data,
2046                shape: &b.shape,
2047            })
2048            .expect("upload B");
2049
2050        let _output_guard = crate::output_count::push_output_count(Some(1));
2051        let mut gpu_opts = StructValue::new();
2052        gpu_opts
2053            .fields
2054            .insert("POSDEF".to_string(), Value::Bool(true));
2055        gpu_opts.fields.insert(
2056            "TRANSA".to_string(),
2057            Value::CharArray(CharArray::new_row("T")),
2058        );
2059        let gpu_value = linsolve_builtin(
2060            Value::GpuTensor(ha.clone()),
2061            Value::GpuTensor(hb.clone()),
2062            vec![Value::Struct(gpu_opts)],
2063        )
2064        .expect("gpu transposed posdef linsolve");
2065        let gpu_solution = match gpu_value {
2066            Value::OutputList(mut outputs) => outputs.remove(0),
2067            other => other,
2068        };
2069        let gathered = test_support::gather(gpu_solution).expect("gather");
2070        let _ = provider.free(&ha);
2071        let _ = provider.free(&hb);
2072
2073        assert_eq!(gathered.shape, cpu_tensor.shape);
2074        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2075            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2076        }
2077
2078        let telemetry = provider.telemetry_snapshot();
2079        assert_eq!(telemetry.linsolve.count, 1);
2080        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2081        assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
2082        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
2083    }
2084
2085    #[cfg(feature = "wgpu")]
2086    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2087    #[test]
2088    fn wgpu_symmetric_linsolve_avoids_host_reupload_fallback() {
2089        let _accel_guard = test_support::accel_test_lock();
2090        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2091            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2092        );
2093        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2094        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2095            return;
2096        }
2097        let a = Tensor::new(vec![5.0, 2.0, 2.0, 6.0], vec![2, 2]).unwrap();
2098        let b = Tensor::new(vec![9.0, 8.0], vec![2, 1]).unwrap();
2099
2100        let mut cpu_opts = StructValue::new();
2101        cpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
2102        let cpu = linsolve_builtin(
2103            Value::Tensor(a.clone()),
2104            Value::Tensor(b.clone()),
2105            vec![Value::Struct(cpu_opts)],
2106        )
2107        .expect("cpu linsolve");
2108        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2109        provider.reset_telemetry();
2110
2111        let ha = provider
2112            .upload(&HostTensorView {
2113                data: &a.data,
2114                shape: &a.shape,
2115            })
2116            .expect("upload A");
2117        let hb = provider
2118            .upload(&HostTensorView {
2119                data: &b.data,
2120                shape: &b.shape,
2121            })
2122            .expect("upload B");
2123
2124        let _output_guard = crate::output_count::push_output_count(Some(1));
2125        let mut gpu_opts = StructValue::new();
2126        gpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
2127        let gpu_value = linsolve_builtin(
2128            Value::GpuTensor(ha.clone()),
2129            Value::GpuTensor(hb.clone()),
2130            vec![Value::Struct(gpu_opts)],
2131        )
2132        .expect("gpu symmetric linsolve");
2133        let gpu_solution = match gpu_value {
2134            Value::OutputList(mut outputs) => outputs.remove(0),
2135            other => other,
2136        };
2137        let gathered = test_support::gather(gpu_solution).expect("gather");
2138        let _ = provider.free(&ha);
2139        let _ = provider.free(&hb);
2140
2141        assert_eq!(gathered.shape, cpu_tensor.shape);
2142        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2143            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2144        }
2145
2146        let telemetry = provider.telemetry_snapshot();
2147        assert_eq!(telemetry.linsolve.count, 1);
2148        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2149    }
2150
2151    #[cfg(feature = "wgpu")]
2152    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2153    #[test]
2154    fn wgpu_transposed_square_linsolve_avoids_host_reupload_fallback() {
2155        let _accel_guard = test_support::accel_test_lock();
2156        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2157            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2158        );
2159        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2160        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2161            return;
2162        }
2163        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2164        let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2165
2166        let mut cpu_opts = StructValue::new();
2167        cpu_opts.fields.insert(
2168            "TRANSA".to_string(),
2169            Value::CharArray(CharArray::new_row("T")),
2170        );
2171        let cpu = linsolve_builtin(
2172            Value::Tensor(a.clone()),
2173            Value::Tensor(b.clone()),
2174            vec![Value::Struct(cpu_opts)],
2175        )
2176        .expect("cpu linsolve");
2177        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2178        provider.reset_telemetry();
2179
2180        let ha = provider
2181            .upload(&HostTensorView {
2182                data: &a.data,
2183                shape: &a.shape,
2184            })
2185            .expect("upload A");
2186        let hb = provider
2187            .upload(&HostTensorView {
2188                data: &b.data,
2189                shape: &b.shape,
2190            })
2191            .expect("upload B");
2192
2193        let _output_guard = crate::output_count::push_output_count(Some(1));
2194        let mut gpu_opts = StructValue::new();
2195        gpu_opts.fields.insert(
2196            "TRANSA".to_string(),
2197            Value::CharArray(CharArray::new_row("T")),
2198        );
2199        let gpu_value = linsolve_builtin(
2200            Value::GpuTensor(ha.clone()),
2201            Value::GpuTensor(hb.clone()),
2202            vec![Value::Struct(gpu_opts)],
2203        )
2204        .expect("gpu transposed square linsolve");
2205        let gpu_solution = match gpu_value {
2206            Value::OutputList(mut outputs) => outputs.remove(0),
2207            other => other,
2208        };
2209        let gathered = test_support::gather(gpu_solution).expect("gather");
2210        let _ = provider.free(&ha);
2211        let _ = provider.free(&hb);
2212
2213        assert_eq!(gathered.shape, cpu_tensor.shape);
2214        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2215            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2216        }
2217
2218        let telemetry = provider.telemetry_snapshot();
2219        assert_eq!(telemetry.linsolve.count, 1);
2220        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2221    }
2222
2223    #[cfg(feature = "wgpu")]
2224    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2225    #[test]
2226    fn wgpu_conjugate_square_linsolve_avoids_host_reupload_fallback_for_real_inputs() {
2227        let _accel_guard = test_support::accel_test_lock();
2228        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2229            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2230        );
2231        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2232        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2233            return;
2234        }
2235
2236        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2237        let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2238        let mut cpu_opts = StructValue::new();
2239        cpu_opts.fields.insert(
2240            "TRANSA".to_string(),
2241            Value::CharArray(CharArray::new_row("C")),
2242        );
2243        let cpu = linsolve_builtin(
2244            Value::Tensor(a.clone()),
2245            Value::Tensor(b.clone()),
2246            vec![Value::Struct(cpu_opts)],
2247        )
2248        .expect("cpu linsolve");
2249        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2250        provider.reset_telemetry();
2251
2252        let ha = provider
2253            .upload(&HostTensorView {
2254                data: &a.data,
2255                shape: &a.shape,
2256            })
2257            .expect("upload A");
2258        let hb = provider
2259            .upload(&HostTensorView {
2260                data: &b.data,
2261                shape: &b.shape,
2262            })
2263            .expect("upload B");
2264
2265        let _output_guard = crate::output_count::push_output_count(Some(1));
2266        let mut gpu_opts = StructValue::new();
2267        gpu_opts.fields.insert(
2268            "TRANSA".to_string(),
2269            Value::CharArray(CharArray::new_row("C")),
2270        );
2271        let gpu_value = linsolve_builtin(
2272            Value::GpuTensor(ha.clone()),
2273            Value::GpuTensor(hb.clone()),
2274            vec![Value::Struct(gpu_opts)],
2275        )
2276        .expect("gpu conjugate square linsolve");
2277        let gpu_solution = match gpu_value {
2278            Value::OutputList(mut outputs) => outputs.remove(0),
2279            other => other,
2280        };
2281        let gathered = test_support::gather(gpu_solution).expect("gather");
2282        let _ = provider.free(&ha);
2283        let _ = provider.free(&hb);
2284
2285        assert_eq!(gathered.shape, cpu_tensor.shape);
2286        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2287            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2288        }
2289
2290        let telemetry = provider.telemetry_snapshot();
2291        assert_eq!(telemetry.linsolve.count, 1);
2292        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2293        assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
2294    }
2295
2296    #[cfg(feature = "wgpu")]
2297    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2298    #[test]
2299    fn wgpu_transposed_rectangular_linsolve_avoids_host_reupload_fallback() {
2300        let _accel_guard = test_support::accel_test_lock();
2301        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2302            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2303        );
2304        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2305        if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2306            return;
2307        }
2308        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![2, 3]).unwrap();
2309        let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
2310
2311        let mut cpu_opts = StructValue::new();
2312        cpu_opts.fields.insert(
2313            "TRANSA".to_string(),
2314            Value::CharArray(CharArray::new_row("T")),
2315        );
2316        cpu_opts
2317            .fields
2318            .insert("RECT".to_string(), Value::Bool(true));
2319        let cpu = linsolve_builtin(
2320            Value::Tensor(a.clone()),
2321            Value::Tensor(b.clone()),
2322            vec![Value::Struct(cpu_opts)],
2323        )
2324        .expect("cpu linsolve");
2325        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2326        provider.reset_telemetry();
2327
2328        let ha = provider
2329            .upload(&HostTensorView {
2330                data: &a.data,
2331                shape: &a.shape,
2332            })
2333            .expect("upload A");
2334        let hb = provider
2335            .upload(&HostTensorView {
2336                data: &b.data,
2337                shape: &b.shape,
2338            })
2339            .expect("upload B");
2340
2341        let _output_guard = crate::output_count::push_output_count(Some(1));
2342        let mut gpu_opts = StructValue::new();
2343        gpu_opts.fields.insert(
2344            "TRANSA".to_string(),
2345            Value::CharArray(CharArray::new_row("T")),
2346        );
2347        gpu_opts
2348            .fields
2349            .insert("RECT".to_string(), Value::Bool(true));
2350        let gpu_value = linsolve_builtin(
2351            Value::GpuTensor(ha.clone()),
2352            Value::GpuTensor(hb.clone()),
2353            vec![Value::Struct(gpu_opts)],
2354        )
2355        .expect("gpu transposed rectangular linsolve");
2356        let gpu_solution = match gpu_value {
2357            Value::OutputList(mut outputs) => outputs.remove(0),
2358            other => other,
2359        };
2360        let gathered = test_support::gather(gpu_solution).expect("gather");
2361        let _ = provider.free(&ha);
2362        let _ = provider.free(&hb);
2363
2364        assert_eq!(gathered.shape, cpu_tensor.shape);
2365        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2366            assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2367        }
2368
2369        let telemetry = provider.telemetry_snapshot();
2370        assert_eq!(telemetry.linsolve.count, 1);
2371        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2372    }
2373
2374    #[cfg(feature = "wgpu")]
2375    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2376    #[test]
2377    fn wgpu_triangular_hint_avoids_host_reupload_fallback() {
2378        let _accel_guard = test_support::accel_test_lock();
2379        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2380            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2381        );
2382        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2383        let a = Tensor::new(
2384            vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
2385            vec![3, 3],
2386        )
2387        .unwrap();
2388        let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
2389
2390        let cpu = linsolve_builtin(Value::Tensor(a.clone()), Value::Tensor(b.clone()), {
2391            let mut opts = StructValue::new();
2392            opts.fields.insert("LT".to_string(), Value::Bool(true));
2393            vec![Value::Struct(opts)]
2394        })
2395        .expect("cpu linsolve");
2396        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2397        provider.reset_telemetry();
2398
2399        let ha = provider
2400            .upload(&HostTensorView {
2401                data: &a.data,
2402                shape: &a.shape,
2403            })
2404            .expect("upload A");
2405        let hb = provider
2406            .upload(&HostTensorView {
2407                data: &b.data,
2408                shape: &b.shape,
2409            })
2410            .expect("upload B");
2411
2412        let _output_guard = crate::output_count::push_output_count(Some(1));
2413        let mut opts = StructValue::new();
2414        opts.fields.insert("LT".to_string(), Value::Bool(true));
2415        let gpu_value = linsolve_builtin(
2416            Value::GpuTensor(ha.clone()),
2417            Value::GpuTensor(hb.clone()),
2418            vec![Value::Struct(opts)],
2419        )
2420        .expect("gpu triangular linsolve");
2421        let gpu_solution = match gpu_value {
2422            Value::OutputList(mut outputs) => outputs.remove(0),
2423            other => other,
2424        };
2425        let gathered = test_support::gather(gpu_solution).expect("gather");
2426        let _ = provider.free(&ha);
2427        let _ = provider.free(&hb);
2428
2429        assert_eq!(gathered.shape, cpu_tensor.shape);
2430        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2431            assert!((gpu - cpu).abs() < 1e-5);
2432        }
2433
2434        let telemetry = provider.telemetry_snapshot();
2435        assert_eq!(telemetry.linsolve.count, 1);
2436        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2437    }
2438
2439    #[cfg(feature = "wgpu")]
2440    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2441    #[test]
2442    fn wgpu_transposed_triangular_hint_avoids_host_reupload_fallback() {
2443        let _accel_guard = test_support::accel_test_lock();
2444        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2445            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2446        );
2447        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2448        let a = Tensor::new(
2449            vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
2450            vec![3, 3],
2451        )
2452        .unwrap();
2453        let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
2454
2455        let mut cpu_opts = StructValue::new();
2456        cpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2457        cpu_opts.fields.insert(
2458            "TRANSA".to_string(),
2459            Value::CharArray(CharArray::new_row("T")),
2460        );
2461        let cpu = linsolve_builtin(
2462            Value::Tensor(a.clone()),
2463            Value::Tensor(b.clone()),
2464            vec![Value::Struct(cpu_opts)],
2465        )
2466        .expect("cpu linsolve");
2467        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2468        provider.reset_telemetry();
2469
2470        let ha = provider
2471            .upload(&HostTensorView {
2472                data: &a.data,
2473                shape: &a.shape,
2474            })
2475            .expect("upload A");
2476        let hb = provider
2477            .upload(&HostTensorView {
2478                data: &b.data,
2479                shape: &b.shape,
2480            })
2481            .expect("upload B");
2482
2483        let _output_guard = crate::output_count::push_output_count(Some(1));
2484        let mut gpu_opts = StructValue::new();
2485        gpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2486        gpu_opts.fields.insert(
2487            "TRANSA".to_string(),
2488            Value::CharArray(CharArray::new_row("T")),
2489        );
2490        let gpu_value = linsolve_builtin(
2491            Value::GpuTensor(ha.clone()),
2492            Value::GpuTensor(hb.clone()),
2493            vec![Value::Struct(gpu_opts)],
2494        )
2495        .expect("gpu transposed triangular linsolve");
2496        let gpu_solution = match gpu_value {
2497            Value::OutputList(mut outputs) => outputs.remove(0),
2498            other => other,
2499        };
2500        let gathered = test_support::gather(gpu_solution).expect("gather");
2501        let _ = provider.free(&ha);
2502        let _ = provider.free(&hb);
2503
2504        assert_eq!(gathered.shape, cpu_tensor.shape);
2505        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2506            assert!((gpu - cpu).abs() < 1e-5);
2507        }
2508
2509        let telemetry = provider.telemetry_snapshot();
2510        assert_eq!(telemetry.linsolve.count, 1);
2511        assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2512    }
2513
2514    #[cfg(feature = "wgpu")]
2515    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2516    #[test]
2517    fn wgpu_round_trip_matches_cpu() {
2518        let _accel_guard = test_support::accel_test_lock();
2519        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2520            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2521        );
2522        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2523        let tol = match provider.precision() {
2524            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2525            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2526        };
2527
2528        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2529        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
2530
2531        let cpu = linsolve_builtin(
2532            Value::Tensor(a.clone()),
2533            Value::Tensor(b.clone()),
2534            Vec::new(),
2535        )
2536        .expect("cpu linsolve");
2537        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2538
2539        let view_a = HostTensorView {
2540            data: &a.data,
2541            shape: &a.shape,
2542        };
2543        let view_b = HostTensorView {
2544            data: &b.data,
2545            shape: &b.shape,
2546        };
2547        let ha = provider.upload(&view_a).expect("upload A");
2548        let hb = provider.upload(&view_b).expect("upload B");
2549        let gpu_value = linsolve_builtin(
2550            Value::GpuTensor(ha.clone()),
2551            Value::GpuTensor(hb.clone()),
2552            Vec::new(),
2553        )
2554        .expect("gpu linsolve");
2555        let gathered = test_support::gather(gpu_value).expect("gather");
2556        let _ = provider.free(&ha);
2557        let _ = provider.free(&hb);
2558
2559        assert_eq!(gathered.shape, cpu_tensor.shape);
2560        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2561            assert!((gpu - cpu).abs() < tol);
2562        }
2563    }
2564}