1use std::cell::Cell;
4use std::collections::HashMap;
5
6use runmat_builtins::{
7 Access, BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9 CellArray, CharArray, ClassDef, MethodDef, ObjectInstance, PropertyDef, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::control::type_resolvers::ss_type;
18use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
19
20const BUILTIN_NAME: &str = "ss";
21const SS_CLASS: &str = "ss";
22
23thread_local! {
24 static SS_CLASS_REGISTERED: Cell<bool> = const { Cell::new(false) };
25}
26
27const SS_OUTPUT_SYS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28 name: "sys",
29 ty: BuiltinParamType::Any,
30 arity: BuiltinParamArity::Required,
31 default: None,
32 description: "State-space model object.",
33}];
34const SS_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
35 name: "A",
36 ty: BuiltinParamType::NumericArray,
37 arity: BuiltinParamArity::Required,
38 default: None,
39 description: "State matrix with shape n-by-n.",
40};
41const SS_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
42 name: "B",
43 ty: BuiltinParamType::NumericArray,
44 arity: BuiltinParamArity::Required,
45 default: None,
46 description: "Input matrix with shape n-by-nu.",
47};
48const SS_PARAM_C: BuiltinParamDescriptor = BuiltinParamDescriptor {
49 name: "C",
50 ty: BuiltinParamType::NumericArray,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Output matrix with shape ny-by-n.",
54};
55const SS_PARAM_D: BuiltinParamDescriptor = BuiltinParamDescriptor {
56 name: "D",
57 ty: BuiltinParamType::NumericArray,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Feedthrough matrix with shape ny-by-nu.",
61};
62const SS_INPUTS_ABCD: [BuiltinParamDescriptor; 4] =
63 [SS_PARAM_A, SS_PARAM_B, SS_PARAM_C, SS_PARAM_D];
64const SS_INPUTS_ABCD_TS: [BuiltinParamDescriptor; 5] = [
65 SS_PARAM_A,
66 SS_PARAM_B,
67 SS_PARAM_C,
68 SS_PARAM_D,
69 BuiltinParamDescriptor {
70 name: "Ts",
71 ty: BuiltinParamType::NumericScalar,
72 arity: BuiltinParamArity::Optional,
73 default: Some("0.0"),
74 description: "Sample time (0 for continuous-time model).",
75 },
76];
77const SS_INPUTS_ABCD_NAMEVALUE: [BuiltinParamDescriptor; 6] = [
78 SS_PARAM_A,
79 SS_PARAM_B,
80 SS_PARAM_C,
81 SS_PARAM_D,
82 BuiltinParamDescriptor {
83 name: "name",
84 ty: BuiltinParamType::StringScalar,
85 arity: BuiltinParamArity::Variadic,
86 default: None,
87 description: "Option name ('Ts' or 'SampleTime').",
88 },
89 BuiltinParamDescriptor {
90 name: "value",
91 ty: BuiltinParamType::Any,
92 arity: BuiltinParamArity::Variadic,
93 default: None,
94 description: "Option value.",
95 },
96];
97const SS_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
98 BuiltinSignatureDescriptor {
99 label: "sys = ss(A, B, C, D)",
100 inputs: &SS_INPUTS_ABCD,
101 outputs: &SS_OUTPUT_SYS,
102 },
103 BuiltinSignatureDescriptor {
104 label: "sys = ss(A, B, C, D, Ts)",
105 inputs: &SS_INPUTS_ABCD_TS,
106 outputs: &SS_OUTPUT_SYS,
107 },
108 BuiltinSignatureDescriptor {
109 label: "sys = ss(A, B, C, D, \"Ts\", Ts)",
110 inputs: &SS_INPUTS_ABCD_NAMEVALUE,
111 outputs: &SS_OUTPUT_SYS,
112 },
113 BuiltinSignatureDescriptor {
114 label: "sys = ss(A, B, C, D, name, value, ...)",
115 inputs: &SS_INPUTS_ABCD_NAMEVALUE,
116 outputs: &SS_OUTPUT_SYS,
117 },
118];
119const SS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
120 code: "RM.SS.INVALID_ARGUMENT",
121 identifier: Some("RunMat:ss:InvalidArgument"),
122 when: "Arguments do not match supported ss invocation forms.",
123 message: "ss: invalid argument",
124};
125const SS_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
126 code: "RM.SS.INVALID_OPTION",
127 identifier: Some("RunMat:ss:InvalidOption"),
128 when: "A name/value option token is unsupported or malformed.",
129 message: "ss: invalid option",
130};
131const SS_ERROR_INVALID_SAMPLE_TIME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132 code: "RM.SS.INVALID_SAMPLE_TIME",
133 identifier: Some("RunMat:ss:InvalidSampleTime"),
134 when: "Sample time is not a finite non-negative scalar.",
135 message: "ss: sample time must be a finite non-negative scalar",
136};
137const SS_ERROR_INVALID_DIMENSIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
138 code: "RM.SS.INVALID_DIMENSIONS",
139 identifier: Some("RunMat:ss:InvalidDimensions"),
140 when: "A, B, C, and D dimensions do not define a consistent state-space model.",
141 message: "ss: invalid state-space matrix dimensions",
142};
143const SS_ERROR_UNSUPPORTED_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144 code: "RM.SS.UNSUPPORTED_INPUT",
145 identifier: Some("RunMat:ss:UnsupportedInput"),
146 when: "An input is complex, sparse, logical, or another unsupported model form.",
147 message: "ss: unsupported input",
148};
149const SS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
150 code: "RM.SS.INTERNAL",
151 identifier: Some("RunMat:ss:Internal"),
152 when: "Internal tensor/object construction failed.",
153 message: "ss: internal error",
154};
155const SS_ERRORS: [BuiltinErrorDescriptor; 6] = [
156 SS_ERROR_INVALID_ARGUMENT,
157 SS_ERROR_INVALID_OPTION,
158 SS_ERROR_INVALID_SAMPLE_TIME,
159 SS_ERROR_INVALID_DIMENSIONS,
160 SS_ERROR_UNSUPPORTED_INPUT,
161 SS_ERROR_INTERNAL,
162];
163pub const SS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
164 signatures: &SS_SIGNATURES,
165 output_mode: BuiltinOutputMode::Fixed,
166 completion_policy: BuiltinCompletionPolicy::Public,
167 errors: &SS_ERRORS,
168};
169
170#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::ss")]
171pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
172 name: "ss",
173 op_kind: GpuOpKind::Custom("state-space-model-constructor"),
174 supported_precisions: &[],
175 broadcast: BroadcastSemantics::None,
176 provider_hooks: &[],
177 constant_strategy: ConstantStrategy::InlineLiteral,
178 residency: ResidencyPolicy::GatherImmediately,
179 nan_mode: ReductionNaN::Include,
180 two_pass_threshold: None,
181 workgroup_size: None,
182 accepts_nan_mode: false,
183 notes: "Object construction runs on the host. gpuArray matrix inputs are gathered before validating and storing the state-space metadata.",
184};
185
186#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::ss")]
187pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
188 name: "ss",
189 shape: ShapeRequirements::Any,
190 constant_strategy: ConstantStrategy::InlineLiteral,
191 elementwise: None,
192 reduction: None,
193 emits_nan: false,
194 notes: "State-space construction is metadata-only and terminates numeric fusion chains.",
195};
196
197fn ss_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
198 ss_error_with_message(error.message, error)
199}
200
201fn ss_error_with_detail(
202 error: &'static BuiltinErrorDescriptor,
203 detail: impl AsRef<str>,
204) -> RuntimeError {
205 ss_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
206}
207
208fn ss_error_with_message(
209 message: impl Into<String>,
210 error: &'static BuiltinErrorDescriptor,
211) -> RuntimeError {
212 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
213 if let Some(identifier) = error.identifier {
214 builder = builder.with_identifier(identifier);
215 }
216 builder.build()
217}
218
219fn ensure_ss_class_registered() {
220 SS_CLASS_REGISTERED.with(|registered| {
221 if registered.get() {
222 return;
223 }
224 let mut properties = HashMap::new();
225 for name in [
226 "A",
227 "B",
228 "C",
229 "D",
230 "Ts",
231 "InputDelay",
232 "OutputDelay",
233 "StateName",
234 "InputName",
235 "OutputName",
236 ] {
237 properties.insert(
238 name.to_string(),
239 PropertyDef {
240 name: name.to_string(),
241 is_static: false,
242 is_constant: false,
243 is_dependent: false,
244 get_access: Access::Public,
245 set_access: Access::Public,
246 default_value: None,
247 },
248 );
249 }
250
251 let methods: HashMap<String, MethodDef> = HashMap::new();
252 runmat_builtins::register_class(ClassDef {
253 name: SS_CLASS.to_string(),
254 parent: None,
255 properties,
256 methods,
257 });
258 registered.set(true);
259 });
260}
261
262#[runtime_builtin(
263 name = "ss",
264 category = "control",
265 summary = "Create state-space model objects from A, B, C, and D matrices.",
266 keywords = "ss,state space,control system,model,matrices",
267 type_resolver(ss_type),
268 descriptor(crate::builtins::control::ss::SS_DESCRIPTOR),
269 builtin_path = "crate::builtins::control::ss"
270)]
271async fn ss_builtin(
272 a: Value,
273 b: Value,
274 c: Value,
275 d: Value,
276 rest: Vec<Value>,
277) -> BuiltinResult<Value> {
278 let options = SsOptions::parse(&rest)?;
279 let a = RealMatrix::parse("A", a).await?;
280 let b = RealMatrix::parse("B", b).await?;
281 let c = RealMatrix::parse("C", c).await?;
282 let d = RealMatrix::parse("D", d).await?;
283
284 validate_state_space_dimensions(&a, &b, &c, &d)?;
285
286 let state_count = a.rows;
287 let input_count = b.cols;
288 let output_count = c.rows;
289
290 ensure_ss_class_registered();
291 let mut object = ObjectInstance::new(SS_CLASS.to_string());
292 object.properties.insert("A".to_string(), a.into_value());
293 object.properties.insert("B".to_string(), b.into_value());
294 object.properties.insert("C".to_string(), c.into_value());
295 object.properties.insert("D".to_string(), d.into_value());
296 object
297 .properties
298 .insert("Ts".to_string(), Value::Num(options.sample_time));
299 object.properties.insert(
300 "InputDelay".to_string(),
301 zero_tensor_value(vec![input_count, 1])?,
302 );
303 object.properties.insert(
304 "OutputDelay".to_string(),
305 zero_tensor_value(vec![output_count, 1])?,
306 );
307 object.properties.insert(
308 "StateName".to_string(),
309 empty_name_cell_value(state_count, 1)?,
310 );
311 object.properties.insert(
312 "InputName".to_string(),
313 empty_name_cell_value(input_count, 1)?,
314 );
315 object.properties.insert(
316 "OutputName".to_string(),
317 empty_name_cell_value(output_count, 1)?,
318 );
319 Ok(Value::Object(object))
320}
321
322#[derive(Clone)]
323struct SsOptions {
324 sample_time: f64,
325}
326
327impl SsOptions {
328 fn parse(rest: &[Value]) -> BuiltinResult<Self> {
329 let mut options = Self { sample_time: 0.0 };
330
331 match rest {
332 [] => {}
333 [sample_time] => options.sample_time = parse_sample_time(sample_time)?,
334 _ => {
335 if !rest.len().is_multiple_of(2) {
336 return Err(ss_error_with_detail(
337 &SS_ERROR_INVALID_ARGUMENT,
338 "optional arguments must be name-value pairs or a scalar sample time",
339 ));
340 }
341 let mut idx = 0;
342 while idx < rest.len() {
343 let name = scalar_text(&rest[idx], "option name")?;
344 let lowered = name.trim().to_ascii_lowercase();
345 let value = &rest[idx + 1];
346 match lowered.as_str() {
347 "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
348 _ => {
349 return Err(ss_error_with_detail(
350 &SS_ERROR_INVALID_OPTION,
351 format!("unsupported option '{name}'"),
352 ));
353 }
354 }
355 idx += 2;
356 }
357 }
358 }
359
360 Ok(options)
361 }
362}
363
364fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
365 let sample_time = match value {
366 Value::Num(n) => *n,
367 Value::Int(i) => i.to_f64(),
368 other => {
369 return Err(ss_error_with_detail(
370 &SS_ERROR_INVALID_SAMPLE_TIME,
371 format!("expected non-negative scalar, got {other:?}"),
372 ))
373 }
374 };
375 if !sample_time.is_finite() || sample_time < 0.0 {
376 return Err(ss_error(&SS_ERROR_INVALID_SAMPLE_TIME));
377 }
378 Ok(sample_time)
379}
380
381fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
382 match value {
383 Value::String(text) => Ok(text.clone()),
384 Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
385 Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
386 other => Err(ss_error_with_detail(
387 &SS_ERROR_INVALID_ARGUMENT,
388 format!("{context} must be a string scalar or character vector, got {other:?}"),
389 )),
390 }
391}
392
393#[derive(Clone)]
394struct RealMatrix {
395 tensor: Tensor,
396 rows: usize,
397 cols: usize,
398}
399
400impl RealMatrix {
401 async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
402 let gathered = dispatcher::gather_if_needed_async(&value).await?;
403 let tensor = match gathered {
404 Value::Tensor(tensor) => tensor,
405 Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|err| {
406 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
407 })?,
408 Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|err| {
409 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
410 })?,
411 Value::Complex(_, _) | Value::ComplexTensor(_) => {
412 return Err(ss_error_with_detail(
413 &SS_ERROR_UNSUPPORTED_INPUT,
414 format!(
415 "{label} must be finite real numeric data; complex input is unsupported"
416 ),
417 ));
418 }
419 other => {
420 return Err(ss_error_with_detail(
421 &SS_ERROR_UNSUPPORTED_INPUT,
422 format!("{label} must be a finite real numeric matrix, got {other:?}"),
423 ));
424 }
425 };
426
427 if tensor.shape.len() > 2 {
428 return Err(ss_error_with_detail(
429 &SS_ERROR_INVALID_DIMENSIONS,
430 format!("{label} must be a 2-D matrix, got shape {:?}", tensor.shape),
431 ));
432 }
433 if tensor.data.iter().any(|value| !value.is_finite()) {
434 return Err(ss_error_with_detail(
435 &SS_ERROR_UNSUPPORTED_INPUT,
436 format!("{label} must contain only finite real values"),
437 ));
438 }
439
440 Ok(Self {
441 rows: tensor.rows,
442 cols: tensor.cols,
443 tensor,
444 })
445 }
446
447 fn into_value(self) -> Value {
448 Value::Tensor(self.tensor)
449 }
450}
451
452fn validate_state_space_dimensions(
453 a: &RealMatrix,
454 b: &RealMatrix,
455 c: &RealMatrix,
456 d: &RealMatrix,
457) -> BuiltinResult<()> {
458 if a.rows != a.cols {
459 return Err(ss_error_with_detail(
460 &SS_ERROR_INVALID_DIMENSIONS,
461 format!("A must be square, got {}x{}", a.rows, a.cols),
462 ));
463 }
464
465 let state_count = a.rows;
466 if b.rows != state_count {
467 return Err(ss_error_with_detail(
468 &SS_ERROR_INVALID_DIMENSIONS,
469 format!(
470 "B must have {} rows to match A, got {}x{}",
471 state_count, b.rows, b.cols
472 ),
473 ));
474 }
475 if c.cols != state_count {
476 return Err(ss_error_with_detail(
477 &SS_ERROR_INVALID_DIMENSIONS,
478 format!(
479 "C must have {} columns to match A, got {}x{}",
480 state_count, c.rows, c.cols
481 ),
482 ));
483 }
484 if d.rows != c.rows || d.cols != b.cols {
485 return Err(ss_error_with_detail(
486 &SS_ERROR_INVALID_DIMENSIONS,
487 format!(
488 "D must have shape {}x{} to match C outputs and B inputs, got {}x{}",
489 c.rows, b.cols, d.rows, d.cols
490 ),
491 ));
492 }
493
494 Ok(())
495}
496
497fn zero_tensor_value(shape: Vec<usize>) -> BuiltinResult<Value> {
498 let len = shape.iter().product();
499 Tensor::new(vec![0.0; len], shape)
500 .map(Value::Tensor)
501 .map_err(|err| {
502 ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
503 })
504}
505
506fn empty_name_cell_value(rows: usize, cols: usize) -> BuiltinResult<Value> {
507 let len = rows * cols;
508 let values = (0..len)
509 .map(|_| Value::CharArray(CharArray::new_row("")))
510 .collect();
511 CellArray::new(values, rows, cols)
512 .map(Value::Cell)
513 .map_err(|err| {
514 ss_error_with_detail(
515 &SS_ERROR_INTERNAL,
516 format!("failed to build cell array: {err}"),
517 )
518 })
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use crate::builtins::common::test_support;
525 use futures::executor::block_on;
526 use runmat_builtins::IntValue;
527
528 fn run_ss(a: Value, b: Value, c: Value, d: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
529 block_on(ss_builtin(a, b, c, d, rest))
530 }
531
532 fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
533 let Value::Object(object) = value else {
534 panic!("expected object, got {value:?}");
535 };
536 object
537 .properties
538 .get(name)
539 .unwrap_or_else(|| panic!("missing property {name}"))
540 }
541
542 fn assert_tensor(value: &Value, shape: &[usize], data: &[f64]) {
543 match value {
544 Value::Tensor(tensor) => {
545 assert_eq!(tensor.shape, shape);
546 assert_eq!(tensor.data, data);
547 }
548 other => panic!("expected tensor, got {other:?}"),
549 }
550 }
551
552 #[test]
553 fn ss_descriptor_signatures_cover_core_forms() {
554 let labels: Vec<&str> = SS_DESCRIPTOR
555 .signatures
556 .iter()
557 .map(|sig| sig.label)
558 .collect();
559 assert!(labels.contains(&"sys = ss(A, B, C, D)"));
560 assert!(labels.contains(&"sys = ss(A, B, C, D, Ts)"));
561 assert!(labels.contains(&"sys = ss(A, B, C, D, \"Ts\", Ts)"));
562 assert!(labels.contains(&"sys = ss(A, B, C, D, name, value, ...)"));
563 }
564
565 #[test]
566 fn ss_constructs_continuous_state_space_object() {
567 let sys = run_ss(
568 Value::Tensor(Tensor::new(vec![0.0, -2.0, 1.0, -3.0], vec![2, 2]).unwrap()),
569 Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap()),
570 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
571 Value::Num(0.0),
572 Vec::new(),
573 )
574 .expect("ss");
575
576 let Value::Object(object) = &sys else {
577 panic!("expected object");
578 };
579 assert_eq!(object.class_name, "ss");
580 assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
581 assert_tensor(property(&sys, "A"), &[2, 2], &[0.0, -2.0, 1.0, -3.0]);
582 assert_tensor(property(&sys, "B"), &[2, 1], &[0.0, 1.0]);
583 assert_tensor(property(&sys, "C"), &[1, 2], &[1.0, 0.0]);
584 assert_tensor(property(&sys, "D"), &[1, 1], &[0.0]);
585 assert_tensor(property(&sys, "InputDelay"), &[1, 1], &[0.0]);
586 assert_tensor(property(&sys, "OutputDelay"), &[1, 1], &[0.0]);
587 }
588
589 #[test]
590 fn ss_preserves_matrix_orientation_for_mimo_systems() {
591 let sys = run_ss(
592 Value::Num(-1.0),
593 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
594 Value::Tensor(Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap()),
595 Value::Tensor(Tensor::new(vec![0.0, 0.1, 0.2, 0.3], vec![2, 2]).unwrap()),
596 Vec::new(),
597 )
598 .expect("ss");
599
600 assert_tensor(property(&sys, "A"), &[1, 1], &[-1.0]);
601 assert_tensor(property(&sys, "B"), &[1, 2], &[1.0, 2.0]);
602 assert_tensor(property(&sys, "C"), &[2, 1], &[3.0, 4.0]);
603 assert_tensor(property(&sys, "D"), &[2, 2], &[0.0, 0.1, 0.2, 0.3]);
604 assert_tensor(property(&sys, "InputDelay"), &[2, 1], &[0.0, 0.0]);
605 assert_tensor(property(&sys, "OutputDelay"), &[2, 1], &[0.0, 0.0]);
606 }
607
608 #[test]
609 fn ss_accepts_discrete_sample_time() {
610 let sys = run_ss(
611 Value::Int(IntValue::I32(1)),
612 Value::Int(IntValue::I32(2)),
613 Value::Int(IntValue::I32(3)),
614 Value::Int(IntValue::I32(4)),
615 vec![Value::Num(0.25)],
616 )
617 .expect("ss");
618
619 assert_eq!(property(&sys, "Ts"), &Value::Num(0.25));
620 }
621
622 #[test]
623 fn ss_accepts_sample_time_name_value_options() {
624 let sys = run_ss(
625 Value::Num(1.0),
626 Value::Num(2.0),
627 Value::Num(3.0),
628 Value::Num(4.0),
629 vec![Value::from("SampleTime"), Value::Num(0.5)],
630 )
631 .expect("ss");
632
633 assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
634 }
635
636 #[test]
637 fn ss_rejects_nonsquare_a_matrix() {
638 let err = run_ss(
639 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
640 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
641 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
642 Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
643 Vec::new(),
644 )
645 .expect_err("nonsquare A should fail");
646 assert!(err.message().contains("A must be square"));
647 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
648 }
649
650 #[test]
651 fn ss_rejects_b_row_mismatch() {
652 let err = run_ss(
653 Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
654 Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
655 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
656 Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
657 Vec::new(),
658 )
659 .expect_err("B mismatch should fail");
660 assert!(err.message().contains("B must have 2 rows"));
661 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
662 }
663
664 #[test]
665 fn ss_rejects_d_shape_mismatch() {
666 let err = run_ss(
667 Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
668 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap()),
669 Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
670 Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
671 Vec::new(),
672 )
673 .expect_err("D mismatch should fail");
674 assert!(err.message().contains("D must have shape 1x1"));
675 assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
676 }
677
678 #[test]
679 fn ss_rejects_invalid_sample_time() {
680 let err = run_ss(
681 Value::Num(1.0),
682 Value::Num(1.0),
683 Value::Num(1.0),
684 Value::Num(0.0),
685 vec![Value::Num(-0.1)],
686 )
687 .expect_err("negative Ts should fail");
688 assert_eq!(err.identifier(), SS_ERROR_INVALID_SAMPLE_TIME.identifier);
689 }
690
691 #[test]
692 fn ss_rejects_complex_inputs() {
693 let err = run_ss(
694 Value::Complex(1.0, 1.0),
695 Value::Num(1.0),
696 Value::Num(1.0),
697 Value::Num(0.0),
698 Vec::new(),
699 )
700 .expect_err("complex A should fail");
701 assert!(err.message().contains("complex input is unsupported"));
702 assert_eq!(err.identifier(), SS_ERROR_UNSUPPORTED_INPUT.identifier);
703 }
704
705 #[test]
706 fn ss_gpu_matrix_input_gathers_to_host() {
707 test_support::with_test_provider(|provider| {
708 let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
709 let view = runmat_accelerate_api::HostTensorView {
710 data: &tensor.data,
711 shape: &tensor.shape,
712 };
713 let handle = provider.upload(&view).expect("upload");
714 let sys = run_ss(
715 Value::GpuTensor(handle),
716 Value::Num(2.0),
717 Value::Num(3.0),
718 Value::Num(4.0),
719 Vec::new(),
720 )
721 .expect("ss");
722
723 assert_tensor(property(&sys, "A"), &[1, 1], &[1.0]);
724 });
725 }
726}