1use nalgebra::{linalg::SVD, DMatrix};
4use num_complex::Complex64;
5use runmat_accelerate_api::{
6 AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
7};
8use runmat_builtins::{ComplexTensor, Tensor, Value};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
14};
15use crate::builtins::common::{
16 gpu_helpers,
17 linalg::{diagonal_rcond, singular_value_rcond},
18 tensor,
19};
20use crate::builtins::math::linalg::type_resolvers::left_divide_type;
21use crate::{build_runtime_error, BuiltinResult, RuntimeError};
22
23const NAME: &str = "linsolve";
24
25#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::solve::linsolve")]
26pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
27 name: "linsolve",
28 op_kind: GpuOpKind::Custom("solve"),
29 supported_precisions: &[ScalarType::F32, ScalarType::F64],
30 broadcast: BroadcastSemantics::None,
31 provider_hooks: &[ProviderHook::Custom("linsolve")],
32 constant_strategy: ConstantStrategy::UniformBuffer,
33 residency: ResidencyPolicy::NewHandle,
34 nan_mode: ReductionNaN::Include,
35 two_pass_threshold: None,
36 workgroup_size: None,
37 accepts_nan_mode: false,
38 notes: "Prefers the provider linsolve hook; WGPU currently supports triangular solves, real F32 TRANSA='T'/'C' variants, a dedicated real F32 POSDEF/Cholesky path, and selected real F32 QR-backed square and rectangular solves, otherwise it gathers to the host solver and re-uploads the result.",
39};
40
41fn builtin_error(message: impl Into<String>) -> RuntimeError {
42 build_runtime_error(message).with_builtin(NAME).build()
43}
44
45fn map_control_flow(err: RuntimeError) -> RuntimeError {
46 let mut builder = build_runtime_error(err.message()).with_builtin(NAME);
47 if let Some(identifier) = err.identifier() {
48 builder = builder.with_identifier(identifier.to_string());
49 }
50 if let Some(task_id) = err.context.task_id.clone() {
51 builder = builder.with_task_id(task_id);
52 }
53 if !err.context.call_stack.is_empty() {
54 builder = builder.with_call_stack(err.context.call_stack.clone());
55 }
56 if let Some(phase) = err.context.phase.clone() {
57 builder = builder.with_phase(phase);
58 }
59 builder.with_source(err).build()
60}
61
62#[runmat_macros::register_fusion_spec(
63 builtin_path = "crate::builtins::math::linalg::solve::linsolve"
64)]
65pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
66 name: "linsolve",
67 shape: ShapeRequirements::Any,
68 constant_strategy: ConstantStrategy::UniformBuffer,
69 elementwise: None,
70 reduction: None,
71 emits_nan: false,
72 notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
73};
74
75#[runtime_builtin(
76 name = "linsolve",
77 category = "math/linalg/solve",
78 summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
79 keywords = "linsolve,linear system,triangular,gpu",
80 accel = "linsolve",
81 type_resolver(left_divide_type),
82 builtin_path = "crate::builtins::math::linalg::solve::linsolve"
83)]
84async fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
85 let eval = evaluate_args(lhs, rhs, &rest).await?;
86 if let Some(out_count) = crate::output_count::current_output_count() {
87 if out_count == 0 {
88 return Ok(Value::OutputList(Vec::new()));
89 }
90 if out_count == 1 {
91 return Ok(Value::OutputList(vec![eval.solution()]));
92 }
93 if out_count == 2 {
94 return Ok(Value::OutputList(vec![
95 eval.solution(),
96 eval.reciprocal_condition(),
97 ]));
98 }
99 return Err(builtin_error(
100 "linsolve currently supports at most two outputs",
101 ));
102 }
103 Ok(eval.solution())
104}
105
106pub async fn evaluate(
108 lhs: Value,
109 rhs: Value,
110 options: SolveOptions,
111) -> BuiltinResult<LinsolveEval> {
112 if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options).await? {
113 return Ok(eval);
114 }
115
116 let lhs_host = crate::dispatcher::gather_if_needed_async(&lhs)
117 .await
118 .map_err(map_control_flow)?;
119 let rhs_host = crate::dispatcher::gather_if_needed_async(&rhs)
120 .await
121 .map_err(map_control_flow)?;
122 let pair = coerce_numeric_pair(lhs_host, rhs_host).await?;
123 match pair {
124 NumericPair::Real(lhs_r, rhs_r) => {
125 let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
126 Ok(LinsolveEval::new(
127 tensor::tensor_into_value(solution),
128 Some(rcond),
129 ))
130 }
131 NumericPair::Complex(lhs_c, rhs_c) => {
132 let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
133 Ok(LinsolveEval::new(
134 Value::ComplexTensor(solution),
135 Some(rcond),
136 ))
137 }
138 }
139}
140
141pub fn linsolve_host_real_for_provider(
143 lhs: &Tensor,
144 rhs: &Tensor,
145 options: &ProviderLinsolveOptions,
146) -> BuiltinResult<(Tensor, f64)> {
147 let opts = SolveOptions::from(options);
148 solve_real(lhs.clone(), rhs.clone(), &opts)
149}
150
151#[derive(Clone)]
153pub struct LinsolveEval {
154 solution: Value,
155 rcond: Option<f64>,
156}
157
158impl LinsolveEval {
159 fn new(solution: Value, rcond: Option<f64>) -> Self {
160 Self { solution, rcond }
161 }
162
163 pub fn solution(&self) -> Value {
165 self.solution.clone()
166 }
167
168 pub fn reciprocal_condition(&self) -> Value {
170 match self.rcond {
171 Some(r) => Value::Num(r),
172 None => Value::Num(f64::NAN),
173 }
174 }
175}
176
177#[derive(Clone, Default)]
178pub struct SolveOptions {
179 lower: bool,
180 upper: bool,
181 rectangular: bool,
182 transposed: bool,
183 conjugate: bool,
184 symmetric: bool,
185 posdef: bool,
186 rcond: Option<f64>,
187}
188
189impl From<&SolveOptions> for ProviderLinsolveOptions {
190 fn from(opts: &SolveOptions) -> Self {
191 Self {
192 lower: opts.lower,
193 upper: opts.upper,
194 rectangular: opts.rectangular,
195 transposed: opts.transposed,
196 conjugate: opts.conjugate,
197 symmetric: opts.symmetric,
198 posdef: opts.posdef,
199 need_rcond: false,
200 rcond: opts.rcond,
201 }
202 }
203}
204
205impl From<&ProviderLinsolveOptions> for SolveOptions {
206 fn from(opts: &ProviderLinsolveOptions) -> Self {
207 Self {
208 lower: opts.lower,
209 upper: opts.upper,
210 rectangular: opts.rectangular,
211 transposed: opts.transposed,
212 conjugate: opts.conjugate,
213 symmetric: opts.symmetric,
214 posdef: opts.posdef,
215 rcond: opts.rcond,
216 }
217 }
218}
219
220fn options_from_rest(rest: &[Value]) -> BuiltinResult<SolveOptions> {
221 match rest.len() {
222 0 => Ok(SolveOptions::default()),
223 1 => parse_options(&rest[0]),
224 _ => Err(builtin_error("linsolve: too many input arguments")),
225 }
226}
227
228pub async fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> BuiltinResult<LinsolveEval> {
230 let options = options_from_rest(rest)?;
231 evaluate(lhs, rhs, options).await
232}
233
234async fn try_gpu_linsolve(
235 lhs: &Value,
236 rhs: &Value,
237 options: &SolveOptions,
238) -> BuiltinResult<Option<LinsolveEval>> {
239 if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
240 return Ok(None);
241 }
242 let provider = match runmat_accelerate_api::provider() {
243 Some(p) => p,
244 None => return Ok(None),
245 };
246
247 if contains_complex(lhs) || contains_complex(rhs) {
248 return Ok(None);
249 }
250
251 let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
252 Some(op) => op,
253 None => return Ok(None),
254 };
255 let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
256 Some(op) => op,
257 None => {
258 release_operand(provider, &mut lhs_operand);
259 return Ok(None);
260 }
261 };
262
263 if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
264 release_operand(provider, &mut lhs_operand);
265 release_operand(provider, &mut rhs_operand);
266 return Ok(None);
267 }
268
269 let mut provider_opts: ProviderLinsolveOptions = options.into();
270 provider_opts.need_rcond =
271 matches!(crate::output_count::current_output_count(), Some(2)) || options.rcond.is_some();
272 let result = provider
273 .linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
274 .await
275 .ok();
276
277 release_operand(provider, &mut lhs_operand);
278 release_operand(provider, &mut rhs_operand);
279
280 if let Some(ProviderLinsolveResult {
281 solution,
282 reciprocal_condition,
283 }) = result
284 {
285 let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
286 return Ok(Some(eval));
287 }
288
289 Ok(None)
290}
291
292fn parse_options(value: &Value) -> BuiltinResult<SolveOptions> {
293 let struct_val = match value {
294 Value::Struct(s) => s,
295 other => {
296 return Err(builtin_error(format!(
297 "linsolve: opts must be a struct, got {other:?}"
298 )))
299 }
300 };
301 let mut opts = SolveOptions::default();
302 for (key, raw_value) in &struct_val.fields {
303 let name = key.to_ascii_uppercase();
304 match name.as_str() {
305 "LT" => opts.lower = parse_bool_field("LT", raw_value)?,
306 "UT" => opts.upper = parse_bool_field("UT", raw_value)?,
307 "RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
308 "SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
309 "POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
310 "TRANSA" => {
311 let transa = parse_transa(raw_value)?;
312 opts.transposed = transa != TransposeMode::None;
313 opts.conjugate = transa == TransposeMode::Conjugate;
314 }
315 "RCOND" => {
316 let threshold = parse_scalar_f64("RCOND", raw_value)?;
317 if threshold < 0.0 {
318 return Err(builtin_error("linsolve: RCOND must be non-negative"));
319 }
320 opts.rcond = Some(threshold);
321 }
322 other => return Err(builtin_error(format!("linsolve: unknown option '{other}'"))),
323 }
324 }
325 if opts.lower && opts.upper {
326 return Err(builtin_error("linsolve: LT and UT are mutually exclusive."));
327 }
328 Ok(opts)
329}
330
331fn parse_bool_field(name: &str, value: &Value) -> BuiltinResult<bool> {
332 match value {
333 Value::Bool(b) => Ok(*b),
334 Value::Int(i) => Ok(!i.is_zero()),
335 Value::Num(n) => Ok(*n != 0.0),
336 Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
337 Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
338 other => Err(builtin_error(format!(
339 "linsolve: option '{name}' must be logical or numeric, got {other:?}"
340 ))),
341 }
342}
343
344fn parse_scalar_f64(name: &str, value: &Value) -> BuiltinResult<f64> {
345 match value {
346 Value::Num(n) => Ok(*n),
347 Value::Int(i) => Ok(i.to_f64()),
348 Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
349 other => Err(builtin_error(format!(
350 "linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
351 ))),
352 }
353}
354
355#[derive(Copy, Clone, PartialEq, Eq)]
356enum TransposeMode {
357 None,
358 Transpose,
359 Conjugate,
360}
361
362fn parse_transa(value: &Value) -> BuiltinResult<TransposeMode> {
363 let text = tensor::value_to_string(value).ok_or_else(|| {
364 builtin_error("linsolve: TRANSA must be a character vector or string scalar")
365 })?;
366 if text.is_empty() {
367 return Err(builtin_error("linsolve: TRANSA cannot be empty"));
368 }
369 match text.trim().to_ascii_uppercase().as_str() {
370 "N" => Ok(TransposeMode::None),
371 "T" => Ok(TransposeMode::Transpose),
372 "C" => Ok(TransposeMode::Conjugate),
373 other => Err(builtin_error(format!(
374 "linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
375 ))),
376 }
377}
378
379enum NumericInput {
380 Real(Tensor),
381 Complex(ComplexTensor),
382}
383
384enum NumericPair {
385 Real(Tensor, Tensor),
386 Complex(ComplexTensor, ComplexTensor),
387}
388
389async fn coerce_numeric_pair(lhs: Value, rhs: Value) -> BuiltinResult<NumericPair> {
390 let lhs_num = coerce_numeric(lhs).await?;
391 let rhs_num = coerce_numeric(rhs).await?;
392 match (lhs_num, rhs_num) {
393 (NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
394 Ok(NumericPair::Real(lhs_r, rhs_r))
395 }
396 (NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
397 Ok(NumericPair::Complex(lhs_c, rhs_c))
398 }
399 (NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
400 let rhs_c = promote_real_tensor(&rhs_r)?;
401 Ok(NumericPair::Complex(lhs_c, rhs_c))
402 }
403 (NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
404 let lhs_c = promote_real_tensor(&lhs_r)?;
405 Ok(NumericPair::Complex(lhs_c, rhs_c))
406 }
407 }
408}
409
410async fn coerce_numeric(value: Value) -> BuiltinResult<NumericInput> {
411 match value {
412 Value::Tensor(tensor) => {
413 ensure_matrix_shape(NAME, &tensor.shape)?;
414 Ok(NumericInput::Real(tensor))
415 }
416 Value::LogicalArray(logical) => {
417 let tensor = tensor::logical_to_tensor(&logical).map_err(builtin_error)?;
418 ensure_matrix_shape(NAME, &tensor.shape)?;
419 Ok(NumericInput::Real(tensor))
420 }
421 Value::Num(n) => {
422 let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(builtin_error)?;
423 Ok(NumericInput::Real(tensor))
424 }
425 Value::Int(i) => {
426 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(builtin_error)?;
427 Ok(NumericInput::Real(tensor))
428 }
429 Value::Bool(b) => {
430 let tensor =
431 Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1]).map_err(builtin_error)?;
432 Ok(NumericInput::Real(tensor))
433 }
434 Value::Complex(re, im) => {
435 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(builtin_error)?;
436 Ok(NumericInput::Complex(tensor))
437 }
438 Value::ComplexTensor(ct) => {
439 ensure_matrix_shape(NAME, &ct.shape)?;
440 Ok(NumericInput::Complex(ct))
441 }
442 Value::GpuTensor(handle) => {
443 let tensor = gpu_helpers::gather_tensor_async(&handle)
444 .await
445 .map_err(map_control_flow)?;
446 ensure_matrix_shape(NAME, &tensor.shape)?;
447 Ok(NumericInput::Real(tensor))
448 }
449 other => Err(builtin_error(format!(
450 "{NAME}: unsupported input type {:?}; convert to numeric values first",
451 other
452 ))),
453 }
454}
455
456fn contains_complex(value: &Value) -> bool {
457 matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
458}
459
460fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
461 crate::builtins::common::shape::is_scalar_shape(&handle.shape)
462}
463
464struct PreparedOperand {
465 handle: GpuTensorHandle,
466 owned: bool,
467}
468
469impl PreparedOperand {
470 fn borrowed(handle: &GpuTensorHandle) -> Self {
471 Self {
472 handle: handle.clone(),
473 owned: false,
474 }
475 }
476
477 fn owned(handle: GpuTensorHandle) -> Self {
478 Self {
479 handle,
480 owned: true,
481 }
482 }
483
484 fn handle(&self) -> &GpuTensorHandle {
485 &self.handle
486 }
487}
488
489fn prepare_gpu_operand(
490 value: &Value,
491 provider: &'static dyn AccelProvider,
492) -> BuiltinResult<Option<PreparedOperand>> {
493 match value {
494 Value::GpuTensor(handle) => {
495 if is_scalar_handle(handle) {
496 Ok(None)
497 } else {
498 Ok(Some(PreparedOperand::borrowed(handle)))
499 }
500 }
501 Value::Tensor(tensor) => {
502 if tensor::is_scalar_tensor(tensor) {
503 Ok(None)
504 } else {
505 let uploaded = upload_tensor(provider, tensor)?;
506 Ok(Some(PreparedOperand::owned(uploaded)))
507 }
508 }
509 Value::LogicalArray(logical) => {
510 if logical.data.len() == 1 {
511 Ok(None)
512 } else {
513 let tensor = tensor::logical_to_tensor(logical).map_err(builtin_error)?;
514 let uploaded = upload_tensor(provider, &tensor)?;
515 Ok(Some(PreparedOperand::owned(uploaded)))
516 }
517 }
518 _ => Ok(None),
519 }
520}
521
522fn upload_tensor(
523 provider: &'static dyn AccelProvider,
524 tensor: &Tensor,
525) -> BuiltinResult<GpuTensorHandle> {
526 let view = HostTensorView {
527 data: &tensor.data,
528 shape: &tensor.shape,
529 };
530 provider
531 .upload(&view)
532 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
533}
534
535fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
536 if operand.owned {
537 let _ = provider.free(&operand.handle);
538 operand.owned = false;
539 }
540}
541
542fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> BuiltinResult<(Tensor, f64)> {
543 let mut lhs_effective = lhs;
544 let mut rhs_effective = rhs;
545 let mut lower = options.lower;
546 let mut upper = options.upper;
547
548 if options.transposed {
549 lhs_effective = transpose_tensor(&lhs_effective);
550 if options.conjugate {
551 conjugate_in_place(&mut lhs_effective);
552 }
553 if lower || upper {
554 std::mem::swap(&mut lower, &mut upper);
555 }
556 }
557
558 rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
559
560 if lower {
561 ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
562 let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
563 enforce_rcond(options, rcond)?;
564 return Ok((solution, rcond));
565 }
566
567 if upper {
568 ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
569 let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
570 enforce_rcond(options, rcond)?;
571 return Ok((solution, rcond));
572 }
573
574 let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
575 enforce_rcond(options, rcond)?;
576 Ok((solution, rcond))
577}
578
579fn solve_complex(
580 lhs: ComplexTensor,
581 rhs: ComplexTensor,
582 options: &SolveOptions,
583) -> BuiltinResult<(ComplexTensor, f64)> {
584 let mut lhs_effective = lhs;
585 let mut rhs_effective = rhs;
586 let mut lower = options.lower;
587 let mut upper = options.upper;
588
589 if options.transposed {
590 lhs_effective = transpose_complex(&lhs_effective);
591 if options.conjugate {
592 conjugate_complex_in_place(&mut lhs_effective);
593 }
594 if lower || upper {
595 std::mem::swap(&mut lower, &mut upper);
596 }
597 }
598
599 rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
600
601 if lower {
602 ensure_square(lhs_effective.rows, lhs_effective.cols)?;
603 let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
604 enforce_rcond(options, rcond)?;
605 return Ok((solution, rcond));
606 }
607
608 if upper {
609 ensure_square(lhs_effective.rows, lhs_effective.cols)?;
610 let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
611 enforce_rcond(options, rcond)?;
612 return Ok((solution, rcond));
613 }
614
615 let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
616 enforce_rcond(options, rcond)?;
617 Ok((solution, rcond))
618}
619
620fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
621 let n = lhs.rows();
622 let nrhs = rhs.data.len() / n;
623 let mut solution = rhs.data.clone();
624 let mut min_diag = f64::INFINITY;
625 let mut max_diag = 0.0_f64;
626
627 for col in 0..nrhs {
628 for i in 0..n {
629 let diag = lhs.data[i + i * n];
630 let diag_abs = diag.abs();
631 min_diag = min_diag.min(diag_abs);
632 max_diag = max_diag.max(diag_abs);
633 if diag_abs == 0.0 {
634 return Err(builtin_error(
635 "linsolve: matrix is singular to working precision.",
636 ));
637 }
638 let mut accum = 0.0;
639 for j in 0..i {
640 accum += lhs.data[i + j * n] * solution[j + col * n];
641 }
642 let rhs_value = solution[i + col * n] - accum;
643 solution[i + col * n] = rhs_value / diag;
644 }
645 }
646
647 let rcond = diagonal_rcond(min_diag, max_diag);
648 let tensor = Tensor::new(solution, rhs.shape.clone())
649 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
650 Ok((tensor, rcond))
651}
652
653fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
654 let n = lhs.rows();
655 let nrhs = rhs.data.len() / n;
656 let mut solution = rhs.data.clone();
657 let mut min_diag = f64::INFINITY;
658 let mut max_diag = 0.0_f64;
659
660 for col in 0..nrhs {
661 for row_rev in 0..n {
662 let i = n - 1 - row_rev;
663 let diag = lhs.data[i + i * n];
664 let diag_abs = diag.abs();
665 min_diag = min_diag.min(diag_abs);
666 max_diag = max_diag.max(diag_abs);
667 if diag_abs == 0.0 {
668 return Err(builtin_error(
669 "linsolve: matrix is singular to working precision.",
670 ));
671 }
672 let mut accum = 0.0;
673 for j in (i + 1)..n {
674 accum += lhs.data[i + j * n] * solution[j + col * n];
675 }
676 let rhs_value = solution[i + col * n] - accum;
677 solution[i + col * n] = rhs_value / diag;
678 }
679 }
680
681 let rcond = diagonal_rcond(min_diag, max_diag);
682 let tensor = Tensor::new(solution, rhs.shape.clone())
683 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
684 Ok((tensor, rcond))
685}
686
687fn forward_substitution_complex(
688 lhs: &ComplexTensor,
689 rhs: &ComplexTensor,
690) -> BuiltinResult<(ComplexTensor, f64)> {
691 let n = lhs.rows;
692 let nrhs = rhs.data.len() / n;
693 let lhs_data: Vec<Complex64> = lhs
694 .data
695 .iter()
696 .map(|&(re, im)| Complex64::new(re, im))
697 .collect();
698 let mut solution: Vec<Complex64> = rhs
699 .data
700 .iter()
701 .map(|&(re, im)| Complex64::new(re, im))
702 .collect();
703 let mut min_diag = f64::INFINITY;
704 let mut max_diag = 0.0_f64;
705
706 for col in 0..nrhs {
707 for i in 0..n {
708 let diag = lhs_data[i + i * n];
709 let diag_abs = diag.norm();
710 min_diag = min_diag.min(diag_abs);
711 max_diag = max_diag.max(diag_abs);
712 if diag_abs == 0.0 {
713 return Err(builtin_error(
714 "linsolve: matrix is singular to working precision.",
715 ));
716 }
717 let mut accum = Complex64::new(0.0, 0.0);
718 for j in 0..i {
719 accum += lhs_data[i + j * n] * solution[j + col * n];
720 }
721 let rhs_value = solution[i + col * n] - accum;
722 solution[i + col * n] = rhs_value / diag;
723 }
724 }
725
726 let rcond = diagonal_rcond(min_diag, max_diag);
727 let tensor = ComplexTensor::new(
728 solution.iter().map(|c| (c.re, c.im)).collect(),
729 rhs.shape.clone(),
730 )
731 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
732 Ok((tensor, rcond))
733}
734
735fn backward_substitution_complex(
736 lhs: &ComplexTensor,
737 rhs: &ComplexTensor,
738) -> BuiltinResult<(ComplexTensor, f64)> {
739 let n = lhs.rows;
740 let nrhs = rhs.data.len() / n;
741 let lhs_data: Vec<Complex64> = lhs
742 .data
743 .iter()
744 .map(|&(re, im)| Complex64::new(re, im))
745 .collect();
746 let mut solution: Vec<Complex64> = rhs
747 .data
748 .iter()
749 .map(|&(re, im)| Complex64::new(re, im))
750 .collect();
751 let mut min_diag = f64::INFINITY;
752 let mut max_diag = 0.0_f64;
753
754 for col in 0..nrhs {
755 for row_rev in 0..n {
756 let i = n - 1 - row_rev;
757 let diag = lhs_data[i + i * n];
758 let diag_abs = diag.norm();
759 min_diag = min_diag.min(diag_abs);
760 max_diag = max_diag.max(diag_abs);
761 if diag_abs == 0.0 {
762 return Err(builtin_error(
763 "linsolve: matrix is singular to working precision.",
764 ));
765 }
766 let mut accum = Complex64::new(0.0, 0.0);
767 for j in (i + 1)..n {
768 accum += lhs_data[i + j * n] * solution[j + col * n];
769 }
770 let rhs_value = solution[i + col * n] - accum;
771 solution[i + col * n] = rhs_value / diag;
772 }
773 }
774
775 let rcond = diagonal_rcond(min_diag, max_diag);
776 let tensor = ComplexTensor::new(
777 solution.iter().map(|c| (c.re, c.im)).collect(),
778 rhs.shape.clone(),
779 )
780 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
781 Ok((tensor, rcond))
782}
783
784fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
785 let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
786 let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
787 let svd = SVD::new(a.clone(), true, true);
788 let rcond = singular_value_rcond(svd.singular_values.as_slice());
789 let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
790 let solution = svd
791 .solve(&b, tol)
792 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
793 let tensor = matrix_real_to_tensor(solution)?;
794 Ok((tensor, rcond))
795}
796
797fn solve_general_complex(
798 lhs: &ComplexTensor,
799 rhs: &ComplexTensor,
800) -> BuiltinResult<(ComplexTensor, f64)> {
801 let a_data: Vec<Complex64> = lhs
802 .data
803 .iter()
804 .map(|&(re, im)| Complex64::new(re, im))
805 .collect();
806 let b_data: Vec<Complex64> = rhs
807 .data
808 .iter()
809 .map(|&(re, im)| Complex64::new(re, im))
810 .collect();
811 let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
812 let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
813 let svd = SVD::new(a.clone(), true, true);
814 let rcond = singular_value_rcond(svd.singular_values.as_slice());
815 let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
816 let solution = svd
817 .solve(&b, tol)
818 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
819 let tensor = matrix_complex_to_tensor(solution)?;
820 Ok((tensor, rcond))
821}
822
823fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> BuiltinResult<Tensor> {
824 if rhs.rows() == expected_rows {
825 return Ok(rhs);
826 }
827 if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
828 return Tensor::new(rhs.data, vec![expected_rows, 1])
829 .map_err(|e| builtin_error(format!("{NAME}: {e}")));
830 }
831 if rhs.data.is_empty() && expected_rows == 0 {
832 return Ok(rhs);
833 }
834 Err(builtin_error("Matrix dimensions must agree."))
835}
836
837fn normalize_rhs_complex(rhs: ComplexTensor, expected_rows: usize) -> BuiltinResult<ComplexTensor> {
838 if rhs.rows == expected_rows {
839 return Ok(rhs);
840 }
841 if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
842 return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
843 .map_err(|e| builtin_error(format!("{NAME}: {e}")));
844 }
845 if rhs.data.is_empty() && expected_rows == 0 {
846 return Ok(rhs);
847 }
848 Err(builtin_error("Matrix dimensions must agree."))
849}
850
851fn enforce_rcond(options: &SolveOptions, rcond: f64) -> BuiltinResult<()> {
852 if let Some(threshold) = options.rcond {
853 if rcond < threshold {
854 return Err(builtin_error(
855 "linsolve: matrix is singular to working precision.",
856 ));
857 }
858 }
859 Ok(())
860}
861
862fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
863 let max_sv = singular_values
864 .iter()
865 .copied()
866 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
867 let max_dim = rows.max(cols) as f64;
868 f64::EPSILON * max_dim * max_sv.max(1.0)
869}
870
871fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> BuiltinResult<Tensor> {
872 let rows = matrix.nrows();
873 let cols = matrix.ncols();
874 Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols])
875 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
876}
877
878fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> BuiltinResult<ComplexTensor> {
879 let rows = matrix.nrows();
880 let cols = matrix.ncols();
881 let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
882 ComplexTensor::new(data, vec![rows, cols]).map_err(|e| builtin_error(format!("{NAME}: {e}")))
883}
884
885fn promote_real_tensor(tensor: &Tensor) -> BuiltinResult<ComplexTensor> {
886 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
887 ComplexTensor::new(data, tensor.shape.clone())
888 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
889}
890
891fn ensure_matrix_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
892 if is_effectively_matrix(shape) {
893 Ok(())
894 } else {
895 Err(builtin_error(format!(
896 "{name}: inputs must be 2-D matrices or vectors"
897 )))
898 }
899}
900
901fn is_effectively_matrix(shape: &[usize]) -> bool {
902 match shape.len() {
903 0..=2 => true,
904 _ => shape.iter().skip(2).all(|&dim| dim == 1),
905 }
906}
907
908fn ensure_square(rows: usize, cols: usize) -> BuiltinResult<()> {
909 if rows == cols {
910 Ok(())
911 } else {
912 Err(builtin_error(
913 "linsolve: triangular solves require a square coefficient matrix.",
914 ))
915 }
916}
917
918fn transpose_tensor(tensor: &Tensor) -> Tensor {
919 let rows = tensor.rows();
920 let cols = tensor.cols();
921 let mut data = vec![0.0; tensor.data.len()];
922 for r in 0..rows {
923 for c in 0..cols {
924 data[c + r * cols] = tensor.data[r + c * rows];
925 }
926 }
927 Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
928}
929
930fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
931 let rows = tensor.rows;
932 let cols = tensor.cols;
933 let mut data = vec![(0.0, 0.0); tensor.data.len()];
934 for r in 0..rows {
935 for c in 0..cols {
936 data[c + r * cols] = tensor.data[r + c * rows];
937 }
938 }
939 ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
940}
941
942fn conjugate_in_place(_tensor: &mut Tensor) {
943 }
945
946fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
947 for value in &mut tensor.data {
948 value.1 = -value.1;
949 }
950}
951
952#[cfg(test)]
953pub(crate) mod tests {
954 use super::*;
955 use futures::executor::block_on;
956 use runmat_accelerate_api::HostTensorView;
957 use runmat_builtins::{CharArray, ResolveContext, StructValue, Type};
958 fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
959 err
960 }
961
962 fn approx_eq(actual: f64, expected: f64) {
963 assert!((actual - expected).abs() < 1e-7);
964 }
965
966 fn evaluate_args(a: Value, b: Value, rest: &[Value]) -> Result<LinsolveEval, RuntimeError> {
967 block_on(super::evaluate_args(a, b, rest))
968 }
969
970 #[test]
971 fn linsolve_type_uses_rhs_columns() {
972 let out = left_divide_type(
973 &[
974 Type::Tensor {
975 shape: Some(vec![Some(2), Some(2)]),
976 },
977 Type::Tensor {
978 shape: Some(vec![Some(2), Some(3)]),
979 },
980 ],
981 &ResolveContext::new(Vec::new()),
982 );
983 assert_eq!(
984 out,
985 Type::Tensor {
986 shape: Some(vec![Some(2), Some(3)])
987 }
988 );
989 }
990
991 use crate::builtins::common::test_support;
992 use runmat_accelerate_api::ProviderTelemetry;
993
994 fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
995 block_on(super::linsolve_builtin(lhs, rhs, rest))
996 }
997
998 fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> BuiltinResult<LinsolveEval> {
999 block_on(super::evaluate(lhs, rhs, options))
1000 }
1001
1002 fn fallback_count(telemetry: &ProviderTelemetry, reason: &str) -> u64 {
1003 telemetry
1004 .solve_fallbacks
1005 .iter()
1006 .find(|entry| entry.reason == reason)
1007 .map(|entry| entry.count)
1008 .unwrap_or(0)
1009 }
1010
1011 #[cfg(feature = "wgpu")]
1012 fn kernel_launch_count(telemetry: &ProviderTelemetry, kernel: &str) -> usize {
1013 telemetry
1014 .kernel_launches
1015 .iter()
1016 .filter(|entry| entry.kernel == kernel)
1017 .count()
1018 }
1019
1020 fn clear_accel_provider_state() {
1021 runmat_accelerate_api::set_thread_provider(None);
1022 runmat_accelerate_api::clear_provider();
1023 }
1024
1025 fn host_linsolve_real(
1026 a: &Tensor,
1027 b: &Tensor,
1028 options: ProviderLinsolveOptions,
1029 ) -> (Tensor, f64) {
1030 super::linsolve_host_real_for_provider(a, b, &options).expect("host linsolve")
1031 }
1032
1033 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1034 #[test]
1035 fn linsolve_basic_square() {
1036 let _accel_guard = test_support::accel_test_lock();
1037 clear_accel_provider_state();
1038 let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
1039 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1040 let result =
1041 linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
1042 let t = test_support::gather(result).expect("gather");
1043 assert_eq!(t.shape, vec![2, 1]);
1044 approx_eq(t.data[0], 1.0);
1045 approx_eq(t.data[1], 2.0);
1046 }
1047
1048 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1049 #[test]
1050 fn linsolve_lower_triangular_hint() {
1051 let a = Tensor::new(
1052 vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
1053 vec![3, 3],
1054 )
1055 .unwrap();
1056 let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
1057 let mut opts = StructValue::new();
1058 opts.fields.insert("LT".to_string(), Value::Bool(true));
1059 let result = linsolve_builtin(
1060 Value::Tensor(a),
1061 Value::Tensor(b),
1062 vec![Value::Struct(opts)],
1063 )
1064 .expect("linsolve");
1065 let tensor = test_support::gather(result).expect("gather");
1066 assert_eq!(tensor.shape, vec![3, 1]);
1067 approx_eq(tensor.data[0], 3.0);
1068 approx_eq(tensor.data[1], 2.0);
1069 approx_eq(tensor.data[2], 1.0);
1070 }
1071
1072 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1073 #[test]
1074 fn linsolve_transposed_triangular_hint() {
1075 let _accel_guard = test_support::accel_test_lock();
1076 clear_accel_provider_state();
1077 let a = Tensor::new(
1078 vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
1079 vec![3, 3],
1080 )
1081 .unwrap();
1082 let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
1083 let mut opts = StructValue::new();
1084 opts.fields.insert("LT".to_string(), Value::Bool(true));
1085 opts.fields.insert(
1086 "TRANSA".to_string(),
1087 Value::CharArray(CharArray::new_row("T")),
1088 );
1089
1090 let result = linsolve_builtin(
1091 Value::Tensor(a.clone()),
1092 Value::Tensor(b.clone()),
1093 vec![Value::Struct(opts)],
1094 )
1095 .expect("linsolve");
1096 let tensor = test_support::gather(result).expect("gather");
1097 assert_eq!(tensor.shape, vec![3, 1]);
1098
1099 let a_transposed = transpose_tensor(&a);
1100 let (expected_tensor, _) =
1101 host_linsolve_real(&a_transposed, &b, ProviderLinsolveOptions::default());
1102
1103 for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
1104 approx_eq(*actual, *expected);
1105 }
1106 }
1107
1108 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1109 #[test]
1110 fn linsolve_complex_inputs_match_residual() {
1111 let a = ComplexTensor::new(
1112 vec![(2.0, 1.0), (-1.0, 0.0), (1.0, -2.0), (3.0, -2.0)],
1113 vec![2, 2],
1114 )
1115 .unwrap();
1116 let b = ComplexTensor::new(vec![(1.0, 0.0), (4.0, 1.0)], vec![2, 1]).unwrap();
1117 let result = linsolve_builtin(
1118 Value::ComplexTensor(a.clone()),
1119 Value::ComplexTensor(b.clone()),
1120 Vec::new(),
1121 )
1122 .expect("linsolve");
1123 let Value::ComplexTensor(out) = result else {
1124 panic!("expected complex tensor result");
1125 };
1126
1127 let mat_a: Vec<Complex64> = a
1128 .data
1129 .iter()
1130 .map(|&(re, im)| Complex64::new(re, im))
1131 .collect();
1132 let mat_b: Vec<Complex64> = b
1133 .data
1134 .iter()
1135 .map(|&(re, im)| Complex64::new(re, im))
1136 .collect();
1137 let mat_x: Vec<Complex64> = out
1138 .data
1139 .iter()
1140 .map(|&(re, im)| Complex64::new(re, im))
1141 .collect();
1142 let a_mat = DMatrix::from_column_slice(a.rows, a.cols, &mat_a);
1143 let b_mat = DMatrix::from_column_slice(b.rows, b.cols, &mat_b);
1144 let x_mat = DMatrix::from_column_slice(out.rows, out.cols, &mat_x);
1145 let residual = a_mat * x_mat - b_mat;
1146 assert!(residual.norm() < 1e-10, "residual={}", residual.norm());
1147 }
1148
1149 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1150 #[test]
1151 fn linsolve_complex_conjugate_transpose_matches_explicit_reference() {
1152 let a = ComplexTensor::new(
1153 vec![(2.0, 1.0), (0.0, -1.0), (1.0, 2.0), (3.0, 0.5)],
1154 vec![2, 2],
1155 )
1156 .unwrap();
1157 let b = ComplexTensor::new(vec![(1.0, -1.0), (2.0, 0.5)], vec![2, 1]).unwrap();
1158
1159 let mut opts = StructValue::new();
1160 opts.fields.insert(
1161 "TRANSA".to_string(),
1162 Value::CharArray(CharArray::new_row("C")),
1163 );
1164 let result = linsolve_builtin(
1165 Value::ComplexTensor(a.clone()),
1166 Value::ComplexTensor(b.clone()),
1167 vec![Value::Struct(opts)],
1168 )
1169 .expect("linsolve");
1170 let Value::ComplexTensor(out) = result else {
1171 panic!("expected complex tensor result");
1172 };
1173
1174 let mut a_conj_t = transpose_complex(&a);
1175 conjugate_complex_in_place(&mut a_conj_t);
1176 let reference = evaluate(
1177 Value::ComplexTensor(a_conj_t),
1178 Value::ComplexTensor(b.clone()),
1179 SolveOptions::default(),
1180 )
1181 .expect("reference");
1182 let Value::ComplexTensor(expected) = reference.solution() else {
1183 panic!("expected complex tensor reference");
1184 };
1185
1186 assert_eq!(out.shape, expected.shape);
1187 for ((out_re, out_im), (exp_re, exp_im)) in out.data.iter().zip(expected.data.iter()) {
1188 assert!(
1189 (out_re - exp_re).abs() < 1e-10,
1190 "out_re={out_re} exp_re={exp_re}"
1191 );
1192 assert!(
1193 (out_im - exp_im).abs() < 1e-10,
1194 "out_im={out_im} exp_im={exp_im}"
1195 );
1196 }
1197 }
1198
1199 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1200 #[test]
1201 fn linsolve_rcond_enforced() {
1202 let _accel_guard = test_support::accel_test_lock();
1203 clear_accel_provider_state();
1204 let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
1205 let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
1206 let mut opts = StructValue::new();
1207 opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
1208 let err = unwrap_error(
1209 linsolve_builtin(
1210 Value::Tensor(a),
1211 Value::Tensor(b),
1212 vec![Value::Struct(opts)],
1213 )
1214 .expect_err("singular matrix must fail"),
1215 );
1216 assert!(
1217 err.message().contains("singular to working precision"),
1218 "unexpected error message: {err}"
1219 );
1220 }
1221
1222 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1223 #[test]
1224 fn linsolve_recovers_rcond_output() {
1225 let _accel_guard = test_support::accel_test_lock();
1226 clear_accel_provider_state();
1227 let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1228 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1229 let eval = evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1230 .expect("evaluate");
1231 let solution_tensor = match eval.solution() {
1232 Value::Tensor(sol) => sol.clone(),
1233 Value::GpuTensor(handle) => {
1234 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
1235 }
1236 other => panic!("unexpected solution value {other:?}"),
1237 };
1238 assert_eq!(solution_tensor.shape, vec![2, 1]);
1239 approx_eq(solution_tensor.data[0], 1.0);
1240 approx_eq(solution_tensor.data[1], 2.0);
1241
1242 let rcond_value = match eval.reciprocal_condition() {
1243 Value::Num(r) => r,
1244 Value::GpuTensor(handle) => {
1245 let gathered =
1246 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
1247 gathered.data[0]
1248 }
1249 other => panic!("unexpected rcond value {other:?}"),
1250 };
1251 approx_eq(rcond_value, 1.0);
1252 }
1253
1254 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1255 #[test]
1256 fn gpu_round_trip_matches_cpu() {
1257 test_support::with_test_provider(|provider| {
1258 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1259 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1260
1261 let cpu = linsolve_builtin(
1262 Value::Tensor(a.clone()),
1263 Value::Tensor(b.clone()),
1264 Vec::new(),
1265 )
1266 .expect("cpu linsolve");
1267 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1268
1269 let view_a = HostTensorView {
1270 data: &a.data,
1271 shape: &a.shape,
1272 };
1273 let view_b = HostTensorView {
1274 data: &b.data,
1275 shape: &b.shape,
1276 };
1277 let ha = provider.upload(&view_a).expect("upload A");
1278 let hb = provider.upload(&view_b).expect("upload B");
1279
1280 let gpu_value = linsolve_builtin(
1281 Value::GpuTensor(ha.clone()),
1282 Value::GpuTensor(hb.clone()),
1283 Vec::new(),
1284 )
1285 .expect("gpu linsolve");
1286 let gathered = test_support::gather(gpu_value).expect("gather");
1287 let _ = provider.free(&ha);
1288 let _ = provider.free(&hb);
1289
1290 assert_eq!(gathered.shape, cpu_tensor.shape);
1291 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1292 assert!((gpu - cpu).abs() < 1e-12);
1293 }
1294 });
1295 }
1296
1297 #[test]
1298 fn host_inputs_auto_promote_into_provider_solve_path() {
1299 test_support::with_test_provider(|provider| {
1300 provider.reset_telemetry();
1301 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1302 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1303 let _ = linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1304 .expect("host linsolve");
1305 let telemetry = provider.telemetry_snapshot();
1306 assert_eq!(telemetry.linsolve.count, 1);
1307 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1308 assert!(telemetry.upload_bytes > 0);
1309 assert!(telemetry.download_bytes > 0);
1310 });
1311 }
1312
1313 #[test]
1314 fn provider_telemetry_records_gpu_host_reupload_path() {
1315 test_support::with_test_provider(|provider| {
1316 provider.reset_telemetry();
1317 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1318 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1319 let ha = provider
1320 .upload(&HostTensorView {
1321 data: &a.data,
1322 shape: &a.shape,
1323 })
1324 .expect("upload A");
1325 let hb = provider
1326 .upload(&HostTensorView {
1327 data: &b.data,
1328 shape: &b.shape,
1329 })
1330 .expect("upload B");
1331
1332 let _ = linsolve_builtin(
1333 Value::GpuTensor(ha.clone()),
1334 Value::GpuTensor(hb.clone()),
1335 Vec::new(),
1336 )
1337 .expect("gpu linsolve");
1338
1339 let telemetry = provider.telemetry_snapshot();
1340 assert_eq!(telemetry.linsolve.count, 1);
1341 assert!(telemetry.upload_bytes > 0);
1342 assert!(telemetry.download_bytes > 0);
1343 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1344
1345 let _ = provider.free(&ha);
1346 let _ = provider.free(&hb);
1347 });
1348 }
1349
1350 #[test]
1351 fn scalar_gpu_inputs_fall_back_without_provider_solve_dispatch() {
1352 test_support::with_test_provider(|provider| {
1353 provider.reset_telemetry();
1354 let a = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1355 let b = Tensor::new(vec![6.0], vec![1, 1]).unwrap();
1356 let ha = provider
1357 .upload(&HostTensorView {
1358 data: &a.data,
1359 shape: &a.shape,
1360 })
1361 .expect("upload A");
1362 let hb = provider
1363 .upload(&HostTensorView {
1364 data: &b.data,
1365 shape: &b.shape,
1366 })
1367 .expect("upload B");
1368
1369 let result = linsolve_builtin(
1370 Value::GpuTensor(ha.clone()),
1371 Value::GpuTensor(hb.clone()),
1372 Vec::new(),
1373 )
1374 .expect("fallback linsolve");
1375 let gathered = test_support::gather(result).expect("gather fallback");
1376 assert_eq!(gathered.data, vec![3.0]);
1377
1378 let telemetry = provider.telemetry_snapshot();
1379 assert_eq!(telemetry.linsolve.count, 0);
1380 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1381 assert!(telemetry.download_bytes > 0);
1382
1383 let _ = provider.free(&ha);
1384 let _ = provider.free(&hb);
1385 });
1386 }
1387
1388 #[cfg(feature = "wgpu")]
1389 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1390 #[test]
1391 fn wgpu_square_linsolve_avoids_host_reupload_fallback() {
1392 let _accel_guard = test_support::accel_test_lock();
1393 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1394 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1395 );
1396 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1397 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1398 return;
1399 }
1400 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1401 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1402
1403 let cpu = linsolve_builtin(
1404 Value::Tensor(a.clone()),
1405 Value::Tensor(b.clone()),
1406 Vec::new(),
1407 )
1408 .expect("cpu linsolve");
1409 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1410 provider.reset_telemetry();
1411
1412 let ha = provider
1413 .upload(&HostTensorView {
1414 data: &a.data,
1415 shape: &a.shape,
1416 })
1417 .expect("upload A");
1418 let hb = provider
1419 .upload(&HostTensorView {
1420 data: &b.data,
1421 shape: &b.shape,
1422 })
1423 .expect("upload B");
1424
1425 let _output_guard = crate::output_count::push_output_count(Some(1));
1426 let gpu_value = linsolve_builtin(
1427 Value::GpuTensor(ha.clone()),
1428 Value::GpuTensor(hb.clone()),
1429 Vec::new(),
1430 )
1431 .expect("gpu square linsolve");
1432 let gpu_solution = match gpu_value {
1433 Value::OutputList(mut outputs) => outputs.remove(0),
1434 other => other,
1435 };
1436 let gathered = test_support::gather(gpu_solution).expect("gather");
1437 let _ = provider.free(&ha);
1438 let _ = provider.free(&hb);
1439
1440 assert_eq!(gathered.shape, cpu_tensor.shape);
1441 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1442 assert!((gpu - cpu).abs() < 1e-4);
1443 }
1444
1445 let telemetry = provider.telemetry_snapshot();
1446 assert_eq!(telemetry.linsolve.count, 1);
1447 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1448 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 0);
1449 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1450 }
1451
1452 #[cfg(feature = "wgpu")]
1453 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1454 #[test]
1455 fn wgpu_square_linsolve_uses_device_path_without_output_count() {
1456 let _accel_guard = test_support::accel_test_lock();
1457 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1458 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1459 );
1460 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1461 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1462 return;
1463 }
1464 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1465 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1466
1467 let cpu = linsolve_builtin(
1468 Value::Tensor(a.clone()),
1469 Value::Tensor(b.clone()),
1470 Vec::new(),
1471 )
1472 .expect("cpu linsolve");
1473 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1474 provider.reset_telemetry();
1475
1476 let ha = provider
1477 .upload(&HostTensorView {
1478 data: &a.data,
1479 shape: &a.shape,
1480 })
1481 .expect("upload A");
1482 let hb = provider
1483 .upload(&HostTensorView {
1484 data: &b.data,
1485 shape: &b.shape,
1486 })
1487 .expect("upload B");
1488
1489 let gpu_value = linsolve_builtin(
1490 Value::GpuTensor(ha.clone()),
1491 Value::GpuTensor(hb.clone()),
1492 Vec::new(),
1493 )
1494 .expect("gpu square linsolve");
1495 let gathered = test_support::gather(gpu_value).expect("gather");
1496 let _ = provider.free(&ha);
1497 let _ = provider.free(&hb);
1498
1499 assert_eq!(gathered.shape, cpu_tensor.shape);
1500 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1501 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1502 }
1503
1504 let telemetry = provider.telemetry_snapshot();
1505 assert_eq!(telemetry.linsolve.count, 1);
1506 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1507 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1508 }
1509
1510 #[cfg(feature = "wgpu")]
1511 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1512 #[test]
1513 fn wgpu_square_linsolve_recovers_rcond_output_on_device() {
1514 let _accel_guard = test_support::accel_test_lock();
1515 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1516 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1517 );
1518 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1519 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1520 return;
1521 }
1522 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1523 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1524
1525 let (_, cpu_rcond) = host_linsolve_real(&a, &b, ProviderLinsolveOptions::default());
1526 provider.reset_telemetry();
1527
1528 let ha = provider
1529 .upload(&HostTensorView {
1530 data: &a.data,
1531 shape: &a.shape,
1532 })
1533 .expect("upload A");
1534 let hb = provider
1535 .upload(&HostTensorView {
1536 data: &b.data,
1537 shape: &b.shape,
1538 })
1539 .expect("upload B");
1540
1541 let _output_guard = crate::output_count::push_output_count(Some(2));
1542 let gpu_value = linsolve_builtin(
1543 Value::GpuTensor(ha.clone()),
1544 Value::GpuTensor(hb.clone()),
1545 Vec::new(),
1546 )
1547 .expect("gpu square linsolve");
1548 let outputs = match gpu_value {
1549 Value::OutputList(outputs) => outputs,
1550 other => panic!("expected output list, got {other:?}"),
1551 };
1552 assert_eq!(outputs.len(), 2);
1553 let gathered = test_support::gather(outputs[0].clone()).expect("gather");
1554 let gpu_rcond = match &outputs[1] {
1555 Value::Num(value) => *value,
1556 other => panic!("unexpected gpu rcond {other:?}"),
1557 };
1558 let _ = provider.free(&ha);
1559 let _ = provider.free(&hb);
1560
1561 assert_eq!(gathered.shape, vec![2, 1]);
1562 assert!(
1563 (gpu_rcond - cpu_rcond).abs() < 1e-4,
1564 "gpu={gpu_rcond} cpu={cpu_rcond}"
1565 );
1566
1567 let telemetry = provider.telemetry_snapshot();
1568 assert_eq!(telemetry.linsolve.count, 1);
1569 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1570 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1571 }
1572
1573 #[cfg(feature = "wgpu")]
1574 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1575 #[test]
1576 fn wgpu_square_linsolve_with_rcond_option_stays_on_device() {
1577 let _accel_guard = test_support::accel_test_lock();
1578 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1579 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1580 );
1581 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1582 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1583 return;
1584 }
1585
1586 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1587 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1588 let mut cpu_opts = StructValue::new();
1589 cpu_opts
1590 .fields
1591 .insert("RCOND".to_string(), Value::Num(0.05));
1592 let cpu = linsolve_builtin(
1593 Value::Tensor(a.clone()),
1594 Value::Tensor(b.clone()),
1595 vec![Value::Struct(cpu_opts)],
1596 )
1597 .expect("cpu linsolve");
1598 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1599 provider.reset_telemetry();
1600
1601 let ha = provider
1602 .upload(&HostTensorView {
1603 data: &a.data,
1604 shape: &a.shape,
1605 })
1606 .expect("upload A");
1607 let hb = provider
1608 .upload(&HostTensorView {
1609 data: &b.data,
1610 shape: &b.shape,
1611 })
1612 .expect("upload B");
1613
1614 let _output_guard = crate::output_count::push_output_count(Some(1));
1615 let mut gpu_opts = StructValue::new();
1616 gpu_opts
1617 .fields
1618 .insert("RCOND".to_string(), Value::Num(0.05));
1619 let gpu_value = linsolve_builtin(
1620 Value::GpuTensor(ha.clone()),
1621 Value::GpuTensor(hb.clone()),
1622 vec![Value::Struct(gpu_opts)],
1623 )
1624 .expect("gpu square linsolve");
1625 let gpu_solution = match gpu_value {
1626 Value::OutputList(mut outputs) => outputs.remove(0),
1627 other => other,
1628 };
1629 let gathered = test_support::gather(gpu_solution).expect("gather");
1630 let _ = provider.free(&ha);
1631 let _ = provider.free(&hb);
1632
1633 assert_eq!(gathered.shape, cpu_tensor.shape);
1634 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1635 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1636 }
1637
1638 let telemetry = provider.telemetry_snapshot();
1639 assert_eq!(telemetry.linsolve.count, 1);
1640 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1641 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1642 }
1643
1644 #[cfg(feature = "wgpu")]
1645 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1646 #[test]
1647 fn wgpu_tall_linsolve_avoids_host_reupload_fallback() {
1648 let _accel_guard = test_support::accel_test_lock();
1649 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1650 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1651 );
1652 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1653 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1654 return;
1655 }
1656 let a = Tensor::new(vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
1657 let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
1658
1659 let cpu = linsolve_builtin(
1660 Value::Tensor(a.clone()),
1661 Value::Tensor(b.clone()),
1662 Vec::new(),
1663 )
1664 .expect("cpu linsolve");
1665 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1666 provider.reset_telemetry();
1667
1668 let ha = provider
1669 .upload(&HostTensorView {
1670 data: &a.data,
1671 shape: &a.shape,
1672 })
1673 .expect("upload A");
1674 let hb = provider
1675 .upload(&HostTensorView {
1676 data: &b.data,
1677 shape: &b.shape,
1678 })
1679 .expect("upload B");
1680
1681 let _output_guard = crate::output_count::push_output_count(Some(1));
1682 let gpu_value = linsolve_builtin(
1683 Value::GpuTensor(ha.clone()),
1684 Value::GpuTensor(hb.clone()),
1685 Vec::new(),
1686 )
1687 .expect("gpu tall linsolve");
1688 let gpu_solution = match gpu_value {
1689 Value::OutputList(mut outputs) => outputs.remove(0),
1690 other => other,
1691 };
1692 let gathered = test_support::gather(gpu_solution).expect("gather");
1693 let _ = provider.free(&ha);
1694 let _ = provider.free(&hb);
1695
1696 assert_eq!(gathered.shape, cpu_tensor.shape);
1697 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1698 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1699 }
1700
1701 let telemetry = provider.telemetry_snapshot();
1702 assert_eq!(telemetry.linsolve.count, 1);
1703 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1704 }
1705
1706 #[cfg(feature = "wgpu")]
1707 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1708 #[test]
1709 fn wgpu_posdef_linsolve_avoids_host_reupload_fallback() {
1710 let _accel_guard = test_support::accel_test_lock();
1711 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1712 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1713 );
1714 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1715 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1716 return;
1717 }
1718 let a = Tensor::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1719 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1720
1721 let mut cpu_opts = StructValue::new();
1722 cpu_opts
1723 .fields
1724 .insert("POSDEF".to_string(), Value::Bool(true));
1725 let cpu = linsolve_builtin(
1726 Value::Tensor(a.clone()),
1727 Value::Tensor(b.clone()),
1728 vec![Value::Struct(cpu_opts)],
1729 )
1730 .expect("cpu linsolve");
1731 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1732 let (_, cpu_rcond) = host_linsolve_real(
1733 &a,
1734 &b,
1735 ProviderLinsolveOptions {
1736 posdef: true,
1737 ..Default::default()
1738 },
1739 );
1740 provider.reset_telemetry();
1741
1742 let ha = provider
1743 .upload(&HostTensorView {
1744 data: &a.data,
1745 shape: &a.shape,
1746 })
1747 .expect("upload A");
1748 let hb = provider
1749 .upload(&HostTensorView {
1750 data: &b.data,
1751 shape: &b.shape,
1752 })
1753 .expect("upload B");
1754
1755 let _output_guard = crate::output_count::push_output_count(Some(2));
1756 let mut gpu_opts = StructValue::new();
1757 gpu_opts
1758 .fields
1759 .insert("POSDEF".to_string(), Value::Bool(true));
1760 let gpu_value = linsolve_builtin(
1761 Value::GpuTensor(ha.clone()),
1762 Value::GpuTensor(hb.clone()),
1763 vec![Value::Struct(gpu_opts)],
1764 )
1765 .expect("gpu posdef linsolve");
1766 let mut outputs = match gpu_value {
1767 Value::OutputList(outputs) => outputs,
1768 other => panic!("expected output list, got {other:?}"),
1769 };
1770 let gpu_rcond = match outputs.remove(1) {
1771 Value::Num(value) => value,
1772 other => panic!("unexpected rcond value {other:?}"),
1773 };
1774 let gpu_solution = outputs.remove(0);
1775 let gathered = test_support::gather(gpu_solution).expect("gather");
1776 let _ = provider.free(&ha);
1777 let _ = provider.free(&hb);
1778
1779 assert_eq!(gathered.shape, cpu_tensor.shape);
1780 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1781 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1782 }
1783 assert!(
1784 (gpu_rcond - cpu_rcond).abs() < 1e-4,
1785 "gpu={gpu_rcond} cpu={cpu_rcond}"
1786 );
1787
1788 let telemetry = provider.telemetry_snapshot();
1789 assert_eq!(telemetry.linsolve.count, 1);
1790 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1791 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
1792 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
1793 }
1794
1795 #[cfg(feature = "wgpu")]
1796 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1797 #[test]
1798 fn wgpu_transposed_posdef_linsolve_uses_cholesky_path() {
1799 let _accel_guard = test_support::accel_test_lock();
1800 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1801 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1802 );
1803 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1804 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1805 return;
1806 }
1807 let a = Tensor::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
1808 let b = Tensor::new(vec![8.0, 9.0], vec![2, 1]).unwrap();
1809
1810 let mut cpu_opts = StructValue::new();
1811 cpu_opts
1812 .fields
1813 .insert("POSDEF".to_string(), Value::Bool(true));
1814 cpu_opts.fields.insert(
1815 "TRANSA".to_string(),
1816 Value::CharArray(CharArray::new_row("T")),
1817 );
1818 let cpu = linsolve_builtin(
1819 Value::Tensor(a.clone()),
1820 Value::Tensor(b.clone()),
1821 vec![Value::Struct(cpu_opts)],
1822 )
1823 .expect("cpu linsolve");
1824 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1825 provider.reset_telemetry();
1826
1827 let ha = provider
1828 .upload(&HostTensorView {
1829 data: &a.data,
1830 shape: &a.shape,
1831 })
1832 .expect("upload A");
1833 let hb = provider
1834 .upload(&HostTensorView {
1835 data: &b.data,
1836 shape: &b.shape,
1837 })
1838 .expect("upload B");
1839
1840 let _output_guard = crate::output_count::push_output_count(Some(1));
1841 let mut gpu_opts = StructValue::new();
1842 gpu_opts
1843 .fields
1844 .insert("POSDEF".to_string(), Value::Bool(true));
1845 gpu_opts.fields.insert(
1846 "TRANSA".to_string(),
1847 Value::CharArray(CharArray::new_row("T")),
1848 );
1849 let gpu_value = linsolve_builtin(
1850 Value::GpuTensor(ha.clone()),
1851 Value::GpuTensor(hb.clone()),
1852 vec![Value::Struct(gpu_opts)],
1853 )
1854 .expect("gpu transposed posdef linsolve");
1855 let gpu_solution = match gpu_value {
1856 Value::OutputList(mut outputs) => outputs.remove(0),
1857 other => other,
1858 };
1859 let gathered = test_support::gather(gpu_solution).expect("gather");
1860 let _ = provider.free(&ha);
1861 let _ = provider.free(&hb);
1862
1863 assert_eq!(gathered.shape, cpu_tensor.shape);
1864 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1865 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1866 }
1867
1868 let telemetry = provider.telemetry_snapshot();
1869 assert_eq!(telemetry.linsolve.count, 1);
1870 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1871 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
1872 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
1873 }
1874
1875 #[cfg(feature = "wgpu")]
1876 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1877 #[test]
1878 fn wgpu_symmetric_linsolve_avoids_host_reupload_fallback() {
1879 let _accel_guard = test_support::accel_test_lock();
1880 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1881 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1882 );
1883 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1884 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1885 return;
1886 }
1887 let a = Tensor::new(vec![5.0, 2.0, 2.0, 6.0], vec![2, 2]).unwrap();
1888 let b = Tensor::new(vec![9.0, 8.0], vec![2, 1]).unwrap();
1889
1890 let mut cpu_opts = StructValue::new();
1891 cpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
1892 let cpu = linsolve_builtin(
1893 Value::Tensor(a.clone()),
1894 Value::Tensor(b.clone()),
1895 vec![Value::Struct(cpu_opts)],
1896 )
1897 .expect("cpu linsolve");
1898 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1899 provider.reset_telemetry();
1900
1901 let ha = provider
1902 .upload(&HostTensorView {
1903 data: &a.data,
1904 shape: &a.shape,
1905 })
1906 .expect("upload A");
1907 let hb = provider
1908 .upload(&HostTensorView {
1909 data: &b.data,
1910 shape: &b.shape,
1911 })
1912 .expect("upload B");
1913
1914 let _output_guard = crate::output_count::push_output_count(Some(1));
1915 let mut gpu_opts = StructValue::new();
1916 gpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
1917 let gpu_value = linsolve_builtin(
1918 Value::GpuTensor(ha.clone()),
1919 Value::GpuTensor(hb.clone()),
1920 vec![Value::Struct(gpu_opts)],
1921 )
1922 .expect("gpu symmetric linsolve");
1923 let gpu_solution = match gpu_value {
1924 Value::OutputList(mut outputs) => outputs.remove(0),
1925 other => other,
1926 };
1927 let gathered = test_support::gather(gpu_solution).expect("gather");
1928 let _ = provider.free(&ha);
1929 let _ = provider.free(&hb);
1930
1931 assert_eq!(gathered.shape, cpu_tensor.shape);
1932 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1933 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1934 }
1935
1936 let telemetry = provider.telemetry_snapshot();
1937 assert_eq!(telemetry.linsolve.count, 1);
1938 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1939 }
1940
1941 #[cfg(feature = "wgpu")]
1942 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1943 #[test]
1944 fn wgpu_transposed_square_linsolve_avoids_host_reupload_fallback() {
1945 let _accel_guard = test_support::accel_test_lock();
1946 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1947 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1948 );
1949 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1950 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1951 return;
1952 }
1953 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1954 let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
1955
1956 let mut cpu_opts = StructValue::new();
1957 cpu_opts.fields.insert(
1958 "TRANSA".to_string(),
1959 Value::CharArray(CharArray::new_row("T")),
1960 );
1961 let cpu = linsolve_builtin(
1962 Value::Tensor(a.clone()),
1963 Value::Tensor(b.clone()),
1964 vec![Value::Struct(cpu_opts)],
1965 )
1966 .expect("cpu linsolve");
1967 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1968 provider.reset_telemetry();
1969
1970 let ha = provider
1971 .upload(&HostTensorView {
1972 data: &a.data,
1973 shape: &a.shape,
1974 })
1975 .expect("upload A");
1976 let hb = provider
1977 .upload(&HostTensorView {
1978 data: &b.data,
1979 shape: &b.shape,
1980 })
1981 .expect("upload B");
1982
1983 let _output_guard = crate::output_count::push_output_count(Some(1));
1984 let mut gpu_opts = StructValue::new();
1985 gpu_opts.fields.insert(
1986 "TRANSA".to_string(),
1987 Value::CharArray(CharArray::new_row("T")),
1988 );
1989 let gpu_value = linsolve_builtin(
1990 Value::GpuTensor(ha.clone()),
1991 Value::GpuTensor(hb.clone()),
1992 vec![Value::Struct(gpu_opts)],
1993 )
1994 .expect("gpu transposed square linsolve");
1995 let gpu_solution = match gpu_value {
1996 Value::OutputList(mut outputs) => outputs.remove(0),
1997 other => other,
1998 };
1999 let gathered = test_support::gather(gpu_solution).expect("gather");
2000 let _ = provider.free(&ha);
2001 let _ = provider.free(&hb);
2002
2003 assert_eq!(gathered.shape, cpu_tensor.shape);
2004 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2005 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2006 }
2007
2008 let telemetry = provider.telemetry_snapshot();
2009 assert_eq!(telemetry.linsolve.count, 1);
2010 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2011 }
2012
2013 #[cfg(feature = "wgpu")]
2014 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2015 #[test]
2016 fn wgpu_conjugate_square_linsolve_avoids_host_reupload_fallback_for_real_inputs() {
2017 let _accel_guard = test_support::accel_test_lock();
2018 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2019 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2020 );
2021 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2022 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2023 return;
2024 }
2025
2026 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2027 let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2028 let mut cpu_opts = StructValue::new();
2029 cpu_opts.fields.insert(
2030 "TRANSA".to_string(),
2031 Value::CharArray(CharArray::new_row("C")),
2032 );
2033 let cpu = linsolve_builtin(
2034 Value::Tensor(a.clone()),
2035 Value::Tensor(b.clone()),
2036 vec![Value::Struct(cpu_opts)],
2037 )
2038 .expect("cpu linsolve");
2039 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2040 provider.reset_telemetry();
2041
2042 let ha = provider
2043 .upload(&HostTensorView {
2044 data: &a.data,
2045 shape: &a.shape,
2046 })
2047 .expect("upload A");
2048 let hb = provider
2049 .upload(&HostTensorView {
2050 data: &b.data,
2051 shape: &b.shape,
2052 })
2053 .expect("upload B");
2054
2055 let _output_guard = crate::output_count::push_output_count(Some(1));
2056 let mut gpu_opts = StructValue::new();
2057 gpu_opts.fields.insert(
2058 "TRANSA".to_string(),
2059 Value::CharArray(CharArray::new_row("C")),
2060 );
2061 let gpu_value = linsolve_builtin(
2062 Value::GpuTensor(ha.clone()),
2063 Value::GpuTensor(hb.clone()),
2064 vec![Value::Struct(gpu_opts)],
2065 )
2066 .expect("gpu conjugate square linsolve");
2067 let gpu_solution = match gpu_value {
2068 Value::OutputList(mut outputs) => outputs.remove(0),
2069 other => other,
2070 };
2071 let gathered = test_support::gather(gpu_solution).expect("gather");
2072 let _ = provider.free(&ha);
2073 let _ = provider.free(&hb);
2074
2075 assert_eq!(gathered.shape, cpu_tensor.shape);
2076 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2077 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2078 }
2079
2080 let telemetry = provider.telemetry_snapshot();
2081 assert_eq!(telemetry.linsolve.count, 1);
2082 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2083 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
2084 }
2085
2086 #[cfg(feature = "wgpu")]
2087 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2088 #[test]
2089 fn wgpu_transposed_rectangular_linsolve_avoids_host_reupload_fallback() {
2090 let _accel_guard = test_support::accel_test_lock();
2091 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2092 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2093 );
2094 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2095 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2096 return;
2097 }
2098 let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![2, 3]).unwrap();
2099 let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
2100
2101 let mut cpu_opts = StructValue::new();
2102 cpu_opts.fields.insert(
2103 "TRANSA".to_string(),
2104 Value::CharArray(CharArray::new_row("T")),
2105 );
2106 cpu_opts
2107 .fields
2108 .insert("RECT".to_string(), Value::Bool(true));
2109 let cpu = linsolve_builtin(
2110 Value::Tensor(a.clone()),
2111 Value::Tensor(b.clone()),
2112 vec![Value::Struct(cpu_opts)],
2113 )
2114 .expect("cpu linsolve");
2115 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2116 provider.reset_telemetry();
2117
2118 let ha = provider
2119 .upload(&HostTensorView {
2120 data: &a.data,
2121 shape: &a.shape,
2122 })
2123 .expect("upload A");
2124 let hb = provider
2125 .upload(&HostTensorView {
2126 data: &b.data,
2127 shape: &b.shape,
2128 })
2129 .expect("upload B");
2130
2131 let _output_guard = crate::output_count::push_output_count(Some(1));
2132 let mut gpu_opts = StructValue::new();
2133 gpu_opts.fields.insert(
2134 "TRANSA".to_string(),
2135 Value::CharArray(CharArray::new_row("T")),
2136 );
2137 gpu_opts
2138 .fields
2139 .insert("RECT".to_string(), Value::Bool(true));
2140 let gpu_value = linsolve_builtin(
2141 Value::GpuTensor(ha.clone()),
2142 Value::GpuTensor(hb.clone()),
2143 vec![Value::Struct(gpu_opts)],
2144 )
2145 .expect("gpu transposed rectangular linsolve");
2146 let gpu_solution = match gpu_value {
2147 Value::OutputList(mut outputs) => outputs.remove(0),
2148 other => other,
2149 };
2150 let gathered = test_support::gather(gpu_solution).expect("gather");
2151 let _ = provider.free(&ha);
2152 let _ = provider.free(&hb);
2153
2154 assert_eq!(gathered.shape, cpu_tensor.shape);
2155 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2156 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2157 }
2158
2159 let telemetry = provider.telemetry_snapshot();
2160 assert_eq!(telemetry.linsolve.count, 1);
2161 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2162 }
2163
2164 #[cfg(feature = "wgpu")]
2165 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2166 #[test]
2167 fn wgpu_triangular_hint_avoids_host_reupload_fallback() {
2168 let _accel_guard = test_support::accel_test_lock();
2169 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2170 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2171 );
2172 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2173 let a = Tensor::new(
2174 vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
2175 vec![3, 3],
2176 )
2177 .unwrap();
2178 let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
2179
2180 let cpu = linsolve_builtin(Value::Tensor(a.clone()), Value::Tensor(b.clone()), {
2181 let mut opts = StructValue::new();
2182 opts.fields.insert("LT".to_string(), Value::Bool(true));
2183 vec![Value::Struct(opts)]
2184 })
2185 .expect("cpu linsolve");
2186 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2187 provider.reset_telemetry();
2188
2189 let ha = provider
2190 .upload(&HostTensorView {
2191 data: &a.data,
2192 shape: &a.shape,
2193 })
2194 .expect("upload A");
2195 let hb = provider
2196 .upload(&HostTensorView {
2197 data: &b.data,
2198 shape: &b.shape,
2199 })
2200 .expect("upload B");
2201
2202 let _output_guard = crate::output_count::push_output_count(Some(1));
2203 let mut opts = StructValue::new();
2204 opts.fields.insert("LT".to_string(), Value::Bool(true));
2205 let gpu_value = linsolve_builtin(
2206 Value::GpuTensor(ha.clone()),
2207 Value::GpuTensor(hb.clone()),
2208 vec![Value::Struct(opts)],
2209 )
2210 .expect("gpu triangular linsolve");
2211 let gpu_solution = match gpu_value {
2212 Value::OutputList(mut outputs) => outputs.remove(0),
2213 other => other,
2214 };
2215 let gathered = test_support::gather(gpu_solution).expect("gather");
2216 let _ = provider.free(&ha);
2217 let _ = provider.free(&hb);
2218
2219 assert_eq!(gathered.shape, cpu_tensor.shape);
2220 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2221 assert!((gpu - cpu).abs() < 1e-5);
2222 }
2223
2224 let telemetry = provider.telemetry_snapshot();
2225 assert_eq!(telemetry.linsolve.count, 1);
2226 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2227 }
2228
2229 #[cfg(feature = "wgpu")]
2230 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2231 #[test]
2232 fn wgpu_transposed_triangular_hint_avoids_host_reupload_fallback() {
2233 let _accel_guard = test_support::accel_test_lock();
2234 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2235 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2236 );
2237 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2238 let a = Tensor::new(
2239 vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
2240 vec![3, 3],
2241 )
2242 .unwrap();
2243 let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
2244
2245 let mut cpu_opts = StructValue::new();
2246 cpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2247 cpu_opts.fields.insert(
2248 "TRANSA".to_string(),
2249 Value::CharArray(CharArray::new_row("T")),
2250 );
2251 let cpu = linsolve_builtin(
2252 Value::Tensor(a.clone()),
2253 Value::Tensor(b.clone()),
2254 vec![Value::Struct(cpu_opts)],
2255 )
2256 .expect("cpu linsolve");
2257 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2258 provider.reset_telemetry();
2259
2260 let ha = provider
2261 .upload(&HostTensorView {
2262 data: &a.data,
2263 shape: &a.shape,
2264 })
2265 .expect("upload A");
2266 let hb = provider
2267 .upload(&HostTensorView {
2268 data: &b.data,
2269 shape: &b.shape,
2270 })
2271 .expect("upload B");
2272
2273 let _output_guard = crate::output_count::push_output_count(Some(1));
2274 let mut gpu_opts = StructValue::new();
2275 gpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2276 gpu_opts.fields.insert(
2277 "TRANSA".to_string(),
2278 Value::CharArray(CharArray::new_row("T")),
2279 );
2280 let gpu_value = linsolve_builtin(
2281 Value::GpuTensor(ha.clone()),
2282 Value::GpuTensor(hb.clone()),
2283 vec![Value::Struct(gpu_opts)],
2284 )
2285 .expect("gpu transposed triangular linsolve");
2286 let gpu_solution = match gpu_value {
2287 Value::OutputList(mut outputs) => outputs.remove(0),
2288 other => other,
2289 };
2290 let gathered = test_support::gather(gpu_solution).expect("gather");
2291 let _ = provider.free(&ha);
2292 let _ = provider.free(&hb);
2293
2294 assert_eq!(gathered.shape, cpu_tensor.shape);
2295 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2296 assert!((gpu - cpu).abs() < 1e-5);
2297 }
2298
2299 let telemetry = provider.telemetry_snapshot();
2300 assert_eq!(telemetry.linsolve.count, 1);
2301 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2302 }
2303
2304 #[cfg(feature = "wgpu")]
2305 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2306 #[test]
2307 fn wgpu_round_trip_matches_cpu() {
2308 let _accel_guard = test_support::accel_test_lock();
2309 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2310 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2311 );
2312 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2313 let tol = match provider.precision() {
2314 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2315 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2316 };
2317
2318 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2319 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
2320
2321 let cpu = linsolve_builtin(
2322 Value::Tensor(a.clone()),
2323 Value::Tensor(b.clone()),
2324 Vec::new(),
2325 )
2326 .expect("cpu linsolve");
2327 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2328
2329 let view_a = HostTensorView {
2330 data: &a.data,
2331 shape: &a.shape,
2332 };
2333 let view_b = HostTensorView {
2334 data: &b.data,
2335 shape: &b.shape,
2336 };
2337 let ha = provider.upload(&view_a).expect("upload A");
2338 let hb = provider.upload(&view_b).expect("upload B");
2339 let gpu_value = linsolve_builtin(
2340 Value::GpuTensor(ha.clone()),
2341 Value::GpuTensor(hb.clone()),
2342 Vec::new(),
2343 )
2344 .expect("gpu linsolve");
2345 let gathered = test_support::gather(gpu_value).expect("gather");
2346 let _ = provider.free(&ha);
2347 let _ = provider.free(&hb);
2348
2349 assert_eq!(gathered.shape, cpu_tensor.shape);
2350 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2351 assert!((gpu - cpu).abs() < tol);
2352 }
2353 }
2354}