1use std::collections::HashMap;
4use std::sync::OnceLock;
5
6use runmat_builtins::{
7 Access, BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 CellArray, CharArray, ClassDef, MethodDef, ObjectInstance, PropertyDef, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::control::type_resolvers::ss_type;
18use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
19
20const BUILTIN_NAME: &str = "ss";
21const SS_CLASS: &str = "ss";
22
23static SS_CLASS_REGISTERED: OnceLock<()> = OnceLock::new();
24
25const SS_OUTPUT_SYS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
26 name: "sys",
27 ty: BuiltinParamType::Any,
28 arity: BuiltinParamArity::Required,
29 default: None,
30 description: "State-space model object.",
31}];
32const SS_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
33 name: "A",
34 ty: BuiltinParamType::NumericArray,
35 arity: BuiltinParamArity::Required,
36 default: None,
37 description: "State matrix with shape n-by-n.",
38};
39const SS_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
40 name: "B",
41 ty: BuiltinParamType::NumericArray,
42 arity: BuiltinParamArity::Required,
43 default: None,
44 description: "Input matrix with shape n-by-nu.",
45};
46const SS_PARAM_C: BuiltinParamDescriptor = BuiltinParamDescriptor {
47 name: "C",
48 ty: BuiltinParamType::NumericArray,
49 arity: BuiltinParamArity::Required,
50 default: None,
51 description: "Output matrix with shape ny-by-n.",
52};
53const SS_PARAM_D: BuiltinParamDescriptor = BuiltinParamDescriptor {
54 name: "D",
55 ty: BuiltinParamType::NumericArray,
56 arity: BuiltinParamArity::Required,
57 default: None,
58 description: "Feedthrough matrix with shape ny-by-nu.",
59};
60const SS_INPUTS_ABCD: [BuiltinParamDescriptor; 4] =
61 [SS_PARAM_A, SS_PARAM_B, SS_PARAM_C, SS_PARAM_D];
62const SS_INPUTS_ABCD_TS: [BuiltinParamDescriptor; 5] = [
63 SS_PARAM_A,
64 SS_PARAM_B,
65 SS_PARAM_C,
66 SS_PARAM_D,
67 BuiltinParamDescriptor {
68 name: "Ts",
69 ty: BuiltinParamType::NumericScalar,
70 arity: BuiltinParamArity::Optional,
71 default: Some("0.0"),
72 description: "Sample time (0 for continuous-time model).",
73 },
74];
75const SS_INPUTS_ABCD_NAMEVALUE: [BuiltinParamDescriptor; 6] = [
76 SS_PARAM_A,
77 SS_PARAM_B,
78 SS_PARAM_C,
79 SS_PARAM_D,
80 BuiltinParamDescriptor {
81 name: "name",
82 ty: BuiltinParamType::StringScalar,
83 arity: BuiltinParamArity::Variadic,
84 default: None,
85 description: "Option name ('Ts' or 'SampleTime').",
86 },
87 BuiltinParamDescriptor {
88 name: "value",
89 ty: BuiltinParamType::Any,
90 arity: BuiltinParamArity::Variadic,
91 default: None,
92 description: "Option value.",
93 },
94];
95const SS_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
96 BuiltinSignatureDescriptor {
97 label: "sys = ss(A, B, C, D)",
98 inputs: &SS_INPUTS_ABCD,
99 outputs: &SS_OUTPUT_SYS,
100 },
101 BuiltinSignatureDescriptor {
102 label: "sys = ss(A, B, C, D, Ts)",
103 inputs: &SS_INPUTS_ABCD_TS,
104 outputs: &SS_OUTPUT_SYS,
105 },
106 BuiltinSignatureDescriptor {
107 label: "sys = ss(A, B, C, D, \"Ts\", Ts)",
108 inputs: &SS_INPUTS_ABCD_NAMEVALUE,
109 outputs: &SS_OUTPUT_SYS,
110 },
111 BuiltinSignatureDescriptor {
112 label: "sys = ss(A, B, C, D, name, value, ...)",
113 inputs: &SS_INPUTS_ABCD_NAMEVALUE,
114 outputs: &SS_OUTPUT_SYS,
115 },
116];
117const SS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118 code: "RM.SS.INVALID_ARGUMENT",
119 identifier: Some("RunMat:ss:InvalidArgument"),
120 when: "Arguments do not match supported ss invocation forms.",
121 message: "ss: invalid argument",
122};
123const SS_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124 code: "RM.SS.INVALID_OPTION",
125 identifier: Some("RunMat:ss:InvalidOption"),
126 when: "A name/value option token is unsupported or malformed.",
127 message: "ss: invalid option",
128};
129const SS_ERROR_INVALID_SAMPLE_TIME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130 code: "RM.SS.INVALID_SAMPLE_TIME",
131 identifier: Some("RunMat:ss:InvalidSampleTime"),
132 when: "Sample time is not a finite non-negative scalar.",
133 message: "ss: sample time must be a finite non-negative scalar",
134};
135const SS_ERROR_INVALID_DIMENSIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136 code: "RM.SS.INVALID_DIMENSIONS",
137 identifier: Some("RunMat:ss:InvalidDimensions"),
138 when: "A, B, C, and D dimensions do not define a consistent state-space model.",
139 message: "ss: invalid state-space matrix dimensions",
140};
141const SS_ERROR_UNSUPPORTED_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
142 code: "RM.SS.UNSUPPORTED_INPUT",
143 identifier: Some("RunMat:ss:UnsupportedInput"),
144 when: "An input is complex, sparse, logical, or another unsupported model form.",
145 message: "ss: unsupported input",
146};
147const SS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
148 code: "RM.SS.INTERNAL",
149 identifier: Some("RunMat:ss:Internal"),
150 when: "Internal tensor/object construction failed.",
151 message: "ss: internal error",
152};
153const SS_ERRORS: [BuiltinErrorDescriptor; 6] = [
154 SS_ERROR_INVALID_ARGUMENT,
155 SS_ERROR_INVALID_OPTION,
156 SS_ERROR_INVALID_SAMPLE_TIME,
157 SS_ERROR_INVALID_DIMENSIONS,
158 SS_ERROR_UNSUPPORTED_INPUT,
159 SS_ERROR_INTERNAL,
160];
161pub const SS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
162 signatures: &SS_SIGNATURES,
163 output_mode: BuiltinOutputMode::Fixed,
164 completion_policy: BuiltinCompletionPolicy::Public,
165 errors: &SS_ERRORS,
166};
167
168#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::ss")]
169pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
170 name: "ss",
171 op_kind: GpuOpKind::Custom("state-space-model-constructor"),
172 supported_precisions: &[],
173 broadcast: BroadcastSemantics::None,
174 provider_hooks: &[],
175 constant_strategy: ConstantStrategy::InlineLiteral,
176 residency: ResidencyPolicy::GatherImmediately,
177 nan_mode: ReductionNaN::Include,
178 two_pass_threshold: None,
179 workgroup_size: None,
180 accepts_nan_mode: false,
181 notes: "Object construction runs on the host. gpuArray matrix inputs are gathered before validating and storing the state-space metadata.",
182};
183
184#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::ss")]
185pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
186 name: "ss",
187 shape: ShapeRequirements::Any,
188 constant_strategy: ConstantStrategy::InlineLiteral,
189 elementwise: None,
190 reduction: None,
191 emits_nan: false,
192 notes: "State-space construction is metadata-only and terminates numeric fusion chains.",
193};
194
195fn ss_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
196 ss_error_with_message(error.message, error)
197}
198
199fn ss_error_with_detail(
200 error: &'static BuiltinErrorDescriptor,
201 detail: impl AsRef<str>,
202) -> RuntimeError {
203 ss_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
204}
205
206fn ss_error_with_message(
207 message: impl Into<String>,
208 error: &'static BuiltinErrorDescriptor,
209) -> RuntimeError {
210 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
211 if let Some(identifier) = error.identifier {
212 builder = builder.with_identifier(identifier);
213 }
214 builder.build()
215}
216
217fn ensure_ss_class_registered() {
218 SS_CLASS_REGISTERED.get_or_init(|| {
219 let mut properties = HashMap::new();
220 for name in [
221 "A",
222 "B",
223 "C",
224 "D",
225 "Ts",
226 "InputDelay",
227 "OutputDelay",
228 "StateName",
229 "InputName",
230 "OutputName",
231 ] {
232 properties.insert(
233 name.to_string(),
234 PropertyDef {
235 name: name.to_string(),
236 is_static: false,
237 is_constant: false,
238 is_dependent: false,
239 get_access: Access::Public,
240 set_access: Access::Public,
241 default_value: None,
242 },
243 );
244 }
245
246 let methods: HashMap<String, MethodDef> = HashMap::new();
247 runmat_builtins::register_class(ClassDef {
248 name: SS_CLASS.to_string(),
249 parent: None,
250 properties,
251 methods,
252 });
253 });
254}
255
256#[runtime_builtin(
257 name = "ss",
258 category = "control",
259 summary = "Create state-space model objects from A, B, C, and D matrices.",
260 keywords = "ss,state space,control system,model,matrices",
261 type_resolver(ss_type),
262 descriptor(crate::builtins::control::ss::SS_DESCRIPTOR),
263 builtin_path = "crate::builtins::control::ss"
264)]
265async fn ss_builtin(
266 a: Value,
267 b: Value,
268 c: Value,
269 d: Value,
270 rest: Vec<Value>,
271) -> BuiltinResult<Value> {
272 let options = SsOptions::parse(&rest)?;
273 let a = RealMatrix::parse("A", a).await?;
274 let b = RealMatrix::parse("B", b).await?;
275 let c = RealMatrix::parse("C", c).await?;
276 let d = RealMatrix::parse("D", d).await?;
277
278 validate_state_space_dimensions(&a, &b, &c, &d)?;
279
280 let state_count = a.rows;
281 let input_count = b.cols;
282 let output_count = c.rows;
283
284 ensure_ss_class_registered();
285 let mut object = ObjectInstance::new(SS_CLASS.to_string());
286 object.properties.insert("A".to_string(), a.into_value());
287 object.properties.insert("B".to_string(), b.into_value());
288 object.properties.insert("C".to_string(), c.into_value());
289 object.properties.insert("D".to_string(), d.into_value());
290 object
291 .properties
292 .insert("Ts".to_string(), Value::Num(options.sample_time));
293 object.properties.insert(
294 "InputDelay".to_string(),
295 zero_tensor_value(vec![input_count, 1])?,
296 );
297 object.properties.insert(
298 "OutputDelay".to_string(),
299 zero_tensor_value(vec![output_count, 1])?,
300 );
301 object.properties.insert(
302 "StateName".to_string(),
303 empty_name_cell_value(state_count, 1)?,
304 );
305 object.properties.insert(
306 "InputName".to_string(),
307 empty_name_cell_value(input_count, 1)?,
308 );
309 object.properties.insert(
310 "OutputName".to_string(),
311 empty_name_cell_value(output_count, 1)?,
312 );
313 Ok(Value::Object(object))
314}
315
316#[derive(Clone)]
317struct SsOptions {
318 sample_time: f64,
319}
320
321impl SsOptions {
322 fn parse(rest: &[Value]) -> BuiltinResult<Self> {
323 let mut options = Self { sample_time: 0.0 };
324
325 match rest {
326 [] => {}
327 [sample_time] => options.sample_time = parse_sample_time(sample_time)?,
328 _ => {
329 if !rest.len().is_multiple_of(2) {
330 return Err(ss_error_with_detail(
331 &SS_ERROR_INVALID_ARGUMENT,
332 "optional arguments must be name-value pairs or a scalar sample time",
333 ));
334 }
335 let mut idx = 0;
336 while idx < rest.len() {
337 let name = scalar_text(&rest[idx], "option name")?;
338 let lowered = name.trim().to_ascii_lowercase();
339 let value = &rest[idx + 1];
340 match lowered.as_str() {
341 "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
342 _ => {
343 return Err(ss_error_with_detail(
344 &SS_ERROR_INVALID_OPTION,
345 format!("unsupported option '{name}'"),
346 ));
347 }
348 }
349 idx += 2;
350 }
351 }
352 }
353
354 Ok(options)
355 }
356}
357
358fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
359 let sample_time = match value {
360 Value::Num(n) => *n,
361 Value::Int(i) => i.to_f64(),
362 other => {
363 return Err(ss_error_with_detail(
364 &SS_ERROR_INVALID_SAMPLE_TIME,
365 format!("expected non-negative scalar, got {other:?}"),
366 ))
367 }
368 };
369 if !sample_time.is_finite() || sample_time < 0.0 {
370 return Err(ss_error(&SS_ERROR_INVALID_SAMPLE_TIME));
371 }
372 Ok(sample_time)
373}
374
375fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
376 match value {
377 Value::String(text) => Ok(text.clone()),
378 Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
379 Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
380 other => Err(ss_error_with_detail(
381 &SS_ERROR_INVALID_ARGUMENT,
382 format!("{context} must be a string scalar or character vector, got {other:?}"),
383 )),
384 }
385}
386
387#[derive(Clone)]
388struct RealMatrix {
389 tensor: Tensor,
390 rows: usize,
391 cols: usize,
392}
393
394impl RealMatrix {
395 async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
396 let gathered = dispatcher::gather_if_needed_async(&value).await?;
397 let tensor = match gathered {
398 Value::Tensor(tensor) => tensor,
399 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|err| {
400 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
401 })?,
402 Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|err| {
403 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
404 })?,
405 Value::Complex(_, _) | Value::ComplexTensor(_) => {
406 return Err(ss_error_with_detail(
407 &SS_ERROR_UNSUPPORTED_INPUT,
408 format!(
409 "{label} must be finite real numeric data; complex input is unsupported"
410 ),
411 ));
412 }
413 other => {
414 return Err(ss_error_with_detail(
415 &SS_ERROR_UNSUPPORTED_INPUT,
416 format!("{label} must be a finite real numeric matrix, got {other:?}"),
417 ));
418 }
419 };
420
421 if tensor.shape.len() > 2 {
422 return Err(ss_error_with_detail(
423 &SS_ERROR_INVALID_DIMENSIONS,
424 format!("{label} must be a 2-D matrix, got shape {:?}", tensor.shape),
425 ));
426 }
427 if tensor.data.iter().any(|value| !value.is_finite()) {
428 return Err(ss_error_with_detail(
429 &SS_ERROR_UNSUPPORTED_INPUT,
430 format!("{label} must contain only finite real values"),
431 ));
432 }
433
434 Ok(Self {
435 rows: tensor.rows,
436 cols: tensor.cols,
437 tensor,
438 })
439 }
440
441 fn into_value(self) -> Value {
442 Value::Tensor(self.tensor)
443 }
444}
445
446fn validate_state_space_dimensions(
447 a: &RealMatrix,
448 b: &RealMatrix,
449 c: &RealMatrix,
450 d: &RealMatrix,
451) -> BuiltinResult<()> {
452 if a.rows != a.cols {
453 return Err(ss_error_with_detail(
454 &SS_ERROR_INVALID_DIMENSIONS,
455 format!("A must be square, got {}x{}", a.rows, a.cols),
456 ));
457 }
458
459 let state_count = a.rows;
460 if b.rows != state_count {
461 return Err(ss_error_with_detail(
462 &SS_ERROR_INVALID_DIMENSIONS,
463 format!(
464 "B must have {} rows to match A, got {}x{}",
465 state_count, b.rows, b.cols
466 ),
467 ));
468 }
469 if c.cols != state_count {
470 return Err(ss_error_with_detail(
471 &SS_ERROR_INVALID_DIMENSIONS,
472 format!(
473 "C must have {} columns to match A, got {}x{}",
474 state_count, c.rows, c.cols
475 ),
476 ));
477 }
478 if d.rows != c.rows || d.cols != b.cols {
479 return Err(ss_error_with_detail(
480 &SS_ERROR_INVALID_DIMENSIONS,
481 format!(
482 "D must have shape {}x{} to match C outputs and B inputs, got {}x{}",
483 c.rows, b.cols, d.rows, d.cols
484 ),
485 ));
486 }
487
488 Ok(())
489}
490
491fn zero_tensor_value(shape: Vec<usize>) -> BuiltinResult<Value> {
492 let len = shape.iter().product();
493 Tensor::new(vec![0.0; len], shape)
494 .map(Value::Tensor)
495 .map_err(|err| {
496 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
497 })
498}
499
500fn empty_name_cell_value(rows: usize, cols: usize) -> BuiltinResult<Value> {
501 let len = rows * cols;
502 let values = (0..len)
503 .map(|_| Value::CharArray(CharArray::new_row("")))
504 .collect();
505 CellArray::new(values, rows, cols)
506 .map(Value::Cell)
507 .map_err(|err| {
508 ss_error_with_detail(
509 &SS_ERROR_INTERNAL,
510 format!("failed to build cell array: {err}"),
511 )
512 })
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::builtins::common::test_support;
519 use futures::executor::block_on;
520 use runmat_builtins::IntValue;
521
522 fn run_ss(a: Value, b: Value, c: Value, d: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
523 block_on(ss_builtin(a, b, c, d, rest))
524 }
525
526 fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
527 let Value::Object(object) = value else {
528 panic!("expected object, got {value:?}");
529 };
530 object
531 .properties
532 .get(name)
533 .unwrap_or_else(|| panic!("missing property {name}"))
534 }
535
536 fn assert_tensor(value: &Value, shape: &[usize], data: &[f64]) {
537 match value {
538 Value::Tensor(tensor) => {
539 assert_eq!(tensor.shape, shape);
540 assert_eq!(tensor.data, data);
541 }
542 other => panic!("expected tensor, got {other:?}"),
543 }
544 }
545
546 #[test]
547 fn ss_descriptor_signatures_cover_core_forms() {
548 let labels: Vec<&str> = SS_DESCRIPTOR
549 .signatures
550 .iter()
551 .map(|sig| sig.label)
552 .collect();
553 assert!(labels.contains(&"sys = ss(A, B, C, D)"));
554 assert!(labels.contains(&"sys = ss(A, B, C, D, Ts)"));
555 assert!(labels.contains(&"sys = ss(A, B, C, D, \"Ts\", Ts)"));
556 assert!(labels.contains(&"sys = ss(A, B, C, D, name, value, ...)"));
557 }
558
559 #[test]
560 fn ss_constructs_continuous_state_space_object() {
561 let sys = run_ss(
562 Value::Tensor(Tensor::new(vec![0.0, -2.0, 1.0, -3.0], vec![2, 2]).unwrap()),
563 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap()),
564 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
565 Value::Num(0.0),
566 Vec::new(),
567 )
568 .expect("ss");
569
570 let Value::Object(object) = &sys else {
571 panic!("expected object");
572 };
573 assert_eq!(object.class_name, "ss");
574 assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
575 assert_tensor(property(&sys, "A"), &[2, 2], &[0.0, -2.0, 1.0, -3.0]);
576 assert_tensor(property(&sys, "B"), &[2, 1], &[0.0, 1.0]);
577 assert_tensor(property(&sys, "C"), &[1, 2], &[1.0, 0.0]);
578 assert_tensor(property(&sys, "D"), &[1, 1], &[0.0]);
579 assert_tensor(property(&sys, "InputDelay"), &[1, 1], &[0.0]);
580 assert_tensor(property(&sys, "OutputDelay"), &[1, 1], &[0.0]);
581 }
582
583 #[test]
584 fn ss_preserves_matrix_orientation_for_mimo_systems() {
585 let sys = run_ss(
586 Value::Num(-1.0),
587 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
588 Value::Tensor(Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap()),
589 Value::Tensor(Tensor::new(vec![0.0, 0.1, 0.2, 0.3], vec![2, 2]).unwrap()),
590 Vec::new(),
591 )
592 .expect("ss");
593
594 assert_tensor(property(&sys, "A"), &[1, 1], &[-1.0]);
595 assert_tensor(property(&sys, "B"), &[1, 2], &[1.0, 2.0]);
596 assert_tensor(property(&sys, "C"), &[2, 1], &[3.0, 4.0]);
597 assert_tensor(property(&sys, "D"), &[2, 2], &[0.0, 0.1, 0.2, 0.3]);
598 assert_tensor(property(&sys, "InputDelay"), &[2, 1], &[0.0, 0.0]);
599 assert_tensor(property(&sys, "OutputDelay"), &[2, 1], &[0.0, 0.0]);
600 }
601
602 #[test]
603 fn ss_accepts_discrete_sample_time() {
604 let sys = run_ss(
605 Value::Int(IntValue::I32(1)),
606 Value::Int(IntValue::I32(2)),
607 Value::Int(IntValue::I32(3)),
608 Value::Int(IntValue::I32(4)),
609 vec![Value::Num(0.25)],
610 )
611 .expect("ss");
612
613 assert_eq!(property(&sys, "Ts"), &Value::Num(0.25));
614 }
615
616 #[test]
617 fn ss_accepts_sample_time_name_value_options() {
618 let sys = run_ss(
619 Value::Num(1.0),
620 Value::Num(2.0),
621 Value::Num(3.0),
622 Value::Num(4.0),
623 vec![Value::from("SampleTime"), Value::Num(0.5)],
624 )
625 .expect("ss");
626
627 assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
628 }
629
630 #[test]
631 fn ss_rejects_nonsquare_a_matrix() {
632 let err = run_ss(
633 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
634 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
635 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
636 Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
637 Vec::new(),
638 )
639 .expect_err("nonsquare A should fail");
640 assert!(err.message().contains("A must be square"));
641 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
642 }
643
644 #[test]
645 fn ss_rejects_b_row_mismatch() {
646 let err = run_ss(
647 Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
648 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
649 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
650 Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
651 Vec::new(),
652 )
653 .expect_err("B mismatch should fail");
654 assert!(err.message().contains("B must have 2 rows"));
655 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
656 }
657
658 #[test]
659 fn ss_rejects_d_shape_mismatch() {
660 let err = run_ss(
661 Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
662 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap()),
663 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
664 Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
665 Vec::new(),
666 )
667 .expect_err("D mismatch should fail");
668 assert!(err.message().contains("D must have shape 1x1"));
669 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
670 }
671
672 #[test]
673 fn ss_rejects_invalid_sample_time() {
674 let err = run_ss(
675 Value::Num(1.0),
676 Value::Num(1.0),
677 Value::Num(1.0),
678 Value::Num(0.0),
679 vec![Value::Num(-0.1)],
680 )
681 .expect_err("negative Ts should fail");
682 assert_eq!(err.identifier(), SS_ERROR_INVALID_SAMPLE_TIME.identifier);
683 }
684
685 #[test]
686 fn ss_rejects_complex_inputs() {
687 let err = run_ss(
688 Value::Complex(1.0, 1.0),
689 Value::Num(1.0),
690 Value::Num(1.0),
691 Value::Num(0.0),
692 Vec::new(),
693 )
694 .expect_err("complex A should fail");
695 assert!(err.message().contains("complex input is unsupported"));
696 assert_eq!(err.identifier(), SS_ERROR_UNSUPPORTED_INPUT.identifier);
697 }
698
699 #[test]
700 fn ss_gpu_matrix_input_gathers_to_host() {
701 test_support::with_test_provider(|provider| {
702 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
703 let view = runmat_accelerate_api::HostTensorView {
704 data: &tensor.data,
705 shape: &tensor.shape,
706 };
707 let handle = provider.upload(&view).expect("upload");
708 let sys = run_ss(
709 Value::GpuTensor(handle),
710 Value::Num(2.0),
711 Value::Num(3.0),
712 Value::Num(4.0),
713 Vec::new(),
714 )
715 .expect("ss");
716
717 assert_tensor(property(&sys, "A"), &[1, 1], &[1.0]);
718 });
719 }
720}