1use nalgebra::{linalg::SVD, DMatrix};
4use num_complex::Complex64;
5use runmat_accelerate_api::{
6 AccelProvider, GpuTensorHandle, HostTensorView, ProviderLinsolveOptions, ProviderLinsolveResult,
7};
8use runmat_builtins::{
9 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
10 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
11 ComplexTensor, Tensor, Value,
12};
13use runmat_macros::runtime_builtin;
14
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ProviderHook, ReductionNaN, ResidencyPolicy, ScalarType, ShapeRequirements,
18};
19use crate::builtins::common::{
20 gpu_helpers,
21 linalg::{diagonal_rcond, singular_value_rcond},
22 tensor,
23};
24use crate::builtins::math::linalg::type_resolvers::left_divide_type;
25use crate::{build_runtime_error, BuiltinResult, RuntimeError};
26
27const NAME: &str = "linsolve";
28
29const LINSOLVE_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
30 name: "X",
31 ty: BuiltinParamType::NumericArray,
32 arity: BuiltinParamArity::Required,
33 default: None,
34 description: "Solution to A * X = B.",
35}];
36
37const LINSOLVE_OUTPUT_XR: [BuiltinParamDescriptor; 2] = [
38 BuiltinParamDescriptor {
39 name: "X",
40 ty: BuiltinParamType::NumericArray,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "Solution to A * X = B.",
44 },
45 BuiltinParamDescriptor {
46 name: "R",
47 ty: BuiltinParamType::NumericScalar,
48 arity: BuiltinParamArity::Required,
49 default: None,
50 description: "Reciprocal condition estimate.",
51 },
52];
53
54const LINSOLVE_INPUTS_AB: [BuiltinParamDescriptor; 2] = [
55 BuiltinParamDescriptor {
56 name: "A",
57 ty: BuiltinParamType::Any,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Coefficient matrix.",
61 },
62 BuiltinParamDescriptor {
63 name: "B",
64 ty: BuiltinParamType::Any,
65 arity: BuiltinParamArity::Required,
66 default: None,
67 description: "Right-hand side matrix or vector.",
68 },
69];
70
71const LINSOLVE_INPUTS_AB_OPTS: [BuiltinParamDescriptor; 3] = [
72 BuiltinParamDescriptor {
73 name: "A",
74 ty: BuiltinParamType::Any,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "Coefficient matrix.",
78 },
79 BuiltinParamDescriptor {
80 name: "B",
81 ty: BuiltinParamType::Any,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Right-hand side matrix or vector.",
85 },
86 BuiltinParamDescriptor {
87 name: "opts",
88 ty: BuiltinParamType::Any,
89 arity: BuiltinParamArity::Optional,
90 default: None,
91 description: "Structural options (LT, UT, RECT, SYM, POSDEF, TRANSA, RCOND).",
92 },
93];
94
95const LINSOLVE_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
96 BuiltinSignatureDescriptor {
97 label: "X = linsolve(A, B)",
98 inputs: &LINSOLVE_INPUTS_AB,
99 outputs: &LINSOLVE_OUTPUT_X,
100 },
101 BuiltinSignatureDescriptor {
102 label: "X = linsolve(A, B, opts)",
103 inputs: &LINSOLVE_INPUTS_AB_OPTS,
104 outputs: &LINSOLVE_OUTPUT_X,
105 },
106 BuiltinSignatureDescriptor {
107 label: "[X, R] = linsolve(A, B)",
108 inputs: &LINSOLVE_INPUTS_AB,
109 outputs: &LINSOLVE_OUTPUT_XR,
110 },
111 BuiltinSignatureDescriptor {
112 label: "[X, R] = linsolve(A, B, opts)",
113 inputs: &LINSOLVE_INPUTS_AB_OPTS,
114 outputs: &LINSOLVE_OUTPUT_XR,
115 },
116];
117
118const LINSOLVE_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
119 code: "RM.LINSOLVE.INVALID_ARGUMENT",
120 identifier: Some("RunMat:linsolve:InvalidArgument"),
121 when: "Options/output count/auxiliary arguments are malformed or unsupported.",
122 message: "linsolve: invalid argument",
123};
124
125const LINSOLVE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
126 code: "RM.LINSOLVE.INVALID_INPUT",
127 identifier: Some("RunMat:linsolve:InvalidInput"),
128 when: "Input shape/type cannot be solved under linsolve semantics.",
129 message: "linsolve: invalid input",
130};
131
132const LINSOLVE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
133 code: "RM.LINSOLVE.INTERNAL",
134 identifier: Some("RunMat:linsolve:Internal"),
135 when: "Runtime fails while solving or executing provider fallback paths.",
136 message: "linsolve: internal runtime failure",
137};
138
139const LINSOLVE_ERRORS: [BuiltinErrorDescriptor; 3] = [
140 LINSOLVE_ERROR_INVALID_ARGUMENT,
141 LINSOLVE_ERROR_INVALID_INPUT,
142 LINSOLVE_ERROR_INTERNAL,
143];
144
145pub const LINSOLVE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
146 signatures: &LINSOLVE_SIGNATURES,
147 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
148 completion_policy: BuiltinCompletionPolicy::Public,
149 errors: &LINSOLVE_ERRORS,
150};
151
152#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::linalg::solve::linsolve")]
153pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
154 name: "linsolve",
155 op_kind: GpuOpKind::Custom("solve"),
156 supported_precisions: &[ScalarType::F32, ScalarType::F64],
157 broadcast: BroadcastSemantics::None,
158 provider_hooks: &[ProviderHook::Custom("linsolve")],
159 constant_strategy: ConstantStrategy::UniformBuffer,
160 residency: ResidencyPolicy::NewHandle,
161 nan_mode: ReductionNaN::Include,
162 two_pass_threshold: None,
163 workgroup_size: None,
164 accepts_nan_mode: false,
165 notes: "Prefers the provider linsolve hook; WGPU currently supports triangular solves, real F32 TRANSA='T'/'C' variants, a dedicated real F32 POSDEF/Cholesky path, and selected real F32 QR-backed square and rectangular solves, otherwise it gathers to the host solver and re-uploads the result.",
166};
167
168fn linsolve_error_with_message(
169 message: impl Into<String>,
170 error: &'static BuiltinErrorDescriptor,
171) -> RuntimeError {
172 let mut builder = build_runtime_error(message).with_builtin(NAME);
173 if let Some(identifier) = error.identifier {
174 builder = builder.with_identifier(identifier);
175 }
176 builder.build()
177}
178
179fn builtin_error(message: impl Into<String>) -> RuntimeError {
180 linsolve_error_with_message(message, &LINSOLVE_ERROR_INVALID_INPUT)
181}
182
183fn argument_error(message: impl Into<String>) -> RuntimeError {
184 linsolve_error_with_message(message, &LINSOLVE_ERROR_INVALID_ARGUMENT)
185}
186
187fn map_control_flow(err: RuntimeError) -> RuntimeError {
188 let mut builder = build_runtime_error(err.message()).with_builtin(NAME);
189 if let Some(identifier) = err.identifier() {
190 builder = builder.with_identifier(identifier.to_string());
191 }
192 if let Some(task_id) = err.context.task_id.clone() {
193 builder = builder.with_task_id(task_id);
194 }
195 if !err.context.call_stack.is_empty() {
196 builder = builder.with_call_stack(err.context.call_stack.clone());
197 }
198 if let Some(phase) = err.context.phase.clone() {
199 builder = builder.with_phase(phase);
200 }
201 builder.with_source(err).build()
202}
203
204#[runmat_macros::register_fusion_spec(
205 builtin_path = "crate::builtins::math::linalg::solve::linsolve"
206)]
207pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
208 name: "linsolve",
209 shape: ShapeRequirements::Any,
210 constant_strategy: ConstantStrategy::UniformBuffer,
211 elementwise: None,
212 reduction: None,
213 emits_nan: false,
214 notes: "Linear solves are terminal operations and do not fuse with surrounding kernels.",
215};
216
217#[runtime_builtin(
218 name = "linsolve",
219 category = "math/linalg/solve",
220 summary = "Solve A * X = B with structural hints such as LT, UT, POSDEF, or TRANSA.",
221 keywords = "linsolve,linear system,triangular,gpu",
222 accel = "linsolve",
223 type_resolver(left_divide_type),
224 descriptor(crate::builtins::math::linalg::solve::linsolve::LINSOLVE_DESCRIPTOR),
225 builtin_path = "crate::builtins::math::linalg::solve::linsolve"
226)]
227async fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
228 let eval = evaluate_args(lhs, rhs, &rest).await?;
229 if let Some(out_count) = crate::output_count::current_output_count() {
230 if out_count == 0 {
231 return Ok(Value::OutputList(Vec::new()));
232 }
233 if out_count == 1 {
234 return Ok(Value::OutputList(vec![eval.solution()]));
235 }
236 if out_count == 2 {
237 return Ok(Value::OutputList(vec![
238 eval.solution(),
239 eval.reciprocal_condition(),
240 ]));
241 }
242 return Err(argument_error(
243 "linsolve currently supports at most two outputs",
244 ));
245 }
246 Ok(eval.solution())
247}
248
249pub async fn evaluate(
251 lhs: Value,
252 rhs: Value,
253 options: SolveOptions,
254) -> BuiltinResult<LinsolveEval> {
255 if let Some(eval) = try_gpu_linsolve(&lhs, &rhs, &options).await? {
256 return Ok(eval);
257 }
258
259 let lhs_host = crate::dispatcher::gather_if_needed_async(&lhs)
260 .await
261 .map_err(map_control_flow)?;
262 let rhs_host = crate::dispatcher::gather_if_needed_async(&rhs)
263 .await
264 .map_err(map_control_flow)?;
265 let pair = coerce_numeric_pair(lhs_host, rhs_host).await?;
266 match pair {
267 NumericPair::Real(lhs_r, rhs_r) => {
268 let (solution, rcond) = solve_real(lhs_r, rhs_r, &options)?;
269 Ok(LinsolveEval::new(
270 tensor::tensor_into_value(solution),
271 Some(rcond),
272 ))
273 }
274 NumericPair::Complex(lhs_c, rhs_c) => {
275 let (solution, rcond) = solve_complex(lhs_c, rhs_c, &options)?;
276 Ok(LinsolveEval::new(
277 Value::ComplexTensor(solution),
278 Some(rcond),
279 ))
280 }
281 }
282}
283
284pub fn linsolve_host_real_for_provider(
286 lhs: &Tensor,
287 rhs: &Tensor,
288 options: &ProviderLinsolveOptions,
289) -> BuiltinResult<(Tensor, f64)> {
290 let opts = SolveOptions::from(options);
291 solve_real(lhs.clone(), rhs.clone(), &opts)
292}
293
294#[derive(Clone)]
296pub struct LinsolveEval {
297 solution: Value,
298 rcond: Option<f64>,
299}
300
301impl LinsolveEval {
302 fn new(solution: Value, rcond: Option<f64>) -> Self {
303 Self { solution, rcond }
304 }
305
306 pub fn solution(&self) -> Value {
308 self.solution.clone()
309 }
310
311 pub fn reciprocal_condition(&self) -> Value {
313 match self.rcond {
314 Some(r) => Value::Num(r),
315 None => Value::Num(f64::NAN),
316 }
317 }
318}
319
320#[derive(Clone, Default)]
321pub struct SolveOptions {
322 lower: bool,
323 upper: bool,
324 rectangular: bool,
325 transposed: bool,
326 conjugate: bool,
327 symmetric: bool,
328 posdef: bool,
329 rcond: Option<f64>,
330}
331
332impl From<&SolveOptions> for ProviderLinsolveOptions {
333 fn from(opts: &SolveOptions) -> Self {
334 Self {
335 lower: opts.lower,
336 upper: opts.upper,
337 rectangular: opts.rectangular,
338 transposed: opts.transposed,
339 conjugate: opts.conjugate,
340 symmetric: opts.symmetric,
341 posdef: opts.posdef,
342 need_rcond: false,
343 rcond: opts.rcond,
344 }
345 }
346}
347
348impl From<&ProviderLinsolveOptions> for SolveOptions {
349 fn from(opts: &ProviderLinsolveOptions) -> Self {
350 Self {
351 lower: opts.lower,
352 upper: opts.upper,
353 rectangular: opts.rectangular,
354 transposed: opts.transposed,
355 conjugate: opts.conjugate,
356 symmetric: opts.symmetric,
357 posdef: opts.posdef,
358 rcond: opts.rcond,
359 }
360 }
361}
362
363fn options_from_rest(rest: &[Value]) -> BuiltinResult<SolveOptions> {
364 match rest.len() {
365 0 => Ok(SolveOptions::default()),
366 1 => parse_options(&rest[0]),
367 _ => Err(argument_error("linsolve: too many input arguments")),
368 }
369}
370
371pub async fn evaluate_args(lhs: Value, rhs: Value, rest: &[Value]) -> BuiltinResult<LinsolveEval> {
373 let options = options_from_rest(rest)?;
374 evaluate(lhs, rhs, options).await
375}
376
377async fn try_gpu_linsolve(
378 lhs: &Value,
379 rhs: &Value,
380 options: &SolveOptions,
381) -> BuiltinResult<Option<LinsolveEval>> {
382 if matches!(crate::output_count::current_output_count(), Some(n) if n > 2) {
383 return Ok(None);
384 }
385 let provider = match runmat_accelerate_api::provider() {
386 Some(p) => p,
387 None => return Ok(None),
388 };
389
390 if contains_complex(lhs) || contains_complex(rhs) {
391 return Ok(None);
392 }
393
394 let mut lhs_operand = match prepare_gpu_operand(lhs, provider)? {
395 Some(op) => op,
396 None => return Ok(None),
397 };
398 let mut rhs_operand = match prepare_gpu_operand(rhs, provider)? {
399 Some(op) => op,
400 None => {
401 release_operand(provider, &mut lhs_operand);
402 return Ok(None);
403 }
404 };
405
406 if is_scalar_handle(lhs_operand.handle()) || is_scalar_handle(rhs_operand.handle()) {
407 release_operand(provider, &mut lhs_operand);
408 release_operand(provider, &mut rhs_operand);
409 return Ok(None);
410 }
411
412 let mut provider_opts: ProviderLinsolveOptions = options.into();
413 provider_opts.need_rcond =
414 matches!(crate::output_count::current_output_count(), Some(2)) || options.rcond.is_some();
415 let result = provider
416 .linsolve(lhs_operand.handle(), rhs_operand.handle(), &provider_opts)
417 .await
418 .ok();
419
420 release_operand(provider, &mut lhs_operand);
421 release_operand(provider, &mut rhs_operand);
422
423 if let Some(ProviderLinsolveResult {
424 solution,
425 reciprocal_condition,
426 }) = result
427 {
428 let eval = LinsolveEval::new(Value::GpuTensor(solution), Some(reciprocal_condition));
429 return Ok(Some(eval));
430 }
431
432 Ok(None)
433}
434
435fn parse_options(value: &Value) -> BuiltinResult<SolveOptions> {
436 let struct_val = match value {
437 Value::Struct(s) => s,
438 other => {
439 return Err(argument_error(format!(
440 "linsolve: opts must be a struct, got {other:?}"
441 )))
442 }
443 };
444 let mut opts = SolveOptions::default();
445 for (key, raw_value) in &struct_val.fields {
446 let name = key.to_ascii_uppercase();
447 match name.as_str() {
448 "LT" => opts.lower = parse_bool_field("LT", raw_value)?,
449 "UT" => opts.upper = parse_bool_field("UT", raw_value)?,
450 "RECT" => opts.rectangular = parse_bool_field("RECT", raw_value)?,
451 "SYM" => opts.symmetric = parse_bool_field("SYM", raw_value)?,
452 "POSDEF" => opts.posdef = parse_bool_field("POSDEF", raw_value)?,
453 "TRANSA" => {
454 let transa = parse_transa(raw_value)?;
455 opts.transposed = transa != TransposeMode::None;
456 opts.conjugate = transa == TransposeMode::Conjugate;
457 }
458 "RCOND" => {
459 let threshold = parse_scalar_f64("RCOND", raw_value)?;
460 if threshold < 0.0 {
461 return Err(argument_error("linsolve: RCOND must be non-negative"));
462 }
463 opts.rcond = Some(threshold);
464 }
465 other => {
466 return Err(argument_error(format!(
467 "linsolve: unknown option '{other}'"
468 )))
469 }
470 }
471 }
472 if opts.lower && opts.upper {
473 return Err(argument_error(
474 "linsolve: LT and UT are mutually exclusive.",
475 ));
476 }
477 Ok(opts)
478}
479
480fn parse_bool_field(name: &str, value: &Value) -> BuiltinResult<bool> {
481 match value {
482 Value::Bool(b) => Ok(*b),
483 Value::Int(i) => Ok(!i.is_zero()),
484 Value::Num(n) => Ok(*n != 0.0),
485 Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0] != 0.0),
486 Value::LogicalArray(arr) if arr.len() == 1 => Ok(arr.data[0] != 0),
487 other => Err(argument_error(format!(
488 "linsolve: option '{name}' must be logical or numeric, got {other:?}"
489 ))),
490 }
491}
492
493fn parse_scalar_f64(name: &str, value: &Value) -> BuiltinResult<f64> {
494 match value {
495 Value::Num(n) => Ok(*n),
496 Value::Int(i) => Ok(i.to_f64()),
497 Value::Tensor(t) if tensor::is_scalar_tensor(t) => Ok(t.data[0]),
498 other => Err(argument_error(format!(
499 "linsolve: option '{name}' must be a scalar numeric value, got {other:?}"
500 ))),
501 }
502}
503
504#[derive(Copy, Clone, PartialEq, Eq)]
505enum TransposeMode {
506 None,
507 Transpose,
508 Conjugate,
509}
510
511fn parse_transa(value: &Value) -> BuiltinResult<TransposeMode> {
512 let text = tensor::value_to_string(value).ok_or_else(|| {
513 argument_error("linsolve: TRANSA must be a character vector or string scalar")
514 })?;
515 if text.is_empty() {
516 return Err(argument_error("linsolve: TRANSA cannot be empty"));
517 }
518 match text.trim().to_ascii_uppercase().as_str() {
519 "N" => Ok(TransposeMode::None),
520 "T" => Ok(TransposeMode::Transpose),
521 "C" => Ok(TransposeMode::Conjugate),
522 other => Err(argument_error(format!(
523 "linsolve: TRANSA must be 'N', 'T', or 'C', got '{other}'"
524 ))),
525 }
526}
527
528enum NumericInput {
529 Real(Tensor),
530 Complex(ComplexTensor),
531}
532
533enum NumericPair {
534 Real(Tensor, Tensor),
535 Complex(ComplexTensor, ComplexTensor),
536}
537
538async fn coerce_numeric_pair(lhs: Value, rhs: Value) -> BuiltinResult<NumericPair> {
539 let lhs_num = coerce_numeric(lhs).await?;
540 let rhs_num = coerce_numeric(rhs).await?;
541 match (lhs_num, rhs_num) {
542 (NumericInput::Real(lhs_r), NumericInput::Real(rhs_r)) => {
543 Ok(NumericPair::Real(lhs_r, rhs_r))
544 }
545 (NumericInput::Complex(lhs_c), NumericInput::Complex(rhs_c)) => {
546 Ok(NumericPair::Complex(lhs_c, rhs_c))
547 }
548 (NumericInput::Complex(lhs_c), NumericInput::Real(rhs_r)) => {
549 let rhs_c = promote_real_tensor(&rhs_r)?;
550 Ok(NumericPair::Complex(lhs_c, rhs_c))
551 }
552 (NumericInput::Real(lhs_r), NumericInput::Complex(rhs_c)) => {
553 let lhs_c = promote_real_tensor(&lhs_r)?;
554 Ok(NumericPair::Complex(lhs_c, rhs_c))
555 }
556 }
557}
558
559async fn coerce_numeric(value: Value) -> BuiltinResult<NumericInput> {
560 match value {
561 Value::Tensor(tensor) => {
562 ensure_matrix_shape(NAME, &tensor.shape)?;
563 Ok(NumericInput::Real(tensor))
564 }
565 Value::LogicalArray(logical) => {
566 let tensor = tensor::logical_to_tensor(&logical).map_err(builtin_error)?;
567 ensure_matrix_shape(NAME, &tensor.shape)?;
568 Ok(NumericInput::Real(tensor))
569 }
570 Value::Num(n) => {
571 let tensor = Tensor::new(vec![n], vec![1, 1]).map_err(builtin_error)?;
572 Ok(NumericInput::Real(tensor))
573 }
574 Value::Int(i) => {
575 let tensor = Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(builtin_error)?;
576 Ok(NumericInput::Real(tensor))
577 }
578 Value::Bool(b) => {
579 let tensor =
580 Tensor::new(vec![if b { 1.0 } else { 0.0 }], vec![1, 1]).map_err(builtin_error)?;
581 Ok(NumericInput::Real(tensor))
582 }
583 Value::Complex(re, im) => {
584 let tensor = ComplexTensor::new(vec![(re, im)], vec![1, 1]).map_err(builtin_error)?;
585 Ok(NumericInput::Complex(tensor))
586 }
587 Value::ComplexTensor(ct) => {
588 ensure_matrix_shape(NAME, &ct.shape)?;
589 Ok(NumericInput::Complex(ct))
590 }
591 Value::GpuTensor(handle) => {
592 let tensor = gpu_helpers::gather_tensor_async(&handle)
593 .await
594 .map_err(map_control_flow)?;
595 ensure_matrix_shape(NAME, &tensor.shape)?;
596 Ok(NumericInput::Real(tensor))
597 }
598 other => Err(builtin_error(format!(
599 "{NAME}: unsupported input type {:?}; convert to numeric values first",
600 other
601 ))),
602 }
603}
604
605fn contains_complex(value: &Value) -> bool {
606 matches!(value, Value::Complex(_, _) | Value::ComplexTensor(_))
607}
608
609fn is_scalar_handle(handle: &GpuTensorHandle) -> bool {
610 crate::builtins::common::shape::is_scalar_shape(&handle.shape)
611}
612
613struct PreparedOperand {
614 handle: GpuTensorHandle,
615 owned: bool,
616}
617
618impl PreparedOperand {
619 fn borrowed(handle: &GpuTensorHandle) -> Self {
620 Self {
621 handle: handle.clone(),
622 owned: false,
623 }
624 }
625
626 fn owned(handle: GpuTensorHandle) -> Self {
627 Self {
628 handle,
629 owned: true,
630 }
631 }
632
633 fn handle(&self) -> &GpuTensorHandle {
634 &self.handle
635 }
636}
637
638fn prepare_gpu_operand(
639 value: &Value,
640 provider: &'static dyn AccelProvider,
641) -> BuiltinResult<Option<PreparedOperand>> {
642 match value {
643 Value::GpuTensor(handle) => {
644 if is_scalar_handle(handle) {
645 Ok(None)
646 } else {
647 Ok(Some(PreparedOperand::borrowed(handle)))
648 }
649 }
650 Value::Tensor(tensor) => {
651 if tensor::is_scalar_tensor(tensor) {
652 Ok(None)
653 } else {
654 let uploaded = upload_tensor(provider, tensor)?;
655 Ok(Some(PreparedOperand::owned(uploaded)))
656 }
657 }
658 Value::LogicalArray(logical) => {
659 if logical.data.len() == 1 {
660 Ok(None)
661 } else {
662 let tensor = tensor::logical_to_tensor(logical).map_err(builtin_error)?;
663 let uploaded = upload_tensor(provider, &tensor)?;
664 Ok(Some(PreparedOperand::owned(uploaded)))
665 }
666 }
667 _ => Ok(None),
668 }
669}
670
671fn upload_tensor(
672 provider: &'static dyn AccelProvider,
673 tensor: &Tensor,
674) -> BuiltinResult<GpuTensorHandle> {
675 let view = HostTensorView {
676 data: &tensor.data,
677 shape: &tensor.shape,
678 };
679 provider
680 .upload(&view)
681 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
682}
683
684fn release_operand(provider: &'static dyn AccelProvider, operand: &mut PreparedOperand) {
685 if operand.owned {
686 let _ = provider.free(&operand.handle);
687 operand.owned = false;
688 }
689}
690
691fn solve_real(lhs: Tensor, rhs: Tensor, options: &SolveOptions) -> BuiltinResult<(Tensor, f64)> {
692 let mut lhs_effective = lhs;
693 let mut rhs_effective = rhs;
694 let mut lower = options.lower;
695 let mut upper = options.upper;
696
697 if options.transposed {
698 lhs_effective = transpose_tensor(&lhs_effective);
699 if options.conjugate {
700 conjugate_in_place(&mut lhs_effective);
701 }
702 if lower || upper {
703 std::mem::swap(&mut lower, &mut upper);
704 }
705 }
706
707 rhs_effective = normalize_rhs_tensor(rhs_effective, lhs_effective.rows())?;
708
709 if lower {
710 ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
711 let (solution, rcond) = forward_substitution_real(&lhs_effective, &rhs_effective)?;
712 enforce_rcond(options, rcond)?;
713 return Ok((solution, rcond));
714 }
715
716 if upper {
717 ensure_square(lhs_effective.rows(), lhs_effective.cols())?;
718 let (solution, rcond) = backward_substitution_real(&lhs_effective, &rhs_effective)?;
719 enforce_rcond(options, rcond)?;
720 return Ok((solution, rcond));
721 }
722
723 let (solution, rcond) = solve_general_real(&lhs_effective, &rhs_effective)?;
724 enforce_rcond(options, rcond)?;
725 Ok((solution, rcond))
726}
727
728fn solve_complex(
729 lhs: ComplexTensor,
730 rhs: ComplexTensor,
731 options: &SolveOptions,
732) -> BuiltinResult<(ComplexTensor, f64)> {
733 let mut lhs_effective = lhs;
734 let mut rhs_effective = rhs;
735 let mut lower = options.lower;
736 let mut upper = options.upper;
737
738 if options.transposed {
739 lhs_effective = transpose_complex(&lhs_effective);
740 if options.conjugate {
741 conjugate_complex_in_place(&mut lhs_effective);
742 }
743 if lower || upper {
744 std::mem::swap(&mut lower, &mut upper);
745 }
746 }
747
748 rhs_effective = normalize_rhs_complex(rhs_effective, lhs_effective.rows)?;
749
750 if lower {
751 ensure_square(lhs_effective.rows, lhs_effective.cols)?;
752 let (solution, rcond) = forward_substitution_complex(&lhs_effective, &rhs_effective)?;
753 enforce_rcond(options, rcond)?;
754 return Ok((solution, rcond));
755 }
756
757 if upper {
758 ensure_square(lhs_effective.rows, lhs_effective.cols)?;
759 let (solution, rcond) = backward_substitution_complex(&lhs_effective, &rhs_effective)?;
760 enforce_rcond(options, rcond)?;
761 return Ok((solution, rcond));
762 }
763
764 let (solution, rcond) = solve_general_complex(&lhs_effective, &rhs_effective)?;
765 enforce_rcond(options, rcond)?;
766 Ok((solution, rcond))
767}
768
769fn forward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
770 let n = lhs.rows();
771 let nrhs = rhs.data.len() / n;
772 let mut solution = rhs.data.clone();
773 let mut min_diag = f64::INFINITY;
774 let mut max_diag = 0.0_f64;
775
776 for col in 0..nrhs {
777 for i in 0..n {
778 let diag = lhs.data[i + i * n];
779 let diag_abs = diag.abs();
780 min_diag = min_diag.min(diag_abs);
781 max_diag = max_diag.max(diag_abs);
782 if diag_abs == 0.0 {
783 return Err(builtin_error(
784 "linsolve: matrix is singular to working precision.",
785 ));
786 }
787 let mut accum = 0.0;
788 for j in 0..i {
789 accum += lhs.data[i + j * n] * solution[j + col * n];
790 }
791 let rhs_value = solution[i + col * n] - accum;
792 solution[i + col * n] = rhs_value / diag;
793 }
794 }
795
796 let rcond = diagonal_rcond(min_diag, max_diag);
797 let tensor = Tensor::new(solution, rhs.shape.clone())
798 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
799 Ok((tensor, rcond))
800}
801
802fn backward_substitution_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
803 let n = lhs.rows();
804 let nrhs = rhs.data.len() / n;
805 let mut solution = rhs.data.clone();
806 let mut min_diag = f64::INFINITY;
807 let mut max_diag = 0.0_f64;
808
809 for col in 0..nrhs {
810 for row_rev in 0..n {
811 let i = n - 1 - row_rev;
812 let diag = lhs.data[i + i * n];
813 let diag_abs = diag.abs();
814 min_diag = min_diag.min(diag_abs);
815 max_diag = max_diag.max(diag_abs);
816 if diag_abs == 0.0 {
817 return Err(builtin_error(
818 "linsolve: matrix is singular to working precision.",
819 ));
820 }
821 let mut accum = 0.0;
822 for j in (i + 1)..n {
823 accum += lhs.data[i + j * n] * solution[j + col * n];
824 }
825 let rhs_value = solution[i + col * n] - accum;
826 solution[i + col * n] = rhs_value / diag;
827 }
828 }
829
830 let rcond = diagonal_rcond(min_diag, max_diag);
831 let tensor = Tensor::new(solution, rhs.shape.clone())
832 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
833 Ok((tensor, rcond))
834}
835
836fn forward_substitution_complex(
837 lhs: &ComplexTensor,
838 rhs: &ComplexTensor,
839) -> BuiltinResult<(ComplexTensor, f64)> {
840 let n = lhs.rows;
841 let nrhs = rhs.data.len() / n;
842 let lhs_data: Vec<Complex64> = lhs
843 .data
844 .iter()
845 .map(|&(re, im)| Complex64::new(re, im))
846 .collect();
847 let mut solution: Vec<Complex64> = rhs
848 .data
849 .iter()
850 .map(|&(re, im)| Complex64::new(re, im))
851 .collect();
852 let mut min_diag = f64::INFINITY;
853 let mut max_diag = 0.0_f64;
854
855 for col in 0..nrhs {
856 for i in 0..n {
857 let diag = lhs_data[i + i * n];
858 let diag_abs = diag.norm();
859 min_diag = min_diag.min(diag_abs);
860 max_diag = max_diag.max(diag_abs);
861 if diag_abs == 0.0 {
862 return Err(builtin_error(
863 "linsolve: matrix is singular to working precision.",
864 ));
865 }
866 let mut accum = Complex64::new(0.0, 0.0);
867 for j in 0..i {
868 accum += lhs_data[i + j * n] * solution[j + col * n];
869 }
870 let rhs_value = solution[i + col * n] - accum;
871 solution[i + col * n] = rhs_value / diag;
872 }
873 }
874
875 let rcond = diagonal_rcond(min_diag, max_diag);
876 let tensor = ComplexTensor::new(
877 solution.iter().map(|c| (c.re, c.im)).collect(),
878 rhs.shape.clone(),
879 )
880 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
881 Ok((tensor, rcond))
882}
883
884fn backward_substitution_complex(
885 lhs: &ComplexTensor,
886 rhs: &ComplexTensor,
887) -> BuiltinResult<(ComplexTensor, f64)> {
888 let n = lhs.rows;
889 let nrhs = rhs.data.len() / n;
890 let lhs_data: Vec<Complex64> = lhs
891 .data
892 .iter()
893 .map(|&(re, im)| Complex64::new(re, im))
894 .collect();
895 let mut solution: Vec<Complex64> = rhs
896 .data
897 .iter()
898 .map(|&(re, im)| Complex64::new(re, im))
899 .collect();
900 let mut min_diag = f64::INFINITY;
901 let mut max_diag = 0.0_f64;
902
903 for col in 0..nrhs {
904 for row_rev in 0..n {
905 let i = n - 1 - row_rev;
906 let diag = lhs_data[i + i * n];
907 let diag_abs = diag.norm();
908 min_diag = min_diag.min(diag_abs);
909 max_diag = max_diag.max(diag_abs);
910 if diag_abs == 0.0 {
911 return Err(builtin_error(
912 "linsolve: matrix is singular to working precision.",
913 ));
914 }
915 let mut accum = Complex64::new(0.0, 0.0);
916 for j in (i + 1)..n {
917 accum += lhs_data[i + j * n] * solution[j + col * n];
918 }
919 let rhs_value = solution[i + col * n] - accum;
920 solution[i + col * n] = rhs_value / diag;
921 }
922 }
923
924 let rcond = diagonal_rcond(min_diag, max_diag);
925 let tensor = ComplexTensor::new(
926 solution.iter().map(|c| (c.re, c.im)).collect(),
927 rhs.shape.clone(),
928 )
929 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
930 Ok((tensor, rcond))
931}
932
933fn solve_general_real(lhs: &Tensor, rhs: &Tensor) -> BuiltinResult<(Tensor, f64)> {
934 let a = DMatrix::from_column_slice(lhs.rows(), lhs.cols(), &lhs.data);
935 let b = DMatrix::from_column_slice(rhs.rows(), rhs.cols(), &rhs.data);
936 let svd = SVD::new(a.clone(), true, true);
937 let rcond = singular_value_rcond(svd.singular_values.as_slice());
938 let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows(), lhs.cols());
939 let solution = svd
940 .solve(&b, tol)
941 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
942 let tensor = matrix_real_to_tensor(solution)?;
943 Ok((tensor, rcond))
944}
945
946fn solve_general_complex(
947 lhs: &ComplexTensor,
948 rhs: &ComplexTensor,
949) -> BuiltinResult<(ComplexTensor, f64)> {
950 let a_data: Vec<Complex64> = lhs
951 .data
952 .iter()
953 .map(|&(re, im)| Complex64::new(re, im))
954 .collect();
955 let b_data: Vec<Complex64> = rhs
956 .data
957 .iter()
958 .map(|&(re, im)| Complex64::new(re, im))
959 .collect();
960 let a = DMatrix::from_column_slice(lhs.rows, lhs.cols, &a_data);
961 let b = DMatrix::from_column_slice(rhs.rows, rhs.cols, &b_data);
962 let svd = SVD::new(a.clone(), true, true);
963 let rcond = singular_value_rcond(svd.singular_values.as_slice());
964 let tol = compute_svd_tolerance(svd.singular_values.as_slice(), lhs.rows, lhs.cols);
965 let solution = svd
966 .solve(&b, tol)
967 .map_err(|e| builtin_error(format!("{NAME}: {e}")))?;
968 let tensor = matrix_complex_to_tensor(solution)?;
969 Ok((tensor, rcond))
970}
971
972fn normalize_rhs_tensor(rhs: Tensor, expected_rows: usize) -> BuiltinResult<Tensor> {
973 if rhs.rows() == expected_rows {
974 return Ok(rhs);
975 }
976 if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
977 return Tensor::new(rhs.data, vec![expected_rows, 1])
978 .map_err(|e| builtin_error(format!("{NAME}: {e}")));
979 }
980 if rhs.data.is_empty() && expected_rows == 0 {
981 return Ok(rhs);
982 }
983 Err(builtin_error("Matrix dimensions must agree."))
984}
985
986fn normalize_rhs_complex(rhs: ComplexTensor, expected_rows: usize) -> BuiltinResult<ComplexTensor> {
987 if rhs.rows == expected_rows {
988 return Ok(rhs);
989 }
990 if rhs.shape.len() == 1 && rhs.shape[0] == expected_rows {
991 return ComplexTensor::new(rhs.data, vec![expected_rows, 1])
992 .map_err(|e| builtin_error(format!("{NAME}: {e}")));
993 }
994 if rhs.data.is_empty() && expected_rows == 0 {
995 return Ok(rhs);
996 }
997 Err(builtin_error("Matrix dimensions must agree."))
998}
999
1000fn enforce_rcond(options: &SolveOptions, rcond: f64) -> BuiltinResult<()> {
1001 if let Some(threshold) = options.rcond {
1002 if rcond < threshold {
1003 return Err(builtin_error(
1004 "linsolve: matrix is singular to working precision.",
1005 ));
1006 }
1007 }
1008 Ok(())
1009}
1010
1011fn compute_svd_tolerance(singular_values: &[f64], rows: usize, cols: usize) -> f64 {
1012 let max_sv = singular_values
1013 .iter()
1014 .copied()
1015 .fold(0.0_f64, |acc, value| acc.max(value.abs()));
1016 let max_dim = rows.max(cols) as f64;
1017 f64::EPSILON * max_dim * max_sv.max(1.0)
1018}
1019
1020fn matrix_real_to_tensor(matrix: DMatrix<f64>) -> BuiltinResult<Tensor> {
1021 let rows = matrix.nrows();
1022 let cols = matrix.ncols();
1023 Tensor::new(matrix.as_slice().to_vec(), vec![rows, cols])
1024 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
1025}
1026
1027fn matrix_complex_to_tensor(matrix: DMatrix<Complex64>) -> BuiltinResult<ComplexTensor> {
1028 let rows = matrix.nrows();
1029 let cols = matrix.ncols();
1030 let data: Vec<(f64, f64)> = matrix.as_slice().iter().map(|c| (c.re, c.im)).collect();
1031 ComplexTensor::new(data, vec![rows, cols]).map_err(|e| builtin_error(format!("{NAME}: {e}")))
1032}
1033
1034fn promote_real_tensor(tensor: &Tensor) -> BuiltinResult<ComplexTensor> {
1035 let data: Vec<(f64, f64)> = tensor.data.iter().map(|&re| (re, 0.0)).collect();
1036 ComplexTensor::new(data, tensor.shape.clone())
1037 .map_err(|e| builtin_error(format!("{NAME}: {e}")))
1038}
1039
1040fn ensure_matrix_shape(name: &str, shape: &[usize]) -> BuiltinResult<()> {
1041 if is_effectively_matrix(shape) {
1042 Ok(())
1043 } else {
1044 Err(builtin_error(format!(
1045 "{name}: inputs must be 2-D matrices or vectors"
1046 )))
1047 }
1048}
1049
1050fn is_effectively_matrix(shape: &[usize]) -> bool {
1051 match shape.len() {
1052 0..=2 => true,
1053 _ => shape.iter().skip(2).all(|&dim| dim == 1),
1054 }
1055}
1056
1057fn ensure_square(rows: usize, cols: usize) -> BuiltinResult<()> {
1058 if rows == cols {
1059 Ok(())
1060 } else {
1061 Err(builtin_error(
1062 "linsolve: triangular solves require a square coefficient matrix.",
1063 ))
1064 }
1065}
1066
1067fn transpose_tensor(tensor: &Tensor) -> Tensor {
1068 let rows = tensor.rows();
1069 let cols = tensor.cols();
1070 let mut data = vec![0.0; tensor.data.len()];
1071 for r in 0..rows {
1072 for c in 0..cols {
1073 data[c + r * cols] = tensor.data[r + c * rows];
1074 }
1075 }
1076 Tensor::new(data, vec![cols, rows]).expect("transpose_tensor valid")
1077}
1078
1079fn transpose_complex(tensor: &ComplexTensor) -> ComplexTensor {
1080 let rows = tensor.rows;
1081 let cols = tensor.cols;
1082 let mut data = vec![(0.0, 0.0); tensor.data.len()];
1083 for r in 0..rows {
1084 for c in 0..cols {
1085 data[c + r * cols] = tensor.data[r + c * rows];
1086 }
1087 }
1088 ComplexTensor::new(data, vec![cols, rows]).expect("transpose_complex valid")
1089}
1090
1091fn conjugate_in_place(_tensor: &mut Tensor) {
1092 }
1094
1095fn conjugate_complex_in_place(tensor: &mut ComplexTensor) {
1096 for value in &mut tensor.data {
1097 value.1 = -value.1;
1098 }
1099}
1100
1101#[cfg(test)]
1102pub(crate) mod tests {
1103 use super::*;
1104 use futures::executor::block_on;
1105 use runmat_accelerate_api::HostTensorView;
1106 use runmat_builtins::{CharArray, ResolveContext, StructValue, Type};
1107 fn unwrap_error(err: crate::RuntimeError) -> crate::RuntimeError {
1108 err
1109 }
1110
1111 fn approx_eq(actual: f64, expected: f64) {
1112 assert!((actual - expected).abs() < 1e-7);
1113 }
1114
1115 fn evaluate_args(a: Value, b: Value, rest: &[Value]) -> Result<LinsolveEval, RuntimeError> {
1116 block_on(super::evaluate_args(a, b, rest))
1117 }
1118
1119 #[test]
1120 fn linsolve_type_uses_rhs_columns() {
1121 let out = left_divide_type(
1122 &[
1123 Type::Tensor {
1124 shape: Some(vec![Some(2), Some(2)]),
1125 },
1126 Type::Tensor {
1127 shape: Some(vec![Some(2), Some(3)]),
1128 },
1129 ],
1130 &ResolveContext::new(Vec::new()),
1131 );
1132 assert_eq!(
1133 out,
1134 Type::Tensor {
1135 shape: Some(vec![Some(2), Some(3)])
1136 }
1137 );
1138 }
1139
1140 #[test]
1141 fn linsolve_descriptor_signatures_cover_core_forms() {
1142 let labels: Vec<&str> = LINSOLVE_DESCRIPTOR
1143 .signatures
1144 .iter()
1145 .map(|signature| signature.label)
1146 .collect();
1147 assert!(labels.contains(&"X = linsolve(A, B)"));
1148 assert!(labels.contains(&"X = linsolve(A, B, opts)"));
1149 assert!(labels.contains(&"[X, R] = linsolve(A, B)"));
1150 assert!(labels.contains(&"[X, R] = linsolve(A, B, opts)"));
1151 }
1152
1153 #[test]
1154 fn linsolve_descriptor_errors_have_stable_codes() {
1155 let codes: Vec<&str> = LINSOLVE_DESCRIPTOR
1156 .errors
1157 .iter()
1158 .map(|err| err.code)
1159 .collect();
1160 assert!(codes.contains(&"RM.LINSOLVE.INVALID_ARGUMENT"));
1161 assert!(codes.contains(&"RM.LINSOLVE.INVALID_INPUT"));
1162 assert!(codes.contains(&"RM.LINSOLVE.INTERNAL"));
1163 }
1164
1165 use crate::builtins::common::test_support;
1166 use runmat_accelerate_api::ProviderTelemetry;
1167
1168 fn linsolve_builtin(lhs: Value, rhs: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
1169 block_on(super::linsolve_builtin(lhs, rhs, rest))
1170 }
1171
1172 fn evaluate(lhs: Value, rhs: Value, options: SolveOptions) -> BuiltinResult<LinsolveEval> {
1173 block_on(super::evaluate(lhs, rhs, options))
1174 }
1175
1176 fn fallback_count(telemetry: &ProviderTelemetry, reason: &str) -> u64 {
1177 telemetry
1178 .solve_fallbacks
1179 .iter()
1180 .find(|entry| entry.reason == reason)
1181 .map(|entry| entry.count)
1182 .unwrap_or(0)
1183 }
1184
1185 #[cfg(feature = "wgpu")]
1186 fn kernel_launch_count(telemetry: &ProviderTelemetry, kernel: &str) -> usize {
1187 telemetry
1188 .kernel_launches
1189 .iter()
1190 .filter(|entry| entry.kernel == kernel)
1191 .count()
1192 }
1193
1194 fn clear_accel_provider_state() {
1195 runmat_accelerate_api::set_thread_provider(None);
1196 runmat_accelerate_api::clear_provider();
1197 }
1198
1199 fn host_linsolve_real(
1200 a: &Tensor,
1201 b: &Tensor,
1202 options: ProviderLinsolveOptions,
1203 ) -> (Tensor, f64) {
1204 super::linsolve_host_real_for_provider(a, b, &options).expect("host linsolve")
1205 }
1206
1207 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1208 #[test]
1209 fn linsolve_basic_square() {
1210 let _accel_guard = test_support::accel_test_lock();
1211 clear_accel_provider_state();
1212 let a = Tensor::new(vec![2.0, 1.0, 1.0, 2.0], vec![2, 2]).unwrap();
1213 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1214 let result =
1215 linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new()).expect("linsolve");
1216 let t = test_support::gather(result).expect("gather");
1217 assert_eq!(t.shape, vec![2, 1]);
1218 approx_eq(t.data[0], 1.0);
1219 approx_eq(t.data[1], 2.0);
1220 }
1221
1222 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1223 #[test]
1224 fn linsolve_lower_triangular_hint() {
1225 let _accel_guard = test_support::accel_test_lock();
1226 clear_accel_provider_state();
1227 let a = Tensor::new(
1228 vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
1229 vec![3, 3],
1230 )
1231 .unwrap();
1232 let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
1233 let mut opts = StructValue::new();
1234 opts.fields.insert("LT".to_string(), Value::Bool(true));
1235 let result = linsolve_builtin(
1236 Value::Tensor(a),
1237 Value::Tensor(b),
1238 vec![Value::Struct(opts)],
1239 )
1240 .expect("linsolve");
1241 let tensor = test_support::gather(result).expect("gather");
1242 assert_eq!(tensor.shape, vec![3, 1]);
1243 approx_eq(tensor.data[0], 3.0);
1244 approx_eq(tensor.data[1], 2.0);
1245 approx_eq(tensor.data[2], 1.0);
1246 }
1247
1248 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1249 #[test]
1250 fn linsolve_transposed_triangular_hint() {
1251 let _accel_guard = test_support::accel_test_lock();
1252 clear_accel_provider_state();
1253 let a = Tensor::new(
1254 vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
1255 vec![3, 3],
1256 )
1257 .unwrap();
1258 let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
1259 let mut opts = StructValue::new();
1260 opts.fields.insert("LT".to_string(), Value::Bool(true));
1261 opts.fields.insert(
1262 "TRANSA".to_string(),
1263 Value::CharArray(CharArray::new_row("T")),
1264 );
1265
1266 let result = linsolve_builtin(
1267 Value::Tensor(a.clone()),
1268 Value::Tensor(b.clone()),
1269 vec![Value::Struct(opts)],
1270 )
1271 .expect("linsolve");
1272 let tensor = test_support::gather(result).expect("gather");
1273 assert_eq!(tensor.shape, vec![3, 1]);
1274
1275 let a_transposed = transpose_tensor(&a);
1276 let (expected_tensor, _) =
1277 host_linsolve_real(&a_transposed, &b, ProviderLinsolveOptions::default());
1278
1279 for (actual, expected) in tensor.data.iter().zip(expected_tensor.data.iter()) {
1280 approx_eq(*actual, *expected);
1281 }
1282 }
1283
1284 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1285 #[test]
1286 fn linsolve_complex_inputs_match_residual() {
1287 let a = ComplexTensor::new(
1288 vec![(2.0, 1.0), (-1.0, 0.0), (1.0, -2.0), (3.0, -2.0)],
1289 vec![2, 2],
1290 )
1291 .unwrap();
1292 let b = ComplexTensor::new(vec![(1.0, 0.0), (4.0, 1.0)], vec![2, 1]).unwrap();
1293 let result = linsolve_builtin(
1294 Value::ComplexTensor(a.clone()),
1295 Value::ComplexTensor(b.clone()),
1296 Vec::new(),
1297 )
1298 .expect("linsolve");
1299 let Value::ComplexTensor(out) = result else {
1300 panic!("expected complex tensor result");
1301 };
1302
1303 let mat_a: Vec<Complex64> = a
1304 .data
1305 .iter()
1306 .map(|&(re, im)| Complex64::new(re, im))
1307 .collect();
1308 let mat_b: Vec<Complex64> = b
1309 .data
1310 .iter()
1311 .map(|&(re, im)| Complex64::new(re, im))
1312 .collect();
1313 let mat_x: Vec<Complex64> = out
1314 .data
1315 .iter()
1316 .map(|&(re, im)| Complex64::new(re, im))
1317 .collect();
1318 let a_mat = DMatrix::from_column_slice(a.rows, a.cols, &mat_a);
1319 let b_mat = DMatrix::from_column_slice(b.rows, b.cols, &mat_b);
1320 let x_mat = DMatrix::from_column_slice(out.rows, out.cols, &mat_x);
1321 let residual = a_mat * x_mat - b_mat;
1322 assert!(residual.norm() < 1e-10, "residual={}", residual.norm());
1323 }
1324
1325 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1326 #[test]
1327 fn linsolve_complex_conjugate_transpose_matches_explicit_reference() {
1328 let a = ComplexTensor::new(
1329 vec![(2.0, 1.0), (0.0, -1.0), (1.0, 2.0), (3.0, 0.5)],
1330 vec![2, 2],
1331 )
1332 .unwrap();
1333 let b = ComplexTensor::new(vec![(1.0, -1.0), (2.0, 0.5)], vec![2, 1]).unwrap();
1334
1335 let mut opts = StructValue::new();
1336 opts.fields.insert(
1337 "TRANSA".to_string(),
1338 Value::CharArray(CharArray::new_row("C")),
1339 );
1340 let result = linsolve_builtin(
1341 Value::ComplexTensor(a.clone()),
1342 Value::ComplexTensor(b.clone()),
1343 vec![Value::Struct(opts)],
1344 )
1345 .expect("linsolve");
1346 let Value::ComplexTensor(out) = result else {
1347 panic!("expected complex tensor result");
1348 };
1349
1350 let mut a_conj_t = transpose_complex(&a);
1351 conjugate_complex_in_place(&mut a_conj_t);
1352 let reference = evaluate(
1353 Value::ComplexTensor(a_conj_t),
1354 Value::ComplexTensor(b.clone()),
1355 SolveOptions::default(),
1356 )
1357 .expect("reference");
1358 let Value::ComplexTensor(expected) = reference.solution() else {
1359 panic!("expected complex tensor reference");
1360 };
1361
1362 assert_eq!(out.shape, expected.shape);
1363 for ((out_re, out_im), (exp_re, exp_im)) in out.data.iter().zip(expected.data.iter()) {
1364 assert!(
1365 (out_re - exp_re).abs() < 1e-10,
1366 "out_re={out_re} exp_re={exp_re}"
1367 );
1368 assert!(
1369 (out_im - exp_im).abs() < 1e-10,
1370 "out_im={out_im} exp_im={exp_im}"
1371 );
1372 }
1373 }
1374
1375 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1376 #[test]
1377 fn linsolve_rcond_enforced() {
1378 let _accel_guard = test_support::accel_test_lock();
1379 clear_accel_provider_state();
1380 let a = Tensor::new(vec![1.0, 1.0, 1.0, 1.0 + 1e-12], vec![2, 2]).unwrap();
1381 let b = Tensor::new(vec![2.0, 2.0 + 1e-12], vec![2, 1]).unwrap();
1382 let mut opts = StructValue::new();
1383 opts.fields.insert("RCOND".to_string(), Value::Num(1e-3));
1384 let err = unwrap_error(
1385 linsolve_builtin(
1386 Value::Tensor(a),
1387 Value::Tensor(b),
1388 vec![Value::Struct(opts)],
1389 )
1390 .expect_err("singular matrix must fail"),
1391 );
1392 assert!(
1393 err.message().contains("singular to working precision"),
1394 "unexpected error message: {err}"
1395 );
1396 assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_INPUT.identifier);
1397 }
1398
1399 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1400 #[test]
1401 fn linsolve_unknown_option_identifier() {
1402 let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1403 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1404 let mut opts = StructValue::new();
1405 opts.fields.insert("UNKNOWN".to_string(), Value::Bool(true));
1406 let err = unwrap_error(
1407 linsolve_builtin(
1408 Value::Tensor(a),
1409 Value::Tensor(b),
1410 vec![Value::Struct(opts)],
1411 )
1412 .expect_err("unknown option should fail"),
1413 );
1414 assert!(err.message().contains("unknown option"));
1415 assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_ARGUMENT.identifier);
1416 }
1417
1418 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1419 #[test]
1420 fn linsolve_output_count_limit_identifier() {
1421 let a = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
1422 let b = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1423 let _guard = crate::output_count::push_output_count(Some(3));
1424 let err = unwrap_error(
1425 linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1426 .expect_err("three outputs should fail"),
1427 );
1428 assert!(err.message().contains("at most two outputs"));
1429 assert_eq!(err.identifier(), LINSOLVE_ERROR_INVALID_ARGUMENT.identifier);
1430 }
1431
1432 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1433 #[test]
1434 fn linsolve_recovers_rcond_output() {
1435 let _accel_guard = test_support::accel_test_lock();
1436 clear_accel_provider_state();
1437 let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap();
1438 let b = Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap();
1439 let eval = evaluate_args(Value::Tensor(a.clone()), Value::Tensor(b.clone()), &[])
1440 .expect("evaluate");
1441 let solution_tensor = match eval.solution() {
1442 Value::Tensor(sol) => sol.clone(),
1443 Value::GpuTensor(handle) => {
1444 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather solution")
1445 }
1446 other => panic!("unexpected solution value {other:?}"),
1447 };
1448 assert_eq!(solution_tensor.shape, vec![2, 1]);
1449 approx_eq(solution_tensor.data[0], 1.0);
1450 approx_eq(solution_tensor.data[1], 2.0);
1451
1452 let rcond_value = match eval.reciprocal_condition() {
1453 Value::Num(r) => r,
1454 Value::GpuTensor(handle) => {
1455 let gathered =
1456 test_support::gather(Value::GpuTensor(handle.clone())).expect("gather rcond");
1457 gathered.data[0]
1458 }
1459 other => panic!("unexpected rcond value {other:?}"),
1460 };
1461 approx_eq(rcond_value, 1.0);
1462 }
1463
1464 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1465 #[test]
1466 fn gpu_round_trip_matches_cpu() {
1467 test_support::with_test_provider(|provider| {
1468 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1469 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1470
1471 let cpu = linsolve_builtin(
1472 Value::Tensor(a.clone()),
1473 Value::Tensor(b.clone()),
1474 Vec::new(),
1475 )
1476 .expect("cpu linsolve");
1477 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1478
1479 let view_a = HostTensorView {
1480 data: &a.data,
1481 shape: &a.shape,
1482 };
1483 let view_b = HostTensorView {
1484 data: &b.data,
1485 shape: &b.shape,
1486 };
1487 let ha = provider.upload(&view_a).expect("upload A");
1488 let hb = provider.upload(&view_b).expect("upload B");
1489
1490 let gpu_value = linsolve_builtin(
1491 Value::GpuTensor(ha.clone()),
1492 Value::GpuTensor(hb.clone()),
1493 Vec::new(),
1494 )
1495 .expect("gpu linsolve");
1496 let gathered = test_support::gather(gpu_value).expect("gather");
1497 let _ = provider.free(&ha);
1498 let _ = provider.free(&hb);
1499
1500 assert_eq!(gathered.shape, cpu_tensor.shape);
1501 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1502 assert!((gpu - cpu).abs() < 1e-12);
1503 }
1504 });
1505 }
1506
1507 #[test]
1508 fn host_inputs_auto_promote_into_provider_solve_path() {
1509 test_support::with_test_provider(|provider| {
1510 provider.reset_telemetry();
1511 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1512 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1513 let _ = linsolve_builtin(Value::Tensor(a), Value::Tensor(b), Vec::new())
1514 .expect("host linsolve");
1515 let telemetry = provider.telemetry_snapshot();
1516 assert!(telemetry.linsolve.count >= 1);
1517 assert!(fallback_count(&telemetry, "linsolve:host_reupload") >= 1);
1518 assert!(telemetry.upload_bytes > 0);
1519 assert!(telemetry.download_bytes > 0);
1520 });
1521 }
1522
1523 #[test]
1524 fn provider_telemetry_records_gpu_host_reupload_path() {
1525 test_support::with_test_provider(|provider| {
1526 provider.reset_telemetry();
1527 let a = Tensor::new(vec![2.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1528 let b = Tensor::new(vec![4.0, 5.0], vec![2, 1]).unwrap();
1529 let ha = provider
1530 .upload(&HostTensorView {
1531 data: &a.data,
1532 shape: &a.shape,
1533 })
1534 .expect("upload A");
1535 let hb = provider
1536 .upload(&HostTensorView {
1537 data: &b.data,
1538 shape: &b.shape,
1539 })
1540 .expect("upload B");
1541
1542 let _ = linsolve_builtin(
1543 Value::GpuTensor(ha.clone()),
1544 Value::GpuTensor(hb.clone()),
1545 Vec::new(),
1546 )
1547 .expect("gpu linsolve");
1548
1549 let telemetry = provider.telemetry_snapshot();
1550 assert_eq!(telemetry.linsolve.count, 1);
1551 assert!(telemetry.upload_bytes > 0);
1552 assert!(telemetry.download_bytes > 0);
1553 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 1);
1554
1555 let _ = provider.free(&ha);
1556 let _ = provider.free(&hb);
1557 });
1558 }
1559
1560 #[test]
1561 fn scalar_gpu_inputs_fall_back_without_provider_solve_dispatch() {
1562 test_support::with_test_provider(|provider| {
1563 provider.reset_telemetry();
1564 let a = Tensor::new(vec![2.0], vec![1, 1]).unwrap();
1565 let b = Tensor::new(vec![6.0], vec![1, 1]).unwrap();
1566 let ha = provider
1567 .upload(&HostTensorView {
1568 data: &a.data,
1569 shape: &a.shape,
1570 })
1571 .expect("upload A");
1572 let hb = provider
1573 .upload(&HostTensorView {
1574 data: &b.data,
1575 shape: &b.shape,
1576 })
1577 .expect("upload B");
1578
1579 let result = linsolve_builtin(
1580 Value::GpuTensor(ha.clone()),
1581 Value::GpuTensor(hb.clone()),
1582 Vec::new(),
1583 )
1584 .expect("fallback linsolve");
1585 let gathered = test_support::gather(result).expect("gather fallback");
1586 assert_eq!(gathered.data, vec![3.0]);
1587
1588 let telemetry = provider.telemetry_snapshot();
1589 assert_eq!(telemetry.linsolve.count, 0);
1590 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1591 assert!(telemetry.download_bytes > 0);
1592
1593 let _ = provider.free(&ha);
1594 let _ = provider.free(&hb);
1595 });
1596 }
1597
1598 #[cfg(feature = "wgpu")]
1599 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1600 #[test]
1601 fn wgpu_square_linsolve_avoids_host_reupload_fallback() {
1602 let _accel_guard = test_support::accel_test_lock();
1603 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1604 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1605 );
1606 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1607 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1608 return;
1609 }
1610 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1611 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1612
1613 let cpu = linsolve_builtin(
1614 Value::Tensor(a.clone()),
1615 Value::Tensor(b.clone()),
1616 Vec::new(),
1617 )
1618 .expect("cpu linsolve");
1619 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1620 provider.reset_telemetry();
1621
1622 let ha = provider
1623 .upload(&HostTensorView {
1624 data: &a.data,
1625 shape: &a.shape,
1626 })
1627 .expect("upload A");
1628 let hb = provider
1629 .upload(&HostTensorView {
1630 data: &b.data,
1631 shape: &b.shape,
1632 })
1633 .expect("upload B");
1634
1635 let _output_guard = crate::output_count::push_output_count(Some(1));
1636 let gpu_value = linsolve_builtin(
1637 Value::GpuTensor(ha.clone()),
1638 Value::GpuTensor(hb.clone()),
1639 Vec::new(),
1640 )
1641 .expect("gpu square linsolve");
1642 let gpu_solution = match gpu_value {
1643 Value::OutputList(mut outputs) => outputs.remove(0),
1644 other => other,
1645 };
1646 let gathered = test_support::gather(gpu_solution).expect("gather");
1647 let _ = provider.free(&ha);
1648 let _ = provider.free(&hb);
1649
1650 assert_eq!(gathered.shape, cpu_tensor.shape);
1651 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1652 assert!((gpu - cpu).abs() < 1e-4);
1653 }
1654
1655 let telemetry = provider.telemetry_snapshot();
1656 assert_eq!(telemetry.linsolve.count, 1);
1657 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1658 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 0);
1659 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1660 }
1661
1662 #[cfg(feature = "wgpu")]
1663 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1664 #[test]
1665 fn wgpu_square_linsolve_uses_device_path_without_output_count() {
1666 let _accel_guard = test_support::accel_test_lock();
1667 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1668 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1669 );
1670 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1671 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1672 return;
1673 }
1674 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1675 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1676
1677 let cpu = linsolve_builtin(
1678 Value::Tensor(a.clone()),
1679 Value::Tensor(b.clone()),
1680 Vec::new(),
1681 )
1682 .expect("cpu linsolve");
1683 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1684 provider.reset_telemetry();
1685
1686 let ha = provider
1687 .upload(&HostTensorView {
1688 data: &a.data,
1689 shape: &a.shape,
1690 })
1691 .expect("upload A");
1692 let hb = provider
1693 .upload(&HostTensorView {
1694 data: &b.data,
1695 shape: &b.shape,
1696 })
1697 .expect("upload B");
1698
1699 let gpu_value = linsolve_builtin(
1700 Value::GpuTensor(ha.clone()),
1701 Value::GpuTensor(hb.clone()),
1702 Vec::new(),
1703 )
1704 .expect("gpu square linsolve");
1705 let gathered = test_support::gather(gpu_value).expect("gather");
1706 let _ = provider.free(&ha);
1707 let _ = provider.free(&hb);
1708
1709 assert_eq!(gathered.shape, cpu_tensor.shape);
1710 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1711 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1712 }
1713
1714 let telemetry = provider.telemetry_snapshot();
1715 assert_eq!(telemetry.linsolve.count, 1);
1716 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1717 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1718 }
1719
1720 #[cfg(feature = "wgpu")]
1721 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1722 #[test]
1723 fn wgpu_square_linsolve_recovers_rcond_output_on_device() {
1724 let _accel_guard = test_support::accel_test_lock();
1725 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1726 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1727 );
1728 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1729 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1730 return;
1731 }
1732 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1733 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1734
1735 let (_, cpu_rcond) = host_linsolve_real(&a, &b, ProviderLinsolveOptions::default());
1736 provider.reset_telemetry();
1737
1738 let ha = provider
1739 .upload(&HostTensorView {
1740 data: &a.data,
1741 shape: &a.shape,
1742 })
1743 .expect("upload A");
1744 let hb = provider
1745 .upload(&HostTensorView {
1746 data: &b.data,
1747 shape: &b.shape,
1748 })
1749 .expect("upload B");
1750
1751 let _output_guard = crate::output_count::push_output_count(Some(2));
1752 let gpu_value = linsolve_builtin(
1753 Value::GpuTensor(ha.clone()),
1754 Value::GpuTensor(hb.clone()),
1755 Vec::new(),
1756 )
1757 .expect("gpu square linsolve");
1758 let outputs = match gpu_value {
1759 Value::OutputList(outputs) => outputs,
1760 other => panic!("expected output list, got {other:?}"),
1761 };
1762 assert_eq!(outputs.len(), 2);
1763 let gathered = test_support::gather(outputs[0].clone()).expect("gather");
1764 let gpu_rcond = match &outputs[1] {
1765 Value::Num(value) => *value,
1766 other => panic!("unexpected gpu rcond {other:?}"),
1767 };
1768 let _ = provider.free(&ha);
1769 let _ = provider.free(&hb);
1770
1771 assert_eq!(gathered.shape, vec![2, 1]);
1772 assert!(
1773 (gpu_rcond - cpu_rcond).abs() < 1e-4,
1774 "gpu={gpu_rcond} cpu={cpu_rcond}"
1775 );
1776
1777 let telemetry = provider.telemetry_snapshot();
1778 assert_eq!(telemetry.linsolve.count, 1);
1779 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1780 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1781 }
1782
1783 #[cfg(feature = "wgpu")]
1784 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1785 #[test]
1786 fn wgpu_square_linsolve_with_rcond_option_stays_on_device() {
1787 let _accel_guard = test_support::accel_test_lock();
1788 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1789 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1790 );
1791 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1792 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1793 return;
1794 }
1795
1796 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
1797 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1798 let mut cpu_opts = StructValue::new();
1799 cpu_opts
1800 .fields
1801 .insert("RCOND".to_string(), Value::Num(0.05));
1802 let cpu = linsolve_builtin(
1803 Value::Tensor(a.clone()),
1804 Value::Tensor(b.clone()),
1805 vec![Value::Struct(cpu_opts)],
1806 )
1807 .expect("cpu linsolve");
1808 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1809 provider.reset_telemetry();
1810
1811 let ha = provider
1812 .upload(&HostTensorView {
1813 data: &a.data,
1814 shape: &a.shape,
1815 })
1816 .expect("upload A");
1817 let hb = provider
1818 .upload(&HostTensorView {
1819 data: &b.data,
1820 shape: &b.shape,
1821 })
1822 .expect("upload B");
1823
1824 let _output_guard = crate::output_count::push_output_count(Some(1));
1825 let mut gpu_opts = StructValue::new();
1826 gpu_opts
1827 .fields
1828 .insert("RCOND".to_string(), Value::Num(0.05));
1829 let gpu_value = linsolve_builtin(
1830 Value::GpuTensor(ha.clone()),
1831 Value::GpuTensor(hb.clone()),
1832 vec![Value::Struct(gpu_opts)],
1833 )
1834 .expect("gpu square linsolve");
1835 let gpu_solution = match gpu_value {
1836 Value::OutputList(mut outputs) => outputs.remove(0),
1837 other => other,
1838 };
1839 let gathered = test_support::gather(gpu_solution).expect("gather");
1840 let _ = provider.free(&ha);
1841 let _ = provider.free(&hb);
1842
1843 assert_eq!(gathered.shape, cpu_tensor.shape);
1844 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1845 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1846 }
1847
1848 let telemetry = provider.telemetry_snapshot();
1849 assert_eq!(telemetry.linsolve.count, 1);
1850 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1851 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
1852 }
1853
1854 #[cfg(feature = "wgpu")]
1855 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1856 #[test]
1857 fn wgpu_tall_linsolve_avoids_host_reupload_fallback() {
1858 let _accel_guard = test_support::accel_test_lock();
1859 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1860 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1861 );
1862 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1863 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1864 return;
1865 }
1866 let a = Tensor::new(vec![1.0, 0.0, 1.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
1867 let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
1868
1869 let cpu = linsolve_builtin(
1870 Value::Tensor(a.clone()),
1871 Value::Tensor(b.clone()),
1872 Vec::new(),
1873 )
1874 .expect("cpu linsolve");
1875 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1876 provider.reset_telemetry();
1877
1878 let ha = provider
1879 .upload(&HostTensorView {
1880 data: &a.data,
1881 shape: &a.shape,
1882 })
1883 .expect("upload A");
1884 let hb = provider
1885 .upload(&HostTensorView {
1886 data: &b.data,
1887 shape: &b.shape,
1888 })
1889 .expect("upload B");
1890
1891 let _output_guard = crate::output_count::push_output_count(Some(1));
1892 let gpu_value = linsolve_builtin(
1893 Value::GpuTensor(ha.clone()),
1894 Value::GpuTensor(hb.clone()),
1895 Vec::new(),
1896 )
1897 .expect("gpu tall linsolve");
1898 let gpu_solution = match gpu_value {
1899 Value::OutputList(mut outputs) => outputs.remove(0),
1900 other => other,
1901 };
1902 let gathered = test_support::gather(gpu_solution).expect("gather");
1903 let _ = provider.free(&ha);
1904 let _ = provider.free(&hb);
1905
1906 assert_eq!(gathered.shape, cpu_tensor.shape);
1907 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1908 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1909 }
1910
1911 let telemetry = provider.telemetry_snapshot();
1912 assert_eq!(telemetry.linsolve.count, 1);
1913 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
1914 }
1915
1916 #[cfg(feature = "wgpu")]
1917 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1918 #[test]
1919 fn wgpu_posdef_linsolve_avoids_host_reupload_fallback() {
1920 let _accel_guard = test_support::accel_test_lock();
1921 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
1922 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
1923 );
1924 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
1925 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
1926 return;
1927 }
1928 let a = Tensor::new(vec![4.0, 1.0, 1.0, 3.0], vec![2, 2]).unwrap();
1929 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
1930
1931 let mut cpu_opts = StructValue::new();
1932 cpu_opts
1933 .fields
1934 .insert("POSDEF".to_string(), Value::Bool(true));
1935 let cpu = linsolve_builtin(
1936 Value::Tensor(a.clone()),
1937 Value::Tensor(b.clone()),
1938 vec![Value::Struct(cpu_opts)],
1939 )
1940 .expect("cpu linsolve");
1941 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
1942 let (_, cpu_rcond) = host_linsolve_real(
1943 &a,
1944 &b,
1945 ProviderLinsolveOptions {
1946 posdef: true,
1947 ..Default::default()
1948 },
1949 );
1950 provider.reset_telemetry();
1951
1952 let ha = provider
1953 .upload(&HostTensorView {
1954 data: &a.data,
1955 shape: &a.shape,
1956 })
1957 .expect("upload A");
1958 let hb = provider
1959 .upload(&HostTensorView {
1960 data: &b.data,
1961 shape: &b.shape,
1962 })
1963 .expect("upload B");
1964
1965 let _output_guard = crate::output_count::push_output_count(Some(2));
1966 let mut gpu_opts = StructValue::new();
1967 gpu_opts
1968 .fields
1969 .insert("POSDEF".to_string(), Value::Bool(true));
1970 let gpu_value = linsolve_builtin(
1971 Value::GpuTensor(ha.clone()),
1972 Value::GpuTensor(hb.clone()),
1973 vec![Value::Struct(gpu_opts)],
1974 )
1975 .expect("gpu posdef linsolve");
1976 let mut outputs = match gpu_value {
1977 Value::OutputList(outputs) => outputs,
1978 other => panic!("expected output list, got {other:?}"),
1979 };
1980 let gpu_rcond = match outputs.remove(1) {
1981 Value::Num(value) => value,
1982 other => panic!("unexpected rcond value {other:?}"),
1983 };
1984 let gpu_solution = outputs.remove(0);
1985 let gathered = test_support::gather(gpu_solution).expect("gather");
1986 let _ = provider.free(&ha);
1987 let _ = provider.free(&hb);
1988
1989 assert_eq!(gathered.shape, cpu_tensor.shape);
1990 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
1991 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
1992 }
1993 assert!(
1994 (gpu_rcond - cpu_rcond).abs() < 1e-4,
1995 "gpu={gpu_rcond} cpu={cpu_rcond}"
1996 );
1997
1998 let telemetry = provider.telemetry_snapshot();
1999 assert_eq!(telemetry.linsolve.count, 1);
2000 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2001 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
2002 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
2003 }
2004
2005 #[cfg(feature = "wgpu")]
2006 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2007 #[test]
2008 fn wgpu_transposed_posdef_linsolve_uses_cholesky_path() {
2009 let _accel_guard = test_support::accel_test_lock();
2010 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2011 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2012 );
2013 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2014 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2015 return;
2016 }
2017 let a = Tensor::new(vec![6.0, 2.0, 2.0, 5.0], vec![2, 2]).unwrap();
2018 let b = Tensor::new(vec![8.0, 9.0], vec![2, 1]).unwrap();
2019
2020 let mut cpu_opts = StructValue::new();
2021 cpu_opts
2022 .fields
2023 .insert("POSDEF".to_string(), Value::Bool(true));
2024 cpu_opts.fields.insert(
2025 "TRANSA".to_string(),
2026 Value::CharArray(CharArray::new_row("T")),
2027 );
2028 let cpu = linsolve_builtin(
2029 Value::Tensor(a.clone()),
2030 Value::Tensor(b.clone()),
2031 vec![Value::Struct(cpu_opts)],
2032 )
2033 .expect("cpu linsolve");
2034 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2035 provider.reset_telemetry();
2036
2037 let ha = provider
2038 .upload(&HostTensorView {
2039 data: &a.data,
2040 shape: &a.shape,
2041 })
2042 .expect("upload A");
2043 let hb = provider
2044 .upload(&HostTensorView {
2045 data: &b.data,
2046 shape: &b.shape,
2047 })
2048 .expect("upload B");
2049
2050 let _output_guard = crate::output_count::push_output_count(Some(1));
2051 let mut gpu_opts = StructValue::new();
2052 gpu_opts
2053 .fields
2054 .insert("POSDEF".to_string(), Value::Bool(true));
2055 gpu_opts.fields.insert(
2056 "TRANSA".to_string(),
2057 Value::CharArray(CharArray::new_row("T")),
2058 );
2059 let gpu_value = linsolve_builtin(
2060 Value::GpuTensor(ha.clone()),
2061 Value::GpuTensor(hb.clone()),
2062 vec![Value::Struct(gpu_opts)],
2063 )
2064 .expect("gpu transposed posdef linsolve");
2065 let gpu_solution = match gpu_value {
2066 Value::OutputList(mut outputs) => outputs.remove(0),
2067 other => other,
2068 };
2069 let gathered = test_support::gather(gpu_solution).expect("gather");
2070 let _ = provider.free(&ha);
2071 let _ = provider.free(&hb);
2072
2073 assert_eq!(gathered.shape, cpu_tensor.shape);
2074 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2075 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2076 }
2077
2078 let telemetry = provider.telemetry_snapshot();
2079 assert_eq!(telemetry.linsolve.count, 1);
2080 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2081 assert_eq!(kernel_launch_count(&telemetry, "linsolve_posdef_chol"), 1);
2082 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 0);
2083 }
2084
2085 #[cfg(feature = "wgpu")]
2086 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2087 #[test]
2088 fn wgpu_symmetric_linsolve_avoids_host_reupload_fallback() {
2089 let _accel_guard = test_support::accel_test_lock();
2090 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2091 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2092 );
2093 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2094 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2095 return;
2096 }
2097 let a = Tensor::new(vec![5.0, 2.0, 2.0, 6.0], vec![2, 2]).unwrap();
2098 let b = Tensor::new(vec![9.0, 8.0], vec![2, 1]).unwrap();
2099
2100 let mut cpu_opts = StructValue::new();
2101 cpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
2102 let cpu = linsolve_builtin(
2103 Value::Tensor(a.clone()),
2104 Value::Tensor(b.clone()),
2105 vec![Value::Struct(cpu_opts)],
2106 )
2107 .expect("cpu linsolve");
2108 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2109 provider.reset_telemetry();
2110
2111 let ha = provider
2112 .upload(&HostTensorView {
2113 data: &a.data,
2114 shape: &a.shape,
2115 })
2116 .expect("upload A");
2117 let hb = provider
2118 .upload(&HostTensorView {
2119 data: &b.data,
2120 shape: &b.shape,
2121 })
2122 .expect("upload B");
2123
2124 let _output_guard = crate::output_count::push_output_count(Some(1));
2125 let mut gpu_opts = StructValue::new();
2126 gpu_opts.fields.insert("SYM".to_string(), Value::Bool(true));
2127 let gpu_value = linsolve_builtin(
2128 Value::GpuTensor(ha.clone()),
2129 Value::GpuTensor(hb.clone()),
2130 vec![Value::Struct(gpu_opts)],
2131 )
2132 .expect("gpu symmetric linsolve");
2133 let gpu_solution = match gpu_value {
2134 Value::OutputList(mut outputs) => outputs.remove(0),
2135 other => other,
2136 };
2137 let gathered = test_support::gather(gpu_solution).expect("gather");
2138 let _ = provider.free(&ha);
2139 let _ = provider.free(&hb);
2140
2141 assert_eq!(gathered.shape, cpu_tensor.shape);
2142 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2143 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2144 }
2145
2146 let telemetry = provider.telemetry_snapshot();
2147 assert_eq!(telemetry.linsolve.count, 1);
2148 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2149 }
2150
2151 #[cfg(feature = "wgpu")]
2152 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2153 #[test]
2154 fn wgpu_transposed_square_linsolve_avoids_host_reupload_fallback() {
2155 let _accel_guard = test_support::accel_test_lock();
2156 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2157 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2158 );
2159 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2160 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2161 return;
2162 }
2163 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2164 let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2165
2166 let mut cpu_opts = StructValue::new();
2167 cpu_opts.fields.insert(
2168 "TRANSA".to_string(),
2169 Value::CharArray(CharArray::new_row("T")),
2170 );
2171 let cpu = linsolve_builtin(
2172 Value::Tensor(a.clone()),
2173 Value::Tensor(b.clone()),
2174 vec![Value::Struct(cpu_opts)],
2175 )
2176 .expect("cpu linsolve");
2177 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2178 provider.reset_telemetry();
2179
2180 let ha = provider
2181 .upload(&HostTensorView {
2182 data: &a.data,
2183 shape: &a.shape,
2184 })
2185 .expect("upload A");
2186 let hb = provider
2187 .upload(&HostTensorView {
2188 data: &b.data,
2189 shape: &b.shape,
2190 })
2191 .expect("upload B");
2192
2193 let _output_guard = crate::output_count::push_output_count(Some(1));
2194 let mut gpu_opts = StructValue::new();
2195 gpu_opts.fields.insert(
2196 "TRANSA".to_string(),
2197 Value::CharArray(CharArray::new_row("T")),
2198 );
2199 let gpu_value = linsolve_builtin(
2200 Value::GpuTensor(ha.clone()),
2201 Value::GpuTensor(hb.clone()),
2202 vec![Value::Struct(gpu_opts)],
2203 )
2204 .expect("gpu transposed square linsolve");
2205 let gpu_solution = match gpu_value {
2206 Value::OutputList(mut outputs) => outputs.remove(0),
2207 other => other,
2208 };
2209 let gathered = test_support::gather(gpu_solution).expect("gather");
2210 let _ = provider.free(&ha);
2211 let _ = provider.free(&hb);
2212
2213 assert_eq!(gathered.shape, cpu_tensor.shape);
2214 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2215 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2216 }
2217
2218 let telemetry = provider.telemetry_snapshot();
2219 assert_eq!(telemetry.linsolve.count, 1);
2220 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2221 }
2222
2223 #[cfg(feature = "wgpu")]
2224 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2225 #[test]
2226 fn wgpu_conjugate_square_linsolve_avoids_host_reupload_fallback_for_real_inputs() {
2227 let _accel_guard = test_support::accel_test_lock();
2228 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2229 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2230 );
2231 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2232 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2233 return;
2234 }
2235
2236 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2237 let b = Tensor::new(vec![5.0, 14.0], vec![2, 1]).unwrap();
2238 let mut cpu_opts = StructValue::new();
2239 cpu_opts.fields.insert(
2240 "TRANSA".to_string(),
2241 Value::CharArray(CharArray::new_row("C")),
2242 );
2243 let cpu = linsolve_builtin(
2244 Value::Tensor(a.clone()),
2245 Value::Tensor(b.clone()),
2246 vec![Value::Struct(cpu_opts)],
2247 )
2248 .expect("cpu linsolve");
2249 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2250 provider.reset_telemetry();
2251
2252 let ha = provider
2253 .upload(&HostTensorView {
2254 data: &a.data,
2255 shape: &a.shape,
2256 })
2257 .expect("upload A");
2258 let hb = provider
2259 .upload(&HostTensorView {
2260 data: &b.data,
2261 shape: &b.shape,
2262 })
2263 .expect("upload B");
2264
2265 let _output_guard = crate::output_count::push_output_count(Some(1));
2266 let mut gpu_opts = StructValue::new();
2267 gpu_opts.fields.insert(
2268 "TRANSA".to_string(),
2269 Value::CharArray(CharArray::new_row("C")),
2270 );
2271 let gpu_value = linsolve_builtin(
2272 Value::GpuTensor(ha.clone()),
2273 Value::GpuTensor(hb.clone()),
2274 vec![Value::Struct(gpu_opts)],
2275 )
2276 .expect("gpu conjugate square linsolve");
2277 let gpu_solution = match gpu_value {
2278 Value::OutputList(mut outputs) => outputs.remove(0),
2279 other => other,
2280 };
2281 let gathered = test_support::gather(gpu_solution).expect("gather");
2282 let _ = provider.free(&ha);
2283 let _ = provider.free(&hb);
2284
2285 assert_eq!(gathered.shape, cpu_tensor.shape);
2286 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2287 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2288 }
2289
2290 let telemetry = provider.telemetry_snapshot();
2291 assert_eq!(telemetry.linsolve.count, 1);
2292 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2293 assert_eq!(kernel_launch_count(&telemetry, "linsolve_tall_qr"), 1);
2294 }
2295
2296 #[cfg(feature = "wgpu")]
2297 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2298 #[test]
2299 fn wgpu_transposed_rectangular_linsolve_avoids_host_reupload_fallback() {
2300 let _accel_guard = test_support::accel_test_lock();
2301 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2302 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2303 );
2304 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2305 if provider.precision() != runmat_accelerate_api::ProviderPrecision::F32 {
2306 return;
2307 }
2308 let a = Tensor::new(vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![2, 3]).unwrap();
2309 let b = Tensor::new(vec![1.0, 2.0, 2.0], vec![3, 1]).unwrap();
2310
2311 let mut cpu_opts = StructValue::new();
2312 cpu_opts.fields.insert(
2313 "TRANSA".to_string(),
2314 Value::CharArray(CharArray::new_row("T")),
2315 );
2316 cpu_opts
2317 .fields
2318 .insert("RECT".to_string(), Value::Bool(true));
2319 let cpu = linsolve_builtin(
2320 Value::Tensor(a.clone()),
2321 Value::Tensor(b.clone()),
2322 vec![Value::Struct(cpu_opts)],
2323 )
2324 .expect("cpu linsolve");
2325 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2326 provider.reset_telemetry();
2327
2328 let ha = provider
2329 .upload(&HostTensorView {
2330 data: &a.data,
2331 shape: &a.shape,
2332 })
2333 .expect("upload A");
2334 let hb = provider
2335 .upload(&HostTensorView {
2336 data: &b.data,
2337 shape: &b.shape,
2338 })
2339 .expect("upload B");
2340
2341 let _output_guard = crate::output_count::push_output_count(Some(1));
2342 let mut gpu_opts = StructValue::new();
2343 gpu_opts.fields.insert(
2344 "TRANSA".to_string(),
2345 Value::CharArray(CharArray::new_row("T")),
2346 );
2347 gpu_opts
2348 .fields
2349 .insert("RECT".to_string(), Value::Bool(true));
2350 let gpu_value = linsolve_builtin(
2351 Value::GpuTensor(ha.clone()),
2352 Value::GpuTensor(hb.clone()),
2353 vec![Value::Struct(gpu_opts)],
2354 )
2355 .expect("gpu transposed rectangular linsolve");
2356 let gpu_solution = match gpu_value {
2357 Value::OutputList(mut outputs) => outputs.remove(0),
2358 other => other,
2359 };
2360 let gathered = test_support::gather(gpu_solution).expect("gather");
2361 let _ = provider.free(&ha);
2362 let _ = provider.free(&hb);
2363
2364 assert_eq!(gathered.shape, cpu_tensor.shape);
2365 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2366 assert!((gpu - cpu).abs() < 1e-4, "gpu={gpu} cpu={cpu}");
2367 }
2368
2369 let telemetry = provider.telemetry_snapshot();
2370 assert_eq!(telemetry.linsolve.count, 1);
2371 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2372 }
2373
2374 #[cfg(feature = "wgpu")]
2375 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2376 #[test]
2377 fn wgpu_triangular_hint_avoids_host_reupload_fallback() {
2378 let _accel_guard = test_support::accel_test_lock();
2379 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2380 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2381 );
2382 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2383 let a = Tensor::new(
2384 vec![3.0, -1.0, 4.0, 0.0, 2.0, 1.0, 0.0, 0.0, 5.0],
2385 vec![3, 3],
2386 )
2387 .unwrap();
2388 let b = Tensor::new(vec![9.0, 1.0, 19.0], vec![3, 1]).unwrap();
2389
2390 let cpu = linsolve_builtin(Value::Tensor(a.clone()), Value::Tensor(b.clone()), {
2391 let mut opts = StructValue::new();
2392 opts.fields.insert("LT".to_string(), Value::Bool(true));
2393 vec![Value::Struct(opts)]
2394 })
2395 .expect("cpu linsolve");
2396 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2397 provider.reset_telemetry();
2398
2399 let ha = provider
2400 .upload(&HostTensorView {
2401 data: &a.data,
2402 shape: &a.shape,
2403 })
2404 .expect("upload A");
2405 let hb = provider
2406 .upload(&HostTensorView {
2407 data: &b.data,
2408 shape: &b.shape,
2409 })
2410 .expect("upload B");
2411
2412 let _output_guard = crate::output_count::push_output_count(Some(1));
2413 let mut opts = StructValue::new();
2414 opts.fields.insert("LT".to_string(), Value::Bool(true));
2415 let gpu_value = linsolve_builtin(
2416 Value::GpuTensor(ha.clone()),
2417 Value::GpuTensor(hb.clone()),
2418 vec![Value::Struct(opts)],
2419 )
2420 .expect("gpu triangular linsolve");
2421 let gpu_solution = match gpu_value {
2422 Value::OutputList(mut outputs) => outputs.remove(0),
2423 other => other,
2424 };
2425 let gathered = test_support::gather(gpu_solution).expect("gather");
2426 let _ = provider.free(&ha);
2427 let _ = provider.free(&hb);
2428
2429 assert_eq!(gathered.shape, cpu_tensor.shape);
2430 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2431 assert!((gpu - cpu).abs() < 1e-5);
2432 }
2433
2434 let telemetry = provider.telemetry_snapshot();
2435 assert_eq!(telemetry.linsolve.count, 1);
2436 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2437 }
2438
2439 #[cfg(feature = "wgpu")]
2440 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2441 #[test]
2442 fn wgpu_transposed_triangular_hint_avoids_host_reupload_fallback() {
2443 let _accel_guard = test_support::accel_test_lock();
2444 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2445 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2446 );
2447 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2448 let a = Tensor::new(
2449 vec![3.0, 1.0, 0.0, 0.0, 4.0, 2.0, 0.0, 0.0, 5.0],
2450 vec![3, 3],
2451 )
2452 .unwrap();
2453 let b = Tensor::new(vec![5.0, 14.0, 23.0], vec![3, 1]).unwrap();
2454
2455 let mut cpu_opts = StructValue::new();
2456 cpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2457 cpu_opts.fields.insert(
2458 "TRANSA".to_string(),
2459 Value::CharArray(CharArray::new_row("T")),
2460 );
2461 let cpu = linsolve_builtin(
2462 Value::Tensor(a.clone()),
2463 Value::Tensor(b.clone()),
2464 vec![Value::Struct(cpu_opts)],
2465 )
2466 .expect("cpu linsolve");
2467 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2468 provider.reset_telemetry();
2469
2470 let ha = provider
2471 .upload(&HostTensorView {
2472 data: &a.data,
2473 shape: &a.shape,
2474 })
2475 .expect("upload A");
2476 let hb = provider
2477 .upload(&HostTensorView {
2478 data: &b.data,
2479 shape: &b.shape,
2480 })
2481 .expect("upload B");
2482
2483 let _output_guard = crate::output_count::push_output_count(Some(1));
2484 let mut gpu_opts = StructValue::new();
2485 gpu_opts.fields.insert("LT".to_string(), Value::Bool(true));
2486 gpu_opts.fields.insert(
2487 "TRANSA".to_string(),
2488 Value::CharArray(CharArray::new_row("T")),
2489 );
2490 let gpu_value = linsolve_builtin(
2491 Value::GpuTensor(ha.clone()),
2492 Value::GpuTensor(hb.clone()),
2493 vec![Value::Struct(gpu_opts)],
2494 )
2495 .expect("gpu transposed triangular linsolve");
2496 let gpu_solution = match gpu_value {
2497 Value::OutputList(mut outputs) => outputs.remove(0),
2498 other => other,
2499 };
2500 let gathered = test_support::gather(gpu_solution).expect("gather");
2501 let _ = provider.free(&ha);
2502 let _ = provider.free(&hb);
2503
2504 assert_eq!(gathered.shape, cpu_tensor.shape);
2505 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2506 assert!((gpu - cpu).abs() < 1e-5);
2507 }
2508
2509 let telemetry = provider.telemetry_snapshot();
2510 assert_eq!(telemetry.linsolve.count, 1);
2511 assert_eq!(fallback_count(&telemetry, "linsolve:host_reupload"), 0);
2512 }
2513
2514 #[cfg(feature = "wgpu")]
2515 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
2516 #[test]
2517 fn wgpu_round_trip_matches_cpu() {
2518 let _accel_guard = test_support::accel_test_lock();
2519 let _ = runmat_accelerate::backend::wgpu::provider::register_wgpu_provider(
2520 runmat_accelerate::backend::wgpu::provider::WgpuProviderOptions::default(),
2521 );
2522 let provider = runmat_accelerate_api::provider().expect("wgpu provider");
2523 let tol = match provider.precision() {
2524 runmat_accelerate_api::ProviderPrecision::F64 => 1e-12,
2525 runmat_accelerate_api::ProviderPrecision::F32 => 1e-5,
2526 };
2527
2528 let a = Tensor::new(vec![3.0, 1.0, 2.0, 4.0], vec![2, 2]).unwrap();
2529 let b = Tensor::new(vec![7.0, 8.0], vec![2, 1]).unwrap();
2530
2531 let cpu = linsolve_builtin(
2532 Value::Tensor(a.clone()),
2533 Value::Tensor(b.clone()),
2534 Vec::new(),
2535 )
2536 .expect("cpu linsolve");
2537 let cpu_tensor = test_support::gather(cpu).expect("cpu gather");
2538
2539 let view_a = HostTensorView {
2540 data: &a.data,
2541 shape: &a.shape,
2542 };
2543 let view_b = HostTensorView {
2544 data: &b.data,
2545 shape: &b.shape,
2546 };
2547 let ha = provider.upload(&view_a).expect("upload A");
2548 let hb = provider.upload(&view_b).expect("upload B");
2549 let gpu_value = linsolve_builtin(
2550 Value::GpuTensor(ha.clone()),
2551 Value::GpuTensor(hb.clone()),
2552 Vec::new(),
2553 )
2554 .expect("gpu linsolve");
2555 let gathered = test_support::gather(gpu_value).expect("gather");
2556 let _ = provider.free(&ha);
2557 let _ = provider.free(&hb);
2558
2559 assert_eq!(gathered.shape, cpu_tensor.shape);
2560 for (gpu, cpu) in gathered.data.iter().zip(cpu_tensor.data.iter()) {
2561 assert!((gpu - cpu).abs() < tol);
2562 }
2563 }
2564}