Skip to main content

runmat_runtime/builtins/control/
ss.rs

1//! MATLAB-compatible `ss` state-space model constructor for RunMat.
2
3use std::collections::HashMap;
4use std::sync::OnceLock;
5
6use runmat_builtins::{
7    Access, BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    CellArray, CharArray, ClassDef, MethodDef, ObjectInstance, PropertyDef, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::control::type_resolvers::ss_type;
18use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
19
20const BUILTIN_NAME: &str = "ss";
21const SS_CLASS: &str = "ss";
22
23static SS_CLASS_REGISTERED: OnceLock<()> = OnceLock::new();
24
25const SS_OUTPUT_SYS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
26    name: "sys",
27    ty: BuiltinParamType::Any,
28    arity: BuiltinParamArity::Required,
29    default: None,
30    description: "State-space model object.",
31}];
32const SS_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
33    name: "A",
34    ty: BuiltinParamType::NumericArray,
35    arity: BuiltinParamArity::Required,
36    default: None,
37    description: "State matrix with shape n-by-n.",
38};
39const SS_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
40    name: "B",
41    ty: BuiltinParamType::NumericArray,
42    arity: BuiltinParamArity::Required,
43    default: None,
44    description: "Input matrix with shape n-by-nu.",
45};
46const SS_PARAM_C: BuiltinParamDescriptor = BuiltinParamDescriptor {
47    name: "C",
48    ty: BuiltinParamType::NumericArray,
49    arity: BuiltinParamArity::Required,
50    default: None,
51    description: "Output matrix with shape ny-by-n.",
52};
53const SS_PARAM_D: BuiltinParamDescriptor = BuiltinParamDescriptor {
54    name: "D",
55    ty: BuiltinParamType::NumericArray,
56    arity: BuiltinParamArity::Required,
57    default: None,
58    description: "Feedthrough matrix with shape ny-by-nu.",
59};
60const SS_INPUTS_ABCD: [BuiltinParamDescriptor; 4] =
61    [SS_PARAM_A, SS_PARAM_B, SS_PARAM_C, SS_PARAM_D];
62const SS_INPUTS_ABCD_TS: [BuiltinParamDescriptor; 5] = [
63    SS_PARAM_A,
64    SS_PARAM_B,
65    SS_PARAM_C,
66    SS_PARAM_D,
67    BuiltinParamDescriptor {
68        name: "Ts",
69        ty: BuiltinParamType::NumericScalar,
70        arity: BuiltinParamArity::Optional,
71        default: Some("0.0"),
72        description: "Sample time (0 for continuous-time model).",
73    },
74];
75const SS_INPUTS_ABCD_NAMEVALUE: [BuiltinParamDescriptor; 6] = [
76    SS_PARAM_A,
77    SS_PARAM_B,
78    SS_PARAM_C,
79    SS_PARAM_D,
80    BuiltinParamDescriptor {
81        name: "name",
82        ty: BuiltinParamType::StringScalar,
83        arity: BuiltinParamArity::Variadic,
84        default: None,
85        description: "Option name ('Ts' or 'SampleTime').",
86    },
87    BuiltinParamDescriptor {
88        name: "value",
89        ty: BuiltinParamType::Any,
90        arity: BuiltinParamArity::Variadic,
91        default: None,
92        description: "Option value.",
93    },
94];
95const SS_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
96    BuiltinSignatureDescriptor {
97        label: "sys = ss(A, B, C, D)",
98        inputs: &SS_INPUTS_ABCD,
99        outputs: &SS_OUTPUT_SYS,
100    },
101    BuiltinSignatureDescriptor {
102        label: "sys = ss(A, B, C, D, Ts)",
103        inputs: &SS_INPUTS_ABCD_TS,
104        outputs: &SS_OUTPUT_SYS,
105    },
106    BuiltinSignatureDescriptor {
107        label: "sys = ss(A, B, C, D, \"Ts\", Ts)",
108        inputs: &SS_INPUTS_ABCD_NAMEVALUE,
109        outputs: &SS_OUTPUT_SYS,
110    },
111    BuiltinSignatureDescriptor {
112        label: "sys = ss(A, B, C, D, name, value, ...)",
113        inputs: &SS_INPUTS_ABCD_NAMEVALUE,
114        outputs: &SS_OUTPUT_SYS,
115    },
116];
117const SS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118    code: "RM.SS.INVALID_ARGUMENT",
119    identifier: Some("RunMat:ss:InvalidArgument"),
120    when: "Arguments do not match supported ss invocation forms.",
121    message: "ss: invalid argument",
122};
123const SS_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124    code: "RM.SS.INVALID_OPTION",
125    identifier: Some("RunMat:ss:InvalidOption"),
126    when: "A name/value option token is unsupported or malformed.",
127    message: "ss: invalid option",
128};
129const SS_ERROR_INVALID_SAMPLE_TIME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
130    code: "RM.SS.INVALID_SAMPLE_TIME",
131    identifier: Some("RunMat:ss:InvalidSampleTime"),
132    when: "Sample time is not a finite non-negative scalar.",
133    message: "ss: sample time must be a finite non-negative scalar",
134};
135const SS_ERROR_INVALID_DIMENSIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
136    code: "RM.SS.INVALID_DIMENSIONS",
137    identifier: Some("RunMat:ss:InvalidDimensions"),
138    when: "A, B, C, and D dimensions do not define a consistent state-space model.",
139    message: "ss: invalid state-space matrix dimensions",
140};
141const SS_ERROR_UNSUPPORTED_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
142    code: "RM.SS.UNSUPPORTED_INPUT",
143    identifier: Some("RunMat:ss:UnsupportedInput"),
144    when: "An input is complex, sparse, logical, or another unsupported model form.",
145    message: "ss: unsupported input",
146};
147const SS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
148    code: "RM.SS.INTERNAL",
149    identifier: Some("RunMat:ss:Internal"),
150    when: "Internal tensor/object construction failed.",
151    message: "ss: internal error",
152};
153const SS_ERRORS: [BuiltinErrorDescriptor; 6] = [
154    SS_ERROR_INVALID_ARGUMENT,
155    SS_ERROR_INVALID_OPTION,
156    SS_ERROR_INVALID_SAMPLE_TIME,
157    SS_ERROR_INVALID_DIMENSIONS,
158    SS_ERROR_UNSUPPORTED_INPUT,
159    SS_ERROR_INTERNAL,
160];
161pub const SS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
162    signatures: &SS_SIGNATURES,
163    output_mode: BuiltinOutputMode::Fixed,
164    completion_policy: BuiltinCompletionPolicy::Public,
165    errors: &SS_ERRORS,
166};
167
168#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::ss")]
169pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
170    name: "ss",
171    op_kind: GpuOpKind::Custom("state-space-model-constructor"),
172    supported_precisions: &[],
173    broadcast: BroadcastSemantics::None,
174    provider_hooks: &[],
175    constant_strategy: ConstantStrategy::InlineLiteral,
176    residency: ResidencyPolicy::GatherImmediately,
177    nan_mode: ReductionNaN::Include,
178    two_pass_threshold: None,
179    workgroup_size: None,
180    accepts_nan_mode: false,
181    notes: "Object construction runs on the host. gpuArray matrix inputs are gathered before validating and storing the state-space metadata.",
182};
183
184#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::ss")]
185pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
186    name: "ss",
187    shape: ShapeRequirements::Any,
188    constant_strategy: ConstantStrategy::InlineLiteral,
189    elementwise: None,
190    reduction: None,
191    emits_nan: false,
192    notes: "State-space construction is metadata-only and terminates numeric fusion chains.",
193};
194
195fn ss_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
196    ss_error_with_message(error.message, error)
197}
198
199fn ss_error_with_detail(
200    error: &'static BuiltinErrorDescriptor,
201    detail: impl AsRef<str>,
202) -> RuntimeError {
203    ss_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
204}
205
206fn ss_error_with_message(
207    message: impl Into<String>,
208    error: &'static BuiltinErrorDescriptor,
209) -> RuntimeError {
210    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
211    if let Some(identifier) = error.identifier {
212        builder = builder.with_identifier(identifier);
213    }
214    builder.build()
215}
216
217fn ensure_ss_class_registered() {
218    SS_CLASS_REGISTERED.get_or_init(|| {
219        let mut properties = HashMap::new();
220        for name in [
221            "A",
222            "B",
223            "C",
224            "D",
225            "Ts",
226            "InputDelay",
227            "OutputDelay",
228            "StateName",
229            "InputName",
230            "OutputName",
231        ] {
232            properties.insert(
233                name.to_string(),
234                PropertyDef {
235                    name: name.to_string(),
236                    is_static: false,
237                    is_constant: false,
238                    is_dependent: false,
239                    get_access: Access::Public,
240                    set_access: Access::Public,
241                    default_value: None,
242                },
243            );
244        }
245
246        let methods: HashMap<String, MethodDef> = HashMap::new();
247        runmat_builtins::register_class(ClassDef {
248            name: SS_CLASS.to_string(),
249            parent: None,
250            properties,
251            methods,
252        });
253    });
254}
255
256#[runtime_builtin(
257    name = "ss",
258    category = "control",
259    summary = "Create state-space model objects from A, B, C, and D matrices.",
260    keywords = "ss,state space,control system,model,matrices",
261    type_resolver(ss_type),
262    descriptor(crate::builtins::control::ss::SS_DESCRIPTOR),
263    builtin_path = "crate::builtins::control::ss"
264)]
265async fn ss_builtin(
266    a: Value,
267    b: Value,
268    c: Value,
269    d: Value,
270    rest: Vec<Value>,
271) -> BuiltinResult<Value> {
272    let options = SsOptions::parse(&rest)?;
273    let a = RealMatrix::parse("A", a).await?;
274    let b = RealMatrix::parse("B", b).await?;
275    let c = RealMatrix::parse("C", c).await?;
276    let d = RealMatrix::parse("D", d).await?;
277
278    validate_state_space_dimensions(&a, &b, &c, &d)?;
279
280    let state_count = a.rows;
281    let input_count = b.cols;
282    let output_count = c.rows;
283
284    ensure_ss_class_registered();
285    let mut object = ObjectInstance::new(SS_CLASS.to_string());
286    object.properties.insert("A".to_string(), a.into_value());
287    object.properties.insert("B".to_string(), b.into_value());
288    object.properties.insert("C".to_string(), c.into_value());
289    object.properties.insert("D".to_string(), d.into_value());
290    object
291        .properties
292        .insert("Ts".to_string(), Value::Num(options.sample_time));
293    object.properties.insert(
294        "InputDelay".to_string(),
295        zero_tensor_value(vec![input_count, 1])?,
296    );
297    object.properties.insert(
298        "OutputDelay".to_string(),
299        zero_tensor_value(vec![output_count, 1])?,
300    );
301    object.properties.insert(
302        "StateName".to_string(),
303        empty_name_cell_value(state_count, 1)?,
304    );
305    object.properties.insert(
306        "InputName".to_string(),
307        empty_name_cell_value(input_count, 1)?,
308    );
309    object.properties.insert(
310        "OutputName".to_string(),
311        empty_name_cell_value(output_count, 1)?,
312    );
313    Ok(Value::Object(object))
314}
315
316#[derive(Clone)]
317struct SsOptions {
318    sample_time: f64,
319}
320
321impl SsOptions {
322    fn parse(rest: &[Value]) -> BuiltinResult<Self> {
323        let mut options = Self { sample_time: 0.0 };
324
325        match rest {
326            [] => {}
327            [sample_time] => options.sample_time = parse_sample_time(sample_time)?,
328            _ => {
329                if !rest.len().is_multiple_of(2) {
330                    return Err(ss_error_with_detail(
331                        &SS_ERROR_INVALID_ARGUMENT,
332                        "optional arguments must be name-value pairs or a scalar sample time",
333                    ));
334                }
335                let mut idx = 0;
336                while idx < rest.len() {
337                    let name = scalar_text(&rest[idx], "option name")?;
338                    let lowered = name.trim().to_ascii_lowercase();
339                    let value = &rest[idx + 1];
340                    match lowered.as_str() {
341                        "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
342                        _ => {
343                            return Err(ss_error_with_detail(
344                                &SS_ERROR_INVALID_OPTION,
345                                format!("unsupported option '{name}'"),
346                            ));
347                        }
348                    }
349                    idx += 2;
350                }
351            }
352        }
353
354        Ok(options)
355    }
356}
357
358fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
359    let sample_time = match value {
360        Value::Num(n) => *n,
361        Value::Int(i) => i.to_f64(),
362        other => {
363            return Err(ss_error_with_detail(
364                &SS_ERROR_INVALID_SAMPLE_TIME,
365                format!("expected non-negative scalar, got {other:?}"),
366            ))
367        }
368    };
369    if !sample_time.is_finite() || sample_time < 0.0 {
370        return Err(ss_error(&SS_ERROR_INVALID_SAMPLE_TIME));
371    }
372    Ok(sample_time)
373}
374
375fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
376    match value {
377        Value::String(text) => Ok(text.clone()),
378        Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
379        Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
380        other => Err(ss_error_with_detail(
381            &SS_ERROR_INVALID_ARGUMENT,
382            format!("{context} must be a string scalar or character vector, got {other:?}"),
383        )),
384    }
385}
386
387#[derive(Clone)]
388struct RealMatrix {
389    tensor: Tensor,
390    rows: usize,
391    cols: usize,
392}
393
394impl RealMatrix {
395    async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
396        let gathered = dispatcher::gather_if_needed_async(&value).await?;
397        let tensor = match gathered {
398            Value::Tensor(tensor) => tensor,
399            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|err| {
400                ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
401            })?,
402            Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|err| {
403                ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
404            })?,
405            Value::Complex(_, _) | Value::ComplexTensor(_) => {
406                return Err(ss_error_with_detail(
407                    &SS_ERROR_UNSUPPORTED_INPUT,
408                    format!(
409                        "{label} must be finite real numeric data; complex input is unsupported"
410                    ),
411                ));
412            }
413            other => {
414                return Err(ss_error_with_detail(
415                    &SS_ERROR_UNSUPPORTED_INPUT,
416                    format!("{label} must be a finite real numeric matrix, got {other:?}"),
417                ));
418            }
419        };
420
421        if tensor.shape.len() > 2 {
422            return Err(ss_error_with_detail(
423                &SS_ERROR_INVALID_DIMENSIONS,
424                format!("{label} must be a 2-D matrix, got shape {:?}", tensor.shape),
425            ));
426        }
427        if tensor.data.iter().any(|value| !value.is_finite()) {
428            return Err(ss_error_with_detail(
429                &SS_ERROR_UNSUPPORTED_INPUT,
430                format!("{label} must contain only finite real values"),
431            ));
432        }
433
434        Ok(Self {
435            rows: tensor.rows,
436            cols: tensor.cols,
437            tensor,
438        })
439    }
440
441    fn into_value(self) -> Value {
442        Value::Tensor(self.tensor)
443    }
444}
445
446fn validate_state_space_dimensions(
447    a: &RealMatrix,
448    b: &RealMatrix,
449    c: &RealMatrix,
450    d: &RealMatrix,
451) -> BuiltinResult<()> {
452    if a.rows != a.cols {
453        return Err(ss_error_with_detail(
454            &SS_ERROR_INVALID_DIMENSIONS,
455            format!("A must be square, got {}x{}", a.rows, a.cols),
456        ));
457    }
458
459    let state_count = a.rows;
460    if b.rows != state_count {
461        return Err(ss_error_with_detail(
462            &SS_ERROR_INVALID_DIMENSIONS,
463            format!(
464                "B must have {} rows to match A, got {}x{}",
465                state_count, b.rows, b.cols
466            ),
467        ));
468    }
469    if c.cols != state_count {
470        return Err(ss_error_with_detail(
471            &SS_ERROR_INVALID_DIMENSIONS,
472            format!(
473                "C must have {} columns to match A, got {}x{}",
474                state_count, c.rows, c.cols
475            ),
476        ));
477    }
478    if d.rows != c.rows || d.cols != b.cols {
479        return Err(ss_error_with_detail(
480            &SS_ERROR_INVALID_DIMENSIONS,
481            format!(
482                "D must have shape {}x{} to match C outputs and B inputs, got {}x{}",
483                c.rows, b.cols, d.rows, d.cols
484            ),
485        ));
486    }
487
488    Ok(())
489}
490
491fn zero_tensor_value(shape: Vec<usize>) -> BuiltinResult<Value> {
492    let len = shape.iter().product();
493    Tensor::new(vec![0.0; len], shape)
494        .map(Value::Tensor)
495        .map_err(|err| {
496            ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
497        })
498}
499
500fn empty_name_cell_value(rows: usize, cols: usize) -> BuiltinResult<Value> {
501    let len = rows * cols;
502    let values = (0..len)
503        .map(|_| Value::CharArray(CharArray::new_row("")))
504        .collect();
505    CellArray::new(values, rows, cols)
506        .map(Value::Cell)
507        .map_err(|err| {
508            ss_error_with_detail(
509                &SS_ERROR_INTERNAL,
510                format!("failed to build cell array: {err}"),
511            )
512        })
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::builtins::common::test_support;
519    use futures::executor::block_on;
520    use runmat_builtins::IntValue;
521
522    fn run_ss(a: Value, b: Value, c: Value, d: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
523        block_on(ss_builtin(a, b, c, d, rest))
524    }
525
526    fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
527        let Value::Object(object) = value else {
528            panic!("expected object, got {value:?}");
529        };
530        object
531            .properties
532            .get(name)
533            .unwrap_or_else(|| panic!("missing property {name}"))
534    }
535
536    fn assert_tensor(value: &Value, shape: &[usize], data: &[f64]) {
537        match value {
538            Value::Tensor(tensor) => {
539                assert_eq!(tensor.shape, shape);
540                assert_eq!(tensor.data, data);
541            }
542            other => panic!("expected tensor, got {other:?}"),
543        }
544    }
545
546    #[test]
547    fn ss_descriptor_signatures_cover_core_forms() {
548        let labels: Vec<&str> = SS_DESCRIPTOR
549            .signatures
550            .iter()
551            .map(|sig| sig.label)
552            .collect();
553        assert!(labels.contains(&"sys = ss(A, B, C, D)"));
554        assert!(labels.contains(&"sys = ss(A, B, C, D, Ts)"));
555        assert!(labels.contains(&"sys = ss(A, B, C, D, \"Ts\", Ts)"));
556        assert!(labels.contains(&"sys = ss(A, B, C, D, name, value, ...)"));
557    }
558
559    #[test]
560    fn ss_constructs_continuous_state_space_object() {
561        let sys = run_ss(
562            Value::Tensor(Tensor::new(vec![0.0, -2.0, 1.0, -3.0], vec![2, 2]).unwrap()),
563            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap()),
564            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
565            Value::Num(0.0),
566            Vec::new(),
567        )
568        .expect("ss");
569
570        let Value::Object(object) = &sys else {
571            panic!("expected object");
572        };
573        assert_eq!(object.class_name, "ss");
574        assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
575        assert_tensor(property(&sys, "A"), &[2, 2], &[0.0, -2.0, 1.0, -3.0]);
576        assert_tensor(property(&sys, "B"), &[2, 1], &[0.0, 1.0]);
577        assert_tensor(property(&sys, "C"), &[1, 2], &[1.0, 0.0]);
578        assert_tensor(property(&sys, "D"), &[1, 1], &[0.0]);
579        assert_tensor(property(&sys, "InputDelay"), &[1, 1], &[0.0]);
580        assert_tensor(property(&sys, "OutputDelay"), &[1, 1], &[0.0]);
581    }
582
583    #[test]
584    fn ss_preserves_matrix_orientation_for_mimo_systems() {
585        let sys = run_ss(
586            Value::Num(-1.0),
587            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
588            Value::Tensor(Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap()),
589            Value::Tensor(Tensor::new(vec![0.0, 0.1, 0.2, 0.3], vec![2, 2]).unwrap()),
590            Vec::new(),
591        )
592        .expect("ss");
593
594        assert_tensor(property(&sys, "A"), &[1, 1], &[-1.0]);
595        assert_tensor(property(&sys, "B"), &[1, 2], &[1.0, 2.0]);
596        assert_tensor(property(&sys, "C"), &[2, 1], &[3.0, 4.0]);
597        assert_tensor(property(&sys, "D"), &[2, 2], &[0.0, 0.1, 0.2, 0.3]);
598        assert_tensor(property(&sys, "InputDelay"), &[2, 1], &[0.0, 0.0]);
599        assert_tensor(property(&sys, "OutputDelay"), &[2, 1], &[0.0, 0.0]);
600    }
601
602    #[test]
603    fn ss_accepts_discrete_sample_time() {
604        let sys = run_ss(
605            Value::Int(IntValue::I32(1)),
606            Value::Int(IntValue::I32(2)),
607            Value::Int(IntValue::I32(3)),
608            Value::Int(IntValue::I32(4)),
609            vec![Value::Num(0.25)],
610        )
611        .expect("ss");
612
613        assert_eq!(property(&sys, "Ts"), &Value::Num(0.25));
614    }
615
616    #[test]
617    fn ss_accepts_sample_time_name_value_options() {
618        let sys = run_ss(
619            Value::Num(1.0),
620            Value::Num(2.0),
621            Value::Num(3.0),
622            Value::Num(4.0),
623            vec![Value::from("SampleTime"), Value::Num(0.5)],
624        )
625        .expect("ss");
626
627        assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
628    }
629
630    #[test]
631    fn ss_rejects_nonsquare_a_matrix() {
632        let err = run_ss(
633            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
634            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
635            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
636            Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
637            Vec::new(),
638        )
639        .expect_err("nonsquare A should fail");
640        assert!(err.message().contains("A must be square"));
641        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
642    }
643
644    #[test]
645    fn ss_rejects_b_row_mismatch() {
646        let err = run_ss(
647            Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
648            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
649            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
650            Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
651            Vec::new(),
652        )
653        .expect_err("B mismatch should fail");
654        assert!(err.message().contains("B must have 2 rows"));
655        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
656    }
657
658    #[test]
659    fn ss_rejects_d_shape_mismatch() {
660        let err = run_ss(
661            Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
662            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap()),
663            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
664            Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
665            Vec::new(),
666        )
667        .expect_err("D mismatch should fail");
668        assert!(err.message().contains("D must have shape 1x1"));
669        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
670    }
671
672    #[test]
673    fn ss_rejects_invalid_sample_time() {
674        let err = run_ss(
675            Value::Num(1.0),
676            Value::Num(1.0),
677            Value::Num(1.0),
678            Value::Num(0.0),
679            vec![Value::Num(-0.1)],
680        )
681        .expect_err("negative Ts should fail");
682        assert_eq!(err.identifier(), SS_ERROR_INVALID_SAMPLE_TIME.identifier);
683    }
684
685    #[test]
686    fn ss_rejects_complex_inputs() {
687        let err = run_ss(
688            Value::Complex(1.0, 1.0),
689            Value::Num(1.0),
690            Value::Num(1.0),
691            Value::Num(0.0),
692            Vec::new(),
693        )
694        .expect_err("complex A should fail");
695        assert!(err.message().contains("complex input is unsupported"));
696        assert_eq!(err.identifier(), SS_ERROR_UNSUPPORTED_INPUT.identifier);
697    }
698
699    #[test]
700    fn ss_gpu_matrix_input_gathers_to_host() {
701        test_support::with_test_provider(|provider| {
702            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
703            let view = runmat_accelerate_api::HostTensorView {
704                data: &tensor.data,
705                shape: &tensor.shape,
706            };
707            let handle = provider.upload(&view).expect("upload");
708            let sys = run_ss(
709                Value::GpuTensor(handle),
710                Value::Num(2.0),
711                Value::Num(3.0),
712                Value::Num(4.0),
713                Vec::new(),
714            )
715            .expect("ss");
716
717            assert_tensor(property(&sys, "A"), &[1, 1], &[1.0]);
718        });
719    }
720}