runmat_runtime/builtins/math/linalg/solve/
linsolve.rs

1//! MATLAB-compatible `linsolve` builtin with structural hints and GPU-aware fallbacks.
2
3use nalgebra::{linalg::SVD, DMatrix};
4use num_complex::Complex64;
5use runmat_accelerate_api::{
6    AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
14};
15use crate::builtins::common::{
16    gpu_helpers,
17    linalg::{diagonal_rcond, singular_value_rcond},
18    tensor,
19};
20use crate::{register_builtin_fusion_spec, register_builtin_gpu_spec};
21
22#[cfg(feature = "doc_export")]
23use crate::register_builtin_doc_text;
24
25const NAME: &str = "linsolve";
26
27#[cfg(feature = "doc_export")]
28pub const DOC_MD: &str = r#"---
29title: "linsolve"
30category: "math/linalg/solve"
31keywords: ["linsolve", "linear solve", "triangular system", "posdef", "gpu"]
32summary: "Solve linear systems A * X = B with optional structural hints (triangular, symmetric, positive-definite, or transposed)."
33references: ["https://www.mathworks.com/help/matlab/ref/linsolve.html"]
34gpu_support:
35  elementwise: false
36  reduction: false
37  precisions: ["f32", "f64"]
38  broadcasting: "none"
39  notes: "Prefers the accel provider's linsolve hook; the current WGPU backend downloads operands to the host, runs the shared solver, then re-uploads the result to preserve GPU residency."
40fusion:
41  elementwise: false
42  reduction: false
43  max_inputs: 2
44  constants: "uniform"
45requires_feature: 'wgpu'
46tested:
47  unit: "builtins::math::linalg::solve::linsolve::tests"
48  gpu: "builtins::math::linalg::solve::linsolve::tests::gpu_round_trip_matches_cpu"
49  wgpu: "builtins::math::linalg::solve::linsolve::tests::wgpu_round_trip_matches_cpu"
50  doc: "builtins::math::linalg::solve::linsolve::tests::doc_examples_present"
51---
52
53# What does the `linsolve` function do in MATLAB / RunMat?
54`X = linsolve(A, B)` solves the linear system `A * X = B`. The optional `opts` structure lets you
55declare that `A` is lower- or upper-triangular, symmetric, positive-definite, rectangular, or that
56the transposed system should be solved instead. These hints mirror MATLAB and allow the runtime to
57skip unnecessary factorizations.
58
59## How does the `linsolve` function behave in MATLAB / RunMat?
60- Inputs must behave like 2-D matrices (trailing singleton dimensions are accepted). `size(A, 1)` must
61  match `size(B, 1)` after accounting for `opts.TRANSA`.
62- When `opts.LT` or `opts.UT` are supplied, `linsolve` performs forward/back substitution instead of a
63  full factorization. Singular pivots trigger the MATLAB error `"linsolve: matrix is singular to working precision."`
64- `opts.TRANSA = 'T'` or `'C'` solves `Aᵀ * X = B` (conjugate transpose for complex matrices).
65- `opts.POSDEF` and `opts.SYM` are accepted for compatibility; the current implementation still falls
66  back to the SVD-based dense solver when a specialised route is not yet wired in.
67- The optional second output `[X, rcond_est] = linsolve(...)` (exposed via the VM multi-output path)
68  returns the estimated reciprocal condition number used to honour `opts.RCOND`.
69- Logical and integer inputs are promoted to double precision. Complex inputs are handled in complex
70  arithmetic.
71
72## `linsolve` GPU execution behaviour
73When a gpuArray provider is active, RunMat offers the solve to its `linsolve` hook. The current WGPU
74backend downloads the operands to the host, executes the shared CPU solver, and uploads the result
75back to the device so downstream kernels retain their residency. If no provider is registered—or a
76provider declines the hook—RunMat gathers inputs to the host and returns a host tensor.
77
78## Examples of using the `linsolve` function in MATLAB / RunMat
79
80### Solving a 2×2 linear system
81```matlab
82A = [4 -2; 1 3];
83b = [6; 7];
84x = linsolve(A, b);
85```
86Expected output:
87```matlab
88x =
89     2
90     1
91```
92
93### Using a lower-triangular hint
94```matlab
95L = [3 0 0; -1 2 0; 4 1 5];
96b = [9; 1; 12];
97opts.LT = true;
98x = linsolve(L, b, opts);
99```
100Expected output:
101```matlab
102x =
103     3
104     2
105     1
106```
107
108### Solving the transposed system
109```matlab
110A = [2 1 0; 0 3 4; 0 0 5];
111b = [3; 11; 5];
112opts.UT = true;
113opts.TRANSA = 'T';
114x = linsolve(A, b, opts);
115```
116Expected output:
117```matlab
118x =
119     1
120     2
121     1
122```
123
124### Complex triangular solve
125```matlab
126U = [2+1i  -1i; 0  4-2i];
127b = [3+2i; 7];
128opts.UT = true;
129x = linsolve(U, b, opts);
130```
131Expected output:
132```matlab
133x =
134   2.0000 + 0.0000i
135   1.7500 + 0.8750i
136```
137
138### Estimating the reciprocal condition number
139```matlab
140A = [1 1; 1 1+1e-12];
141b = [2; 2+1e-12];
142[x, rcond_est] = linsolve(A, b);
143```
144Expected output (up to small round-off):
145```matlab
146x =
147     1
148     1
149
150rcond_est =
151    4.4409e-12
152```
153
154## GPU residency in RunMat (Do I need `gpuArray`?)
155No additional residency management is required. When both operands already reside on the GPU,
156RunMat executes the provider's `linsolve` hook. The current WGPU backend gathers the data to the
157host, runs the shared solver, and re-uploads the output automatically, so downstream GPU work keeps
158its residency. Providers that implement an on-device kernel can execute entirely on the GPU without
159any MATLAB-level changes.
160
161## FAQ
162
163### What happens if I pass both `opts.LT` and `opts.UT`?
164RunMat raises the MATLAB error `"linsolve: LT and UT are mutually exclusive."`—a matrix cannot be
165simultaneously strictly lower- and upper-triangular.
166
167### Does `opts.TRANSA` accept lowercase characters?
168Yes. `opts.TRANSA` is case-insensitive and accepts `'N'`, `'T'`, `'C'`, or their lowercase variants.
169`'C'` and `'T'` are equivalent for real matrices; `'C'` takes the conjugate transpose for complex
170matrices (mirroring MATLAB).
171
172### How is `opts.RCOND` used?
173`opts.RCOND` provides a lower bound on the acceptable reciprocal condition number. If the estimated
174`rcond` falls below the requested threshold the builtin raises
175`"linsolve: matrix is singular to working precision."`
176
177### Do `opts.SYM` or `opts.POSDEF` change the algorithm today?
178They are accepted for MATLAB compatibility. The current implementation still uses the dense SVD
179solver when no specialised routine is wired in; future work will route positive-definite systems to
180Cholesky-based kernels.
181
182### Can I use higher-dimensional arrays?
183Inputs must behave like matrices. Trailing singleton dimensions are permitted, but other higher-rank
184arrays should be reshaped before calling `linsolve`, just like in MATLAB.
185
186## See Also
187[mldivide](../../ops/mldivide), [mrdivide](../../ops/mrdivide), [lu](../../factor/lu), [chol](../../factor/chol), [gpuArray](../../../acceleration/gpu/gpuArray), [gather](../../../acceleration/gpu/gather)
188"#;
189
190#[cfg(not(feature = "doc_export"))]
191#[allow(dead_code)]
192const DOC_MD: &str = "";
193
194pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
195    name: "linsolve",
196    op_kind: GpuOpKind::Custom("solve"),
197    supported_precisions: &[ScalarType::F32, ScalarType::F64],
198    broadcast: BroadcastSemantics::None,
199    provider_hooks: &[ProviderHook::Custom("linsolve")],
200    constant_strategy: ConstantStrategy::UniformBuffer,
201    residency: ResidencyPolicy::NewHandle,
202    nan_mode: ReductionNaN::Include,
203    two_pass_threshold: None,
204    workgroup_size: None,
205    accepts_nan_mode: false,
206    notes: "Prefers the provider linsolve hook; WGPU currently gathers to the host solver and re-uploads the result.",
207};
208
209register_builtin_gpu_spec!(GPU_SPEC);
210
211pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
212    name: "linsolve",
213    shape: ShapeRequirements::Any,
214    constant_strategy: ConstantStrategy::UniformBuffer,
215    elementwise: None,
216    reduction: None,
217    emits_nan: false,
218    notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
219};
220
221register_builtin_fusion_spec!(FUSION_SPEC);
222
223#[cfg(feature = "doc_export")]
224register_builtin_doc_text!("linsolve", DOC_MD);
225
226#[runtime_builtin(
227    name = "linsolve",
228    category = "math/linalg/solve",
229    summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
230    keywords = "linsolve,linear system,triangular,gpu",
231    accel = "linsolve"
232)]
233fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> Result<Value, String> {
234    let eval = evaluate_args(lhs, rhs, &rest)?;
235    Ok(eval.solution())
236}
237
238/// Evaluate `linsolve`, returning both the solution and the estimated reciprocal condition number.
239pub fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> Result<LinsolveEval, String> {
240    if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options)? {
241        return Ok(eval);
242    }
243
244    let lhs_host = crate::dispatcher::gather_if_needed(&lhs)?;
245    let rhs_host = crate::dispatcher::gather_if_needed(&rhs)?;
246    let pair = coerce_numeric_pair(lhs_host, rhs_host)?;
247    match pair {
248        NumericPair::Real(lhs_r, rhs_r) => {
249            let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
250            Ok(LinsolveEval::new(
251                tensor::tensor_into_value(solution),
252                Some(rcond),
253            ))
254        }
255        NumericPair::Complex(lhs_c, rhs_c) => {
256            let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
257            Ok(LinsolveEval::new(
258                Value::ComplexTensor(solution),
259                Some(rcond),
260            ))
261        }
262    }
263}
264
265/// Host implementation shared with acceleration providers that fall back to CPU execution.
266pub fn linsolve_host_real_for_provider(
267    lhs: &Tensor,
268    rhs: &Tensor,
269    options: &ProviderLinsolveOptions,
270) -> Result<(Tensor, f64), String> {
271    let opts = SolveOptions::from(options);
272    solve_real(lhs.clone(), rhs.clone(), &opts)
273}
274
275/// Result wrapper that exposes both primary and secondary outputs.
276#[derive(Clone)]
277pub struct LinsolveEval {
278    solution: Value,
279    rcond: Option<f64>,
280}
281
282impl LinsolveEval {
283    fn new(solution: Value, rcond: Option<f64>) -> Self {
284        Self { solution, rcond }
285    }
286
287    /// Primary solution output.
288    pub fn solution(&self) -> Value {
289        self.solution.clone()
290    }
291
292    /// Estimated reciprocal condition number (second output).
293    pub fn reciprocal_condition(&self) -> Value {
294        match self.rcond {
295            Some(r) => Value::Num(r),
296            None => Value::Num(f64::NAN),
297        }
298    }
299}
300
301#[derive(Clone, Default)]
302pub struct SolveOptions {
303    lower: bool,
304    upper: bool,
305    rectangular: bool,
306    transposed: bool,
307    conjugate: bool,
308    symmetric: bool,
309    posdef: bool,
310    rcond: Option<f64>,
311}
312
313impl From<&SolveOptions> for ProviderLinsolveOptions {
314    fn from(opts: &SolveOptions) -> Self {
315        Self {
316            lower: opts.lower,
317            upper: opts.upper,
318            rectangular: opts.rectangular,
319            transposed: opts.transposed,
320            conjugate: opts.conjugate,
321            symmetric: opts.symmetric,
322            posdef: opts.posdef,
323            rcond: opts.rcond,
324        }
325    }
326}
327
328impl From<&ProviderLinsolveOptions> for SolveOptions {
329    fn from(opts: &ProviderLinsolveOptions) -> Self {
330        Self {
331            lower: opts.lower,
332            upper: opts.upper,
333            rectangular: opts.rectangular,
334            transposed: opts.transposed,
335            conjugate: opts.conjugate,
336            symmetric: opts.symmetric,
337            posdef: opts.posdef,
338            rcond: opts.rcond,
339        }
340    }
341}
342
343fn options_from_rest(rest: &[Value]) -> Result<SolveOptions, String> {
344    match rest.len() {
345        0 => Ok(SolveOptions::default()),
346        1 => parse_options(&rest[0]),
347        _ => Err("linsolve: too many input arguments".to_string()),
348    }
349}
350
351/// Public helper for the VM multi-output surface.
352pub fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> Result<LinsolveEval, String> {
353    let options = options_from_rest(rest)?;
354    evaluate(lhs, rhs, options)
355}
356
357fn try_gpu_linsolve(
358    lhs: &Value,
359    rhs: &Value,
360    options: &SolveOptions,
361) -> Result<Option<LinsolveEval>, String> {
362    let provider = match runmat_accelerate_api::provider() {
363        Some(p) => p,
364        None => return Ok(None),
365    };
366
367    if contains_complex(lhs) || contains_complex(rhs) {
368        return Ok(None);
369    }
370
371    let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
372        Some(op) => op,
373        None => return Ok(None),
374    };
375    let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
376        Some(op) => op,
377        None => {
378            release_operand(provider, &mut lhs_operand);
379            return Ok(None);
380        }
381    };
382
383    if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
384        release_operand(provider, &mut lhs_operand);
385        release_operand(provider, &mut rhs_operand);
386        return Ok(None);
387    }
388
389    let provider_opts: ProviderLinsolveOptions = options.into();
390    let result = provider
391        .linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
392        .ok();
393
394    release_operand(provider, &mut lhs_operand);
395    release_operand(provider, &mut rhs_operand);
396
397    if let Some(ProviderLinsolveResult {
398        solution,
399        reciprocal_condition,
400    }) = result
401    {
402        let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
403        return Ok(Some(eval));
404    }
405
406    Ok(None)
407}
408
409fn parse_options(value: &Value) -> Result<SolveOptions, String> {
410    let struct_val = match value {
411        Value::Struct(s) => s,
412        other => return Err(format!("linsolve: opts must be a struct, got {other:?}")),
413    };
414    let mut opts = SolveOptions::default();
415    for (key, raw_value) in &struct_val.fields {
416        let name = key.to_ascii_uppercase();
417        match name.as_str() {
418            "LT" => opts.lower = parse_bool_field("LT", raw_value)?,
419            "UT" => opts.upper = parse_bool_field("UT", raw_value)?,
420            "RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
421            "SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
422            "POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
423            "TRANSA" => {
424                let transa = parse_transa(raw_value)?;
425                opts.transposed = transa != TransposeMode::None;
426                opts.conjugate = transa == TransposeMode::Conjugate;
427            }
428            "RCOND" => {
429                let threshold = parse_scalar_f64("RCOND", raw_value)?;
430                if threshold < 0.0 {
431                    return Err("linsolve: RCOND must be non-negative".to_string());
432                }
433                opts.rcond = Some(threshold);
434            }
435            other => return Err(format!("linsolve: unknown option '{other}'")),
436        }
437    }
438    if opts.lower && opts.upper {
439        return Err("linsolve: LT and UT are mutually exclusive.".to_string());
440    }
441    Ok(opts)
442}
443
444fn parse_bool_field(name: &str, value: &Value) -> Result<bool, String> {
445    match value {
446        Value::Bool(b) => Ok(*b),
447        Value::Int(i) => Ok(!i.is_zero()),
448        Value::Num(n) => Ok(*n != 0.0),
449        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
450        Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
451        other => Err(format!(
452            "linsolve: option '{name}' must be logical or numeric, got {other:?}"
453        )),
454    }
455}
456
457fn parse_scalar_f64(name: &str, value: &Value) -> Result<f64, String> {
458    match value {
459        Value::Num(n) => Ok(*n),
460        Value::Int(i) => Ok(i.to_f64()),
461        Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
462        other => Err(format!(
463            "linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
464        )),
465    }
466}
467
468#[derive(Copy, Clone, PartialEq, Eq)]
469enum TransposeMode {
470    None,
471    Transpose,
472    Conjugate,
473}
474
475fn parse_transa(value: &Value) -> Result<TransposeMode, String> {
476    let text = tensor::value_to_string(value).ok_or_else(|| {
477        "linsolve: TRANSA must be a character vector or string scalar".to_string()
478    })?;
479    if text.is_empty() {
480        return Err("linsolve: TRANSA cannot be empty".to_string());
481    }
482    match text.trim().to_ascii_uppercase().as_str() {
483        "N" => Ok(TransposeMode::None),
484        "T" => Ok(TransposeMode::Transpose),
485        "C" => Ok(TransposeMode::Conjugate),
486        other => Err(format!(
487            "linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
488        )),
489    }
490}
491
492enum NumericInput {
493    Real(Tensor),
494    Complex(ComplexTensor),
495}
496
497enum NumericPair {
498    Real(Tensor, Tensor),
499    Complex(ComplexTensor, ComplexTensor),
500}
501
502fn coerce_numeric_pair(lhs: Value, rhs: Value) -> Result<NumericPair, String> {
503    let lhs_num = coerce_numeric(lhs)?;
504    let rhs_num = coerce_numeric(rhs)?;
505    match (lhs_num, rhs_num) {
506        (NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
507            Ok(NumericPair::Real(lhs_r, rhs_r))
508        }
509        (NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
510            Ok(NumericPair::Complex(lhs_c, rhs_c))
511        }
512        (NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
513            let rhs_c = promote_real_tensor(&rhs_r)?;
514            Ok(NumericPair::Complex(lhs_c, rhs_c))
515        }
516        (NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
517            let lhs_c = promote_real_tensor(&lhs_r)?;
518            Ok(NumericPair::Complex(lhs_c, rhs_c))
519        }
520    }
521}
522
523fn coerce_numeric(value: Value) -> Result<NumericInput, String> {
524    match value {
525        Value::Tensor(tensor) => {
526            ensure_matrix_shape(NAME, &tensor.shape)?;
527            Ok(NumericInput::Real(tensor))
528        }
529        Value::LogicalArray(logical) => {
530            let tensor = tensor::logical_to_tensor(&logical)?;
531            ensure_matrix_shape(NAME, &tensor.shape)?;
532            Ok(NumericInput::Real(tensor))
533        }
534        Value::Num(n) => {
535            let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(|e| format!("{NAME}: {e}"))?;
536            Ok(NumericInput::Real(tensor))
537        }
538        Value::Int(i) => {
539            let tensor =
540                Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|e| format!("{NAME}: {e}"))?;
541            Ok(NumericInput::Real(tensor))
542        }
543        Value::Bool(b) => {
544            let tensor = Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1])
545                .map_err(|e| format!("{NAME}: {e}"))?;
546            Ok(NumericInput::Real(tensor))
547        }
548        Value::Complex(re, im) => {
549            let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1])
550                .map_err(|e| format!("{NAME}: {e}"))?;
551            Ok(NumericInput::Complex(tensor))
552        }
553        Value::ComplexTensor(ct) => {
554            ensure_matrix_shape(NAME, &ct.shape)?;
555            Ok(NumericInput::Complex(ct))
556        }
557        Value::GpuTensor(handle) => {
558            let tensor = gpu_helpers::gather_tensor(&handle)?;
559            ensure_matrix_shape(NAME, &tensor.shape)?;
560            Ok(NumericInput::Real(tensor))
561        }
562        other => Err(format!(
563            "{NAME}: unsupported input type {:?}; convert to numeric values first",
564            other
565        )),
566    }
567}
568
569fn contains_complex(value: &Value) -> bool {
570    matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
571}
572
573fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
574    handle.shape.iter().copied().product::<usize>() == 1
575}
576
577struct PreparedOperand {
578    handle: GpuTensorHandle,
579    owned: bool,
580}
581
582impl PreparedOperand {
583    fn borrowed(handle: &GpuTensorHandle) -> Self {
584        Self {
585            handle: handle.clone(),
586            owned: false,
587        }
588    }
589
590    fn owned(handle: GpuTensorHandle) -> Self {
591        Self {
592            handle,
593            owned: true,
594        }
595    }
596
597    fn handle(&self) -> &GpuTensorHandle {
598        &self.handle
599    }
600}
601
602fn prepare_gpu_operand(
603    value: &Value,
604    provider: &'static dyn AccelProvider,
605) -> Result<Option<PreparedOperand>, String> {
606    match value {
607        Value::GpuTensor(handle) => {
608            if is_scalar_handle(handle) {
609                Ok(None)
610            } else {
611                Ok(Some(PreparedOperand::borrowed(handle)))
612            }
613        }
614        Value::Tensor(tensor) => {
615            if tensor::is_scalar_tensor(tensor) {
616                Ok(None)
617            } else {
618                let uploaded = upload_tensor(provider, tensor)?;
619                Ok(Some(PreparedOperand::owned(uploaded)))
620            }
621        }
622        Value::LogicalArray(logical) => {
623            if logical.data.len() == 1 {
624                Ok(None)
625            } else {
626                let tensor = tensor::logical_to_tensor(logical)?;
627                let uploaded = upload_tensor(provider, &tensor)?;
628                Ok(Some(PreparedOperand::owned(uploaded)))
629            }
630        }
631        _ => Ok(None),
632    }
633}
634
635fn upload_tensor(
636    provider: &'static dyn AccelProvider,
637    tensor: &Tensor,
638) -> Result<GpuTensorHandle, String> {
639    let view = HostTensorView {
640        data: &tensor.data,
641        shape: &tensor.shape,
642    };
643    provider.upload(&view).map_err(|e| format!("{NAME}: {e}"))
644}
645
646fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
647    if operand.owned {
648        let _ = provider.free(&operand.handle);
649        operand.owned = false;
650    }
651}
652
653fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> Result<(Tensor, f64), String> {
654    let mut lhs_effective = lhs;
655    let mut rhs_effective = rhs;
656    let mut lower = options.lower;
657    let mut upper = options.upper;
658
659    if options.transposed {
660        lhs_effective = transpose_tensor(&lhs_effective);
661        if options.conjugate {
662            conjugate_in_place(&mut lhs_effective);
663        }
664        if lower || upper {
665            std::mem::swap(&mut lower, &mut upper);
666        }
667    }
668
669    rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
670
671    if lower {
672        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
673        let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
674        enforce_rcond(options, rcond)?;
675        return Ok((solution, rcond));
676    }
677
678    if upper {
679        ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
680        let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
681        enforce_rcond(options, rcond)?;
682        return Ok((solution, rcond));
683    }
684
685    let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
686    enforce_rcond(options, rcond)?;
687    Ok((solution, rcond))
688}
689
690fn solve_complex(
691    lhs: ComplexTensor,
692    rhs: ComplexTensor,
693    options: &SolveOptions,
694) -> Result<(ComplexTensor, f64), String> {
695    let mut lhs_effective = lhs;
696    let mut rhs_effective = rhs;
697    let mut lower = options.lower;
698    let mut upper = options.upper;
699
700    if options.transposed {
701        lhs_effective = transpose_complex(&lhs_effective);
702        if options.conjugate {
703            conjugate_complex_in_place(&mut lhs_effective);
704        }
705        if lower || upper {
706            std::mem::swap(&mut lower, &mut upper);
707        }
708    }
709
710    rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
711
712    if lower {
713        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
714        let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
715        enforce_rcond(options, rcond)?;
716        return Ok((solution, rcond));
717    }
718
719    if upper {
720        ensure_square(lhs_effective.rows, lhs_effective.cols)?;
721        let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
722        enforce_rcond(options, rcond)?;
723        return Ok((solution, rcond));
724    }
725
726    let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
727    enforce_rcond(options, rcond)?;
728    Ok((solution, rcond))
729}
730
731fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> Result<(Tensor, f64), String> {
732    let n = lhs.rows();
733    let nrhs = rhs.data.len() / n;
734    let mut solution = rhs.data.clone();
735    let mut min_diag = f64::INFINITY;
736    let mut max_diag = 0.0_f64;
737
738    for col in 0..nrhs {
739        for i in 0..n {
740            let diag = lhs.data[i + i * n];
741            let diag_abs = diag.abs();
742            min_diag = min_diag.min(diag_abs);
743            max_diag = max_diag.max(diag_abs);
744            if diag_abs == 0.0 {
745                return Err("linsolve: matrix is singular to working precision.".to_string());
746            }
747            let mut accum = 0.0;
748            for j in 0..i {
749                accum += lhs.data[i + j * n] * solution[j + col * n];
750            }
751            let rhs_value = solution[i + col * n] - accum;
752            solution[i + col * n] = rhs_value / diag;
753        }
754    }
755
756    let rcond = diagonal_rcond(min_diag, max_diag);
757    let tensor = Tensor::new(solution, rhs.shape.clone()).map_err(|e| format!("{NAME}: {e}"))?;
758    Ok((tensor, rcond))
759}
760
761fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> Result<(Tensor, f64), String> {
762    let n = lhs.rows();
763    let nrhs = rhs.data.len() / n;
764    let mut solution = rhs.data.clone();
765    let mut min_diag = f64::INFINITY;
766    let mut max_diag = 0.0_f64;
767
768    for col in 0..nrhs {
769        for row_rev in 0..n {
770            let i = n - 1 - row_rev;
771            let diag = lhs.data[i + i * n];
772            let diag_abs = diag.abs();
773            min_diag = min_diag.min(diag_abs);
774            max_diag = max_diag.max(diag_abs);
775            if diag_abs == 0.0 {
776                return Err("linsolve: matrix is singular to working precision.".to_string());
777            }
778            let mut accum = 0.0;
779            for j in (i + 1)..n {
780                accum += lhs.data[i + j * n] * solution[j + col * n];
781            }
782            let rhs_value = solution[i + col * n] - accum;
783            solution[i + col * n] = rhs_value / diag;
784        }
785    }
786
787    let rcond = diagonal_rcond(min_diag, max_diag);
788    let tensor = Tensor::new(solution, rhs.shape.clone()).map_err(|e| format!("{NAME}: {e}"))?;
789    Ok((tensor, rcond))
790}
791
792fn forward_substitution_complex(
793    lhs: &ComplexTensor,
794    rhs: &ComplexTensor,
795) -> Result<(ComplexTensor, f64), String> {
796    let n = lhs.rows;
797    let nrhs = rhs.data.len() / n;
798    let lhs_data: Vec<Complex64> = lhs
799        .data
800        .iter()
801        .map(|&(re, im)| Complex64::new(re, im))
802        .collect();
803    let mut solution: Vec<Complex64> = rhs
804        .data
805        .iter()
806        .map(|&(re, im)| Complex64::new(re, im))
807        .collect();
808    let mut min_diag = f64::INFINITY;
809    let mut max_diag = 0.0_f64;
810
811    for col in 0..nrhs {
812        for i in 0..n {
813            let diag = lhs_data[i + i * n];
814            let diag_abs = diag.norm();
815            min_diag = min_diag.min(diag_abs);
816            max_diag = max_diag.max(diag_abs);
817            if diag_abs == 0.0 {
818                return Err("linsolve: matrix is singular to working precision.".to_string());
819            }
820            let mut accum = Complex64::new(0.0, 0.0);
821            for j in 0..i {
822                accum += lhs_data[i + j * n] * solution[j + col * n];
823            }
824            let rhs_value = solution[i + col * n] - accum;
825            solution[i + col * n] = rhs_value / diag;
826        }
827    }
828
829    let rcond = diagonal_rcond(min_diag, max_diag);
830    let tensor = ComplexTensor::new(
831        solution.iter().map(|c| (c.re, c.im)).collect(),
832        rhs.shape.clone(),
833    )
834    .map_err(|e| format!("{NAME}: {e}"))?;
835    Ok((tensor, rcond))
836}
837
838fn backward_substitution_complex(
839    lhs: &ComplexTensor,
840    rhs: &ComplexTensor,
841) -> Result<(ComplexTensor, f64), String> {
842    let n = lhs.rows;
843    let nrhs = rhs.data.len() / n;
844    let lhs_data: Vec<Complex64> = lhs
845        .data
846        .iter()
847        .map(|&(re, im)| Complex64::new(re, im))
848        .collect();
849    let mut solution: Vec<Complex64> = rhs
850        .data
851        .iter()
852        .map(|&(re, im)| Complex64::new(re, im))
853        .collect();
854    let mut min_diag = f64::INFINITY;
855    let mut max_diag = 0.0_f64;
856
857    for col in 0..nrhs {
858        for row_rev in 0..n {
859            let i = n - 1 - row_rev;
860            let diag = lhs_data[i + i * n];
861            let diag_abs = diag.norm();
862            min_diag = min_diag.min(diag_abs);
863            max_diag = max_diag.max(diag_abs);
864            if diag_abs == 0.0 {
865                return Err("linsolve: matrix is singular to working precision.".to_string());
866            }
867            let mut accum = Complex64::new(0.0, 0.0);
868            for j in (i + 1)..n {
869                accum += lhs_data[i + j * n] * solution[j + col * n];
870            }
871            let rhs_value = solution[i + col * n] - accum;
872            solution[i + col * n] = rhs_value / diag;
873        }
874    }
875
876    let rcond = diagonal_rcond(min_diag, max_diag);
877    let tensor = ComplexTensor::new(
878        solution.iter().map(|c| (c.re, c.im)).collect(),
879        rhs.shape.clone(),
880    )
881    .map_err(|e| format!("{NAME}: {e}"))?;
882    Ok((tensor, rcond))
883}
884
885fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> Result<(Tensor, f64), String> {
886    let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
887    let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
888    let svd = SVD::new(a.clone(), true, true);
889    let rcond = singular_value_rcond(svd.singular_values.as_slice());
890    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
891    let solution = svd.solve(&b, tol).map_err(|e| format!("{NAME}: {e}"))?;
892    let tensor = matrix_real_to_tensor(solution)?;
893    Ok((tensor, rcond))
894}
895
896fn solve_general_complex(
897    lhs: &ComplexTensor,
898    rhs: &ComplexTensor,
899) -> Result<(ComplexTensor, f64), String> {
900    let a_data: Vec<Complex64> = lhs
901        .data
902        .iter()
903        .map(|&(re, im)| Complex64::new(re, im))
904        .collect();
905    let b_data: Vec<Complex64> = rhs
906        .data
907        .iter()
908        .map(|&(re, im)| Complex64::new(re, im))
909        .collect();
910    let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
911    let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
912    let svd = SVD::new(a.clone(), true, true);
913    let rcond = singular_value_rcond(svd.singular_values.as_slice());
914    let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
915    let solution = svd.solve(&b, tol).map_err(|e| format!("{NAME}: {e}"))?;
916    let tensor = matrix_complex_to_tensor(solution)?;
917    Ok((tensor, rcond))
918}
919
920fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> Result<Tensor, String> {
921    if rhs.rows() == expected_rows {
922        return Ok(rhs);
923    }
924    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
925        return Tensor::new(rhs.data, vec![expected_rows, 1]).map_err(|e| format!("{NAME}: {e}"));
926    }
927    if rhs.data.is_empty() && expected_rows == 0 {
928        return Ok(rhs);
929    }
930    Err("Matrix dimensions must agree.".to_string())
931}
932
933fn normalize_rhs_complex(
934    rhs: ComplexTensor,
935    expected_rows: usize,
936) -> Result<ComplexTensor, String> {
937    if rhs.rows == expected_rows {
938        return Ok(rhs);
939    }
940    if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
941        return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
942            .map_err(|e| format!("{NAME}: {e}"));
943    }
944    if rhs.data.is_empty() && expected_rows == 0 {
945        return Ok(rhs);
946    }
947    Err("Matrix dimensions must agree.".to_string())
948}
949
950fn enforce_rcond(options: &SolveOptions, rcond: f64) -> Result<(), String> {
951    if let Some(threshold) = options.rcond {
952        if rcond < threshold {
953            return Err("linsolve: matrix is singular to working precision.".to_string());
954        }
955    }
956    Ok(())
957}
958
959fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
960    let max_sv = singular_values
961        .iter()
962        .copied()
963        .fold(0.0_f64, |acc, value| acc.max(value.abs()));
964    let max_dim = rows.max(cols) as f64;
965    f64::EPSILON * max_dim * max_sv.max(1.0)
966}
967
968fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> Result<Tensor, String> {
969    let rows = matrix.nrows();
970    let cols = matrix.ncols();
971    Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols]).map_err(|e| format!("{NAME}: {e}"))
972}
973
974fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> Result<ComplexTensor, String> {
975    let rows = matrix.nrows();
976    let cols = matrix.ncols();
977    let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
978    ComplexTensor::new(data, vec![rows, cols]).map_err(|e| format!("{NAME}: {e}"))
979}
980
981fn promote_real_tensor(tensor: &Tensor) -> Result<ComplexTensor, String> {
982    let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
983    ComplexTensor::new(data, tensor.shape.clone()).map_err(|e| format!("{NAME}: {e}"))
984}
985
986fn ensure_matrix_shape(name: &str, shape: &[usize]) -> Result<(), String> {
987    if is_effectively_matrix(shape) {
988        Ok(())
989    } else {
990        Err(format!("{name}: inputs must be 2-D matrices or vectors"))
991    }
992}
993
994fn is_effectively_matrix(shape: &[usize]) -> bool {
995    match shape.len() {
996        0..=2 => true,
997        _ => shape.iter().skip(2).all(|&dim| dim == 1),
998    }
999}
1000
1001fn ensure_square(rows: usize, cols: usize) -> Result<(), String> {
1002    if rows == cols {
1003        Ok(())
1004    } else {
1005        Err("linsolve: triangular solves require a square coefficient matrix.".to_string())
1006    }
1007}
1008
1009fn transpose_tensor(tensor: &Tensor) -> Tensor {
1010    let rows = tensor.rows();
1011    let cols = tensor.cols();
1012    let mut data = vec![0.0; tensor.data.len()];
1013    for r in 0..rows {
1014        for c in 0..cols {
1015            data[c + r * cols] = tensor.data[r + c * rows];
1016        }
1017    }
1018    Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
1019}
1020
1021fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
1022    let rows = tensor.rows;
1023    let cols = tensor.cols;
1024    let mut data = vec![(0.0, 0.0); tensor.data.len()];
1025    for r in 0..rows {
1026        for c in 0..cols {
1027            data[c + r * cols] = tensor.data[r + c * rows];
1028        }
1029    }
1030    ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
1031}
1032
1033fn conjugate_in_place(_tensor: &mut Tensor) {
1034    // Real-valued matrices are unaffected by conjugation.
1035}
1036
1037fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
1038    for value in &mut tensor.data {
1039        value.1 = -value.1;
1040    }
1041}
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046    use runmat_accelerate_api::HostTensorView;
1047    use runmat_builtins::{CharArray, StructValue, Tensor};
1048
1049    fn approx_eq(actual: f64, expected: f64) {
1050        assert!((actual - expected).abs() < 1e-12);
1051    }
1052
1053    use crate::builtins::common::test_support;
1054
1055    #[test]
1056    fn linsolve_basic_square() {
1057        let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
1058        let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1059        let result =
1060            linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
1061        let t = test_support::gather(result).expect("gather");
1062        assert_eq!(t.shape, vec![2, 1]);
1063        approx_eq(t.data[0], 1.0);
1064        approx_eq(t.data[1], 2.0);
1065    }
1066
1067    #[test]
1068    fn linsolve_lower_triangular_hint() {
1069        let a = Tensor::new(
1070            vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
1071            vec![3, 3],
1072        )
1073        .unwrap();
1074        let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
1075        let mut opts = StructValue::new();
1076        opts.fields.insert("LT".to_string(), Value::Bool(true));
1077        let result = linsolve_builtin(
1078            Value::Tensor(a),
1079            Value::Tensor(b),
1080            vec![Value::Struct(opts)],
1081        )
1082        .expect("linsolve");
1083        let tensor = test_support::gather(result).expect("gather");
1084        assert_eq!(tensor.shape, vec![3, 1]);
1085        approx_eq(tensor.data[0], 3.0);
1086        approx_eq(tensor.data[1], 2.0);
1087        approx_eq(tensor.data[2], 1.0);
1088    }
1089
1090    #[test]
1091    fn linsolve_transposed_triangular_hint() {
1092        let a = Tensor::new(
1093            vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
1094            vec![3, 3],
1095        )
1096        .unwrap();
1097        let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
1098        let mut opts = StructValue::new();
1099        opts.fields.insert("LT".to_string(), Value::Bool(true));
1100        opts.fields.insert(
1101            "TRANSA".to_string(),
1102            Value::CharArray(CharArray::new_row("T")),
1103        );
1104
1105        let result = linsolve_builtin(
1106            Value::Tensor(a.clone()),
1107            Value::Tensor(b.clone()),
1108            vec![Value::Struct(opts)],
1109        )
1110        .expect("linsolve");
1111        let tensor = test_support::gather(result).expect("gather");
1112        assert_eq!(tensor.shape, vec![3, 1]);
1113
1114        let a_transposed = transpose_tensor(&a);
1115        let reference = super::evaluate(
1116            Value::Tensor(a_transposed.clone()),
1117            Value::Tensor(b.clone()),
1118            SolveOptions::default(),
1119        )
1120        .expect("reference");
1121        let expected_tensor = test_support::gather(reference.solution()).expect("gather ref");
1122
1123        for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
1124            approx_eq(*actual, *expected);
1125        }
1126    }
1127
1128    #[test]
1129    fn linsolve_rcond_enforced() {
1130        let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
1131        let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
1132        let mut opts = StructValue::new();
1133        opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
1134        let err = linsolve_builtin(
1135            Value::Tensor(a),
1136            Value::Tensor(b),
1137            vec![Value::Struct(opts)],
1138        )
1139        .expect_err("singular matrix must fail");
1140        assert!(
1141            err.contains("singular to working precision"),
1142            "unexpected error message: {err}"
1143        );
1144    }
1145
1146    #[test]
1147    fn linsolve_recovers_rcond_output() {
1148        let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1149        let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1150        let eval = super::evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1151            .expect("evaluate");
1152        let solution_tensor = match eval.solution() {
1153            Value::Tensor(sol) => sol.clone(),
1154            Value::GpuTensor(handle) => {
1155                test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
1156            }
1157            other => panic!("unexpected solution value {other:?}"),
1158        };
1159        assert_eq!(solution_tensor.shape, vec![2, 1]);
1160        approx_eq(solution_tensor.data[0], 1.0);
1161        approx_eq(solution_tensor.data[1], 2.0);
1162
1163        let rcond_value = match eval.reciprocal_condition() {
1164            Value::Num(r) => r,
1165            Value::GpuTensor(handle) => {
1166                let gathered =
1167                    test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
1168                gathered.data[0]
1169            }
1170            other => panic!("unexpected rcond value {other:?}"),
1171        };
1172        approx_eq(rcond_value, 1.0);
1173    }
1174
1175    #[test]
1176    fn gpu_round_trip_matches_cpu() {
1177        test_support::with_test_provider(|provider| {
1178            let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1179            let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1180
1181            let cpu = linsolve_builtin(
1182                Value::Tensor(a.clone()),
1183                Value::Tensor(b.clone()),
1184                Vec::new(),
1185            )
1186            .expect("cpu linsolve");
1187            let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1188
1189            let view_a = HostTensorView {
1190                data: &a.data,
1191                shape: &a.shape,
1192            };
1193            let view_b = HostTensorView {
1194                data: &b.data,
1195                shape: &b.shape,
1196            };
1197            let ha = provider.upload(&view_a).expect("upload A");
1198            let hb = provider.upload(&view_b).expect("upload B");
1199
1200            let gpu_value = linsolve_builtin(
1201                Value::GpuTensor(ha.clone()),
1202                Value::GpuTensor(hb.clone()),
1203                Vec::new(),
1204            )
1205            .expect("gpu linsolve");
1206            let gathered = test_support::gather(gpu_value).expect("gather");
1207            let _ = provider.free(&ha);
1208            let _ = provider.free(&hb);
1209
1210            assert_eq!(gathered.shape, cpu_tensor.shape);
1211            for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1212                assert!((gpu - cpu).abs() < 1e-12);
1213            }
1214        });
1215    }
1216
1217    #[cfg(feature = "wgpu")]
1218    #[test]
1219    fn wgpu_round_trip_matches_cpu() {
1220        let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1221            runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1222        );
1223        let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1224        let tol = match provider.precision() {
1225            runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
1226            runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
1227        };
1228
1229        let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1230        let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1231
1232        let cpu = linsolve_builtin(
1233            Value::Tensor(a.clone()),
1234            Value::Tensor(b.clone()),
1235            Vec::new(),
1236        )
1237        .expect("cpu linsolve");
1238        let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1239
1240        let view_a = HostTensorView {
1241            data: &a.data,
1242            shape: &a.shape,
1243        };
1244        let view_b = HostTensorView {
1245            data: &b.data,
1246            shape: &b.shape,
1247        };
1248        let ha = provider.upload(&view_a).expect("upload A");
1249        let hb = provider.upload(&view_b).expect("upload B");
1250        let gpu_value = linsolve_builtin(
1251            Value::GpuTensor(ha.clone()),
1252            Value::GpuTensor(hb.clone()),
1253            Vec::new(),
1254        )
1255        .expect("gpu linsolve");
1256        let gathered = test_support::gather(gpu_value).expect("gather");
1257        let _ = provider.free(&ha);
1258        let _ = provider.free(&hb);
1259
1260        assert_eq!(gathered.shape, cpu_tensor.shape);
1261        for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1262            assert!((gpu - cpu).abs() < tol);
1263        }
1264    }
1265
1266    #[cfg(feature = "doc_export")]
1267    #[test]
1268    fn doc_examples_present() {
1269        let blocks = test_support::doc_examples(DOC_MD);
1270        assert!(!blocks.is_empty());
1271    }
1272}