Skip to main content

runmat_runtime/builtins/control/
ss.rs

1//! MATLAB-compatible `ss` state-space model constructor for RunMat.
2
3use std::cell::Cell;
4use std::collections::HashMap;
5
6use runmat_builtins::{
7    Access, BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
8    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
9    CellArray, CharArray, ClassDef, MethodDef, ObjectInstance, PropertyDef, Tensor, Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::control::type_resolvers::ss_type;
18use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
19
20const BUILTIN_NAME: &str = "ss";
21const SS_CLASS: &str = "ss";
22
23thread_local! {
24    static SS_CLASS_REGISTERED: Cell<bool> = const { Cell::new(false) };
25}
26
27const SS_OUTPUT_SYS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
28    name: "sys",
29    ty: BuiltinParamType::Any,
30    arity: BuiltinParamArity::Required,
31    default: None,
32    description: "State-space model object.",
33}];
34const SS_PARAM_A: BuiltinParamDescriptor = BuiltinParamDescriptor {
35    name: "A",
36    ty: BuiltinParamType::NumericArray,
37    arity: BuiltinParamArity::Required,
38    default: None,
39    description: "State matrix with shape n-by-n.",
40};
41const SS_PARAM_B: BuiltinParamDescriptor = BuiltinParamDescriptor {
42    name: "B",
43    ty: BuiltinParamType::NumericArray,
44    arity: BuiltinParamArity::Required,
45    default: None,
46    description: "Input matrix with shape n-by-nu.",
47};
48const SS_PARAM_C: BuiltinParamDescriptor = BuiltinParamDescriptor {
49    name: "C",
50    ty: BuiltinParamType::NumericArray,
51    arity: BuiltinParamArity::Required,
52    default: None,
53    description: "Output matrix with shape ny-by-n.",
54};
55const SS_PARAM_D: BuiltinParamDescriptor = BuiltinParamDescriptor {
56    name: "D",
57    ty: BuiltinParamType::NumericArray,
58    arity: BuiltinParamArity::Required,
59    default: None,
60    description: "Feedthrough matrix with shape ny-by-nu.",
61};
62const SS_INPUTS_ABCD: [BuiltinParamDescriptor; 4] =
63    [SS_PARAM_A, SS_PARAM_B, SS_PARAM_C, SS_PARAM_D];
64const SS_INPUTS_ABCD_TS: [BuiltinParamDescriptor; 5] = [
65    SS_PARAM_A,
66    SS_PARAM_B,
67    SS_PARAM_C,
68    SS_PARAM_D,
69    BuiltinParamDescriptor {
70        name: "Ts",
71        ty: BuiltinParamType::NumericScalar,
72        arity: BuiltinParamArity::Optional,
73        default: Some("0.0"),
74        description: "Sample time (0 for continuous-time model).",
75    },
76];
77const SS_INPUTS_ABCD_NAMEVALUE: [BuiltinParamDescriptor; 6] = [
78    SS_PARAM_A,
79    SS_PARAM_B,
80    SS_PARAM_C,
81    SS_PARAM_D,
82    BuiltinParamDescriptor {
83        name: "name",
84        ty: BuiltinParamType::StringScalar,
85        arity: BuiltinParamArity::Variadic,
86        default: None,
87        description: "Option name ('Ts' or 'SampleTime').",
88    },
89    BuiltinParamDescriptor {
90        name: "value",
91        ty: BuiltinParamType::Any,
92        arity: BuiltinParamArity::Variadic,
93        default: None,
94        description: "Option value.",
95    },
96];
97const SS_SIGNATURES: [BuiltinSignatureDescriptor; 4] = [
98    BuiltinSignatureDescriptor {
99        label: "sys = ss(A, B, C, D)",
100        inputs: &SS_INPUTS_ABCD,
101        outputs: &SS_OUTPUT_SYS,
102    },
103    BuiltinSignatureDescriptor {
104        label: "sys = ss(A, B, C, D, Ts)",
105        inputs: &SS_INPUTS_ABCD_TS,
106        outputs: &SS_OUTPUT_SYS,
107    },
108    BuiltinSignatureDescriptor {
109        label: "sys = ss(A, B, C, D, \"Ts\", Ts)",
110        inputs: &SS_INPUTS_ABCD_NAMEVALUE,
111        outputs: &SS_OUTPUT_SYS,
112    },
113    BuiltinSignatureDescriptor {
114        label: "sys = ss(A, B, C, D, name, value, ...)",
115        inputs: &SS_INPUTS_ABCD_NAMEVALUE,
116        outputs: &SS_OUTPUT_SYS,
117    },
118];
119const SS_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
120    code: "RM.SS.INVALID_ARGUMENT",
121    identifier: Some("RunMat:ss:InvalidArgument"),
122    when: "Arguments do not match supported ss invocation forms.",
123    message: "ss: invalid argument",
124};
125const SS_ERROR_INVALID_OPTION: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
126    code: "RM.SS.INVALID_OPTION",
127    identifier: Some("RunMat:ss:InvalidOption"),
128    when: "A name/value option token is unsupported or malformed.",
129    message: "ss: invalid option",
130};
131const SS_ERROR_INVALID_SAMPLE_TIME: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
132    code: "RM.SS.INVALID_SAMPLE_TIME",
133    identifier: Some("RunMat:ss:InvalidSampleTime"),
134    when: "Sample time is not a finite non-negative scalar.",
135    message: "ss: sample time must be a finite non-negative scalar",
136};
137const SS_ERROR_INVALID_DIMENSIONS: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
138    code: "RM.SS.INVALID_DIMENSIONS",
139    identifier: Some("RunMat:ss:InvalidDimensions"),
140    when: "A, B, C, and D dimensions do not define a consistent state-space model.",
141    message: "ss: invalid state-space matrix dimensions",
142};
143const SS_ERROR_UNSUPPORTED_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
144    code: "RM.SS.UNSUPPORTED_INPUT",
145    identifier: Some("RunMat:ss:UnsupportedInput"),
146    when: "An input is complex, sparse, logical, or another unsupported model form.",
147    message: "ss: unsupported input",
148};
149const SS_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
150    code: "RM.SS.INTERNAL",
151    identifier: Some("RunMat:ss:Internal"),
152    when: "Internal tensor/object construction failed.",
153    message: "ss: internal error",
154};
155const SS_ERRORS: [BuiltinErrorDescriptor; 6] = [
156    SS_ERROR_INVALID_ARGUMENT,
157    SS_ERROR_INVALID_OPTION,
158    SS_ERROR_INVALID_SAMPLE_TIME,
159    SS_ERROR_INVALID_DIMENSIONS,
160    SS_ERROR_UNSUPPORTED_INPUT,
161    SS_ERROR_INTERNAL,
162];
163pub const SS_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
164    signatures: &SS_SIGNATURES,
165    output_mode: BuiltinOutputMode::Fixed,
166    completion_policy: BuiltinCompletionPolicy::Public,
167    errors: &SS_ERRORS,
168};
169
170#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::ss")]
171pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
172    name: "ss",
173    op_kind: GpuOpKind::Custom("state-space-model-constructor"),
174    supported_precisions: &[],
175    broadcast: BroadcastSemantics::None,
176    provider_hooks: &[],
177    constant_strategy: ConstantStrategy::InlineLiteral,
178    residency: ResidencyPolicy::GatherImmediately,
179    nan_mode: ReductionNaN::Include,
180    two_pass_threshold: None,
181    workgroup_size: None,
182    accepts_nan_mode: false,
183    notes: "Object construction runs on the host. gpuArray matrix inputs are gathered before validating and storing the state-space metadata.",
184};
185
186#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::ss")]
187pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
188    name: "ss",
189    shape: ShapeRequirements::Any,
190    constant_strategy: ConstantStrategy::InlineLiteral,
191    elementwise: None,
192    reduction: None,
193    emits_nan: false,
194    notes: "State-space construction is metadata-only and terminates numeric fusion chains.",
195};
196
197fn ss_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
198    ss_error_with_message(error.message, error)
199}
200
201fn ss_error_with_detail(
202    error: &'static BuiltinErrorDescriptor,
203    detail: impl AsRef<str>,
204) -> RuntimeError {
205    ss_error_with_message(format!("{}: {}", error.message, detail.as_ref()), error)
206}
207
208fn ss_error_with_message(
209    message: impl Into<String>,
210    error: &'static BuiltinErrorDescriptor,
211) -> RuntimeError {
212    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
213    if let Some(identifier) = error.identifier {
214        builder = builder.with_identifier(identifier);
215    }
216    builder.build()
217}
218
219fn ensure_ss_class_registered() {
220    SS_CLASS_REGISTERED.with(|registered| {
221        if registered.get() {
222            return;
223        }
224        let mut properties = HashMap::new();
225        for name in [
226            "A",
227            "B",
228            "C",
229            "D",
230            "Ts",
231            "InputDelay",
232            "OutputDelay",
233            "StateName",
234            "InputName",
235            "OutputName",
236        ] {
237            properties.insert(
238                name.to_string(),
239                PropertyDef {
240                    name: name.to_string(),
241                    is_static: false,
242                    is_constant: false,
243                    is_dependent: false,
244                    get_access: Access::Public,
245                    set_access: Access::Public,
246                    default_value: None,
247                },
248            );
249        }
250
251        let methods: HashMap<String, MethodDef> = HashMap::new();
252        runmat_builtins::register_class(ClassDef {
253            name: SS_CLASS.to_string(),
254            parent: None,
255            properties,
256            methods,
257        });
258        registered.set(true);
259    });
260}
261
262#[runtime_builtin(
263    name = "ss",
264    category = "control",
265    summary = "Create state-space model objects from A, B, C, and D matrices.",
266    keywords = "ss,state space,control system,model,matrices",
267    type_resolver(ss_type),
268    descriptor(crate::builtins::control::ss::SS_DESCRIPTOR),
269    builtin_path = "crate::builtins::control::ss"
270)]
271async fn ss_builtin(
272    a: Value,
273    b: Value,
274    c: Value,
275    d: Value,
276    rest: Vec<Value>,
277) -> BuiltinResult<Value> {
278    let options = SsOptions::parse(&rest)?;
279    let a = RealMatrix::parse("A", a).await?;
280    let b = RealMatrix::parse("B", b).await?;
281    let c = RealMatrix::parse("C", c).await?;
282    let d = RealMatrix::parse("D", d).await?;
283
284    validate_state_space_dimensions(&a, &b, &c, &d)?;
285
286    let state_count = a.rows;
287    let input_count = b.cols;
288    let output_count = c.rows;
289
290    ensure_ss_class_registered();
291    let mut object = ObjectInstance::new(SS_CLASS.to_string());
292    object.properties.insert("A".to_string(), a.into_value());
293    object.properties.insert("B".to_string(), b.into_value());
294    object.properties.insert("C".to_string(), c.into_value());
295    object.properties.insert("D".to_string(), d.into_value());
296    object
297        .properties
298        .insert("Ts".to_string(), Value::Num(options.sample_time));
299    object.properties.insert(
300        "InputDelay".to_string(),
301        zero_tensor_value(vec![input_count, 1])?,
302    );
303    object.properties.insert(
304        "OutputDelay".to_string(),
305        zero_tensor_value(vec![output_count, 1])?,
306    );
307    object.properties.insert(
308        "StateName".to_string(),
309        empty_name_cell_value(state_count, 1)?,
310    );
311    object.properties.insert(
312        "InputName".to_string(),
313        empty_name_cell_value(input_count, 1)?,
314    );
315    object.properties.insert(
316        "OutputName".to_string(),
317        empty_name_cell_value(output_count, 1)?,
318    );
319    Ok(Value::Object(object))
320}
321
322#[derive(Clone)]
323struct SsOptions {
324    sample_time: f64,
325}
326
327impl SsOptions {
328    fn parse(rest: &[Value]) -> BuiltinResult<Self> {
329        let mut options = Self { sample_time: 0.0 };
330
331        match rest {
332            [] => {}
333            [sample_time] => options.sample_time = parse_sample_time(sample_time)?,
334            _ => {
335                if !rest.len().is_multiple_of(2) {
336                    return Err(ss_error_with_detail(
337                        &SS_ERROR_INVALID_ARGUMENT,
338                        "optional arguments must be name-value pairs or a scalar sample time",
339                    ));
340                }
341                let mut idx = 0;
342                while idx < rest.len() {
343                    let name = scalar_text(&rest[idx], "option name")?;
344                    let lowered = name.trim().to_ascii_lowercase();
345                    let value = &rest[idx + 1];
346                    match lowered.as_str() {
347                        "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
348                        _ => {
349                            return Err(ss_error_with_detail(
350                                &SS_ERROR_INVALID_OPTION,
351                                format!("unsupported option '{name}'"),
352                            ));
353                        }
354                    }
355                    idx += 2;
356                }
357            }
358        }
359
360        Ok(options)
361    }
362}
363
364fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
365    let sample_time = match value {
366        Value::Num(n) => *n,
367        Value::Int(i) => i.to_f64(),
368        other => {
369            return Err(ss_error_with_detail(
370                &SS_ERROR_INVALID_SAMPLE_TIME,
371                format!("expected non-negative scalar, got {other:?}"),
372            ))
373        }
374    };
375    if !sample_time.is_finite() || sample_time < 0.0 {
376        return Err(ss_error(&SS_ERROR_INVALID_SAMPLE_TIME));
377    }
378    Ok(sample_time)
379}
380
381fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
382    match value {
383        Value::String(text) => Ok(text.clone()),
384        Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
385        Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
386        other => Err(ss_error_with_detail(
387            &SS_ERROR_INVALID_ARGUMENT,
388            format!("{context} must be a string scalar or character vector, got {other:?}"),
389        )),
390    }
391}
392
393#[derive(Clone)]
394struct RealMatrix {
395    tensor: Tensor,
396    rows: usize,
397    cols: usize,
398}
399
400impl RealMatrix {
401    async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
402        let gathered = dispatcher::gather_if_needed_async(&value).await?;
403        let tensor = match gathered {
404            Value::Tensor(tensor) => tensor,
405            Value::Num(n) => Tensor::new(vec![n], vec![1, 1]).map_err(|err| {
406                ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
407            })?,
408            Value::Int(i) => Tensor::new(vec![i.to_f64()], vec![1, 1]).map_err(|err| {
409                ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
410            })?,
411            Value::Complex(_, _) | Value::ComplexTensor(_) => {
412                return Err(ss_error_with_detail(
413                    &SS_ERROR_UNSUPPORTED_INPUT,
414                    format!(
415                        "{label} must be finite real numeric data; complex input is unsupported"
416                    ),
417                ));
418            }
419            other => {
420                return Err(ss_error_with_detail(
421                    &SS_ERROR_UNSUPPORTED_INPUT,
422                    format!("{label} must be a finite real numeric matrix, got {other:?}"),
423                ));
424            }
425        };
426
427        if tensor.shape.len() > 2 {
428            return Err(ss_error_with_detail(
429                &SS_ERROR_INVALID_DIMENSIONS,
430                format!("{label} must be a 2-D matrix, got shape {:?}", tensor.shape),
431            ));
432        }
433        if tensor.data.iter().any(|value| !value.is_finite()) {
434            return Err(ss_error_with_detail(
435                &SS_ERROR_UNSUPPORTED_INPUT,
436                format!("{label} must contain only finite real values"),
437            ));
438        }
439
440        Ok(Self {
441            rows: tensor.rows,
442            cols: tensor.cols,
443            tensor,
444        })
445    }
446
447    fn into_value(self) -> Value {
448        Value::Tensor(self.tensor)
449    }
450}
451
452fn validate_state_space_dimensions(
453    a: &RealMatrix,
454    b: &RealMatrix,
455    c: &RealMatrix,
456    d: &RealMatrix,
457) -> BuiltinResult<()> {
458    if a.rows != a.cols {
459        return Err(ss_error_with_detail(
460            &SS_ERROR_INVALID_DIMENSIONS,
461            format!("A must be square, got {}x{}", a.rows, a.cols),
462        ));
463    }
464
465    let state_count = a.rows;
466    if b.rows != state_count {
467        return Err(ss_error_with_detail(
468            &SS_ERROR_INVALID_DIMENSIONS,
469            format!(
470                "B must have {} rows to match A, got {}x{}",
471                state_count, b.rows, b.cols
472            ),
473        ));
474    }
475    if c.cols != state_count {
476        return Err(ss_error_with_detail(
477            &SS_ERROR_INVALID_DIMENSIONS,
478            format!(
479                "C must have {} columns to match A, got {}x{}",
480                state_count, c.rows, c.cols
481            ),
482        ));
483    }
484    if d.rows != c.rows || d.cols != b.cols {
485        return Err(ss_error_with_detail(
486            &SS_ERROR_INVALID_DIMENSIONS,
487            format!(
488                "D must have shape {}x{} to match C outputs and B inputs, got {}x{}",
489                c.rows, b.cols, d.rows, d.cols
490            ),
491        ));
492    }
493
494    Ok(())
495}
496
497fn zero_tensor_value(shape: Vec<usize>) -> BuiltinResult<Value> {
498    let len = shape.iter().product();
499    Tensor::new(vec![0.0; len], shape)
500        .map(Value::Tensor)
501        .map_err(|err| {
502            ss_error_with_detail(&SS_ERROR_INTERNAL, format!("failed to build tensor: {err}"))
503        })
504}
505
506fn empty_name_cell_value(rows: usize, cols: usize) -> BuiltinResult<Value> {
507    let len = rows * cols;
508    let values = (0..len)
509        .map(|_| Value::CharArray(CharArray::new_row("")))
510        .collect();
511    CellArray::new(values, rows, cols)
512        .map(Value::Cell)
513        .map_err(|err| {
514            ss_error_with_detail(
515                &SS_ERROR_INTERNAL,
516                format!("failed to build cell array: {err}"),
517            )
518        })
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use crate::builtins::common::test_support;
525    use futures::executor::block_on;
526    use runmat_builtins::IntValue;
527
528    fn run_ss(a: Value, b: Value, c: Value, d: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
529        block_on(ss_builtin(a, b, c, d, rest))
530    }
531
532    fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
533        let Value::Object(object) = value else {
534            panic!("expected object, got {value:?}");
535        };
536        object
537            .properties
538            .get(name)
539            .unwrap_or_else(|| panic!("missing property {name}"))
540    }
541
542    fn assert_tensor(value: &Value, shape: &[usize], data: &[f64]) {
543        match value {
544            Value::Tensor(tensor) => {
545                assert_eq!(tensor.shape, shape);
546                assert_eq!(tensor.data, data);
547            }
548            other => panic!("expected tensor, got {other:?}"),
549        }
550    }
551
552    #[test]
553    fn ss_descriptor_signatures_cover_core_forms() {
554        let labels: Vec<&str> = SS_DESCRIPTOR
555            .signatures
556            .iter()
557            .map(|sig| sig.label)
558            .collect();
559        assert!(labels.contains(&"sys = ss(A, B, C, D)"));
560        assert!(labels.contains(&"sys = ss(A, B, C, D, Ts)"));
561        assert!(labels.contains(&"sys = ss(A, B, C, D, \"Ts\", Ts)"));
562        assert!(labels.contains(&"sys = ss(A, B, C, D, name, value, ...)"));
563    }
564
565    #[test]
566    fn ss_constructs_continuous_state_space_object() {
567        let sys = run_ss(
568            Value::Tensor(Tensor::new(vec![0.0, -2.0, 1.0, -3.0], vec![2, 2]).unwrap()),
569            Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap()),
570            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
571            Value::Num(0.0),
572            Vec::new(),
573        )
574        .expect("ss");
575
576        let Value::Object(object) = &sys else {
577            panic!("expected object");
578        };
579        assert_eq!(object.class_name, "ss");
580        assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
581        assert_tensor(property(&sys, "A"), &[2, 2], &[0.0, -2.0, 1.0, -3.0]);
582        assert_tensor(property(&sys, "B"), &[2, 1], &[0.0, 1.0]);
583        assert_tensor(property(&sys, "C"), &[1, 2], &[1.0, 0.0]);
584        assert_tensor(property(&sys, "D"), &[1, 1], &[0.0]);
585        assert_tensor(property(&sys, "InputDelay"), &[1, 1], &[0.0]);
586        assert_tensor(property(&sys, "OutputDelay"), &[1, 1], &[0.0]);
587    }
588
589    #[test]
590    fn ss_preserves_matrix_orientation_for_mimo_systems() {
591        let sys = run_ss(
592            Value::Num(-1.0),
593            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
594            Value::Tensor(Tensor::new(vec![3.0, 4.0], vec![2, 1]).unwrap()),
595            Value::Tensor(Tensor::new(vec![0.0, 0.1, 0.2, 0.3], vec![2, 2]).unwrap()),
596            Vec::new(),
597        )
598        .expect("ss");
599
600        assert_tensor(property(&sys, "A"), &[1, 1], &[-1.0]);
601        assert_tensor(property(&sys, "B"), &[1, 2], &[1.0, 2.0]);
602        assert_tensor(property(&sys, "C"), &[2, 1], &[3.0, 4.0]);
603        assert_tensor(property(&sys, "D"), &[2, 2], &[0.0, 0.1, 0.2, 0.3]);
604        assert_tensor(property(&sys, "InputDelay"), &[2, 1], &[0.0, 0.0]);
605        assert_tensor(property(&sys, "OutputDelay"), &[2, 1], &[0.0, 0.0]);
606    }
607
608    #[test]
609    fn ss_accepts_discrete_sample_time() {
610        let sys = run_ss(
611            Value::Int(IntValue::I32(1)),
612            Value::Int(IntValue::I32(2)),
613            Value::Int(IntValue::I32(3)),
614            Value::Int(IntValue::I32(4)),
615            vec![Value::Num(0.25)],
616        )
617        .expect("ss");
618
619        assert_eq!(property(&sys, "Ts"), &Value::Num(0.25));
620    }
621
622    #[test]
623    fn ss_accepts_sample_time_name_value_options() {
624        let sys = run_ss(
625            Value::Num(1.0),
626            Value::Num(2.0),
627            Value::Num(3.0),
628            Value::Num(4.0),
629            vec![Value::from("SampleTime"), Value::Num(0.5)],
630        )
631        .expect("ss");
632
633        assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
634    }
635
636    #[test]
637    fn ss_rejects_nonsquare_a_matrix() {
638        let err = run_ss(
639            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
640            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
641            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap()),
642            Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
643            Vec::new(),
644        )
645        .expect_err("nonsquare A should fail");
646        assert!(err.message().contains("A must be square"));
647        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
648    }
649
650    #[test]
651    fn ss_rejects_b_row_mismatch() {
652        let err = run_ss(
653            Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
654            Value::Tensor(Tensor::new(vec![1.0], vec![1, 1]).unwrap()),
655            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
656            Value::Tensor(Tensor::new(vec![0.0], vec![1, 1]).unwrap()),
657            Vec::new(),
658        )
659        .expect_err("B mismatch should fail");
660        assert!(err.message().contains("B must have 2 rows"));
661        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
662    }
663
664    #[test]
665    fn ss_rejects_d_shape_mismatch() {
666        let err = run_ss(
667            Value::Tensor(Tensor::new(vec![1.0, 0.0, 0.0, 1.0], vec![2, 2]).unwrap()),
668            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![2, 1]).unwrap()),
669            Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
670            Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
671            Vec::new(),
672        )
673        .expect_err("D mismatch should fail");
674        assert!(err.message().contains("D must have shape 1x1"));
675        assert_eq!(err.identifier(), SS_ERROR_INVALID_DIMENSIONS.identifier);
676    }
677
678    #[test]
679    fn ss_rejects_invalid_sample_time() {
680        let err = run_ss(
681            Value::Num(1.0),
682            Value::Num(1.0),
683            Value::Num(1.0),
684            Value::Num(0.0),
685            vec![Value::Num(-0.1)],
686        )
687        .expect_err("negative Ts should fail");
688        assert_eq!(err.identifier(), SS_ERROR_INVALID_SAMPLE_TIME.identifier);
689    }
690
691    #[test]
692    fn ss_rejects_complex_inputs() {
693        let err = run_ss(
694            Value::Complex(1.0, 1.0),
695            Value::Num(1.0),
696            Value::Num(1.0),
697            Value::Num(0.0),
698            Vec::new(),
699        )
700        .expect_err("complex A should fail");
701        assert!(err.message().contains("complex input is unsupported"));
702        assert_eq!(err.identifier(), SS_ERROR_UNSUPPORTED_INPUT.identifier);
703    }
704
705    #[test]
706    fn ss_gpu_matrix_input_gathers_to_host() {
707        test_support::with_test_provider(|provider| {
708            let tensor = Tensor::new(vec![1.0], vec![1, 1]).unwrap();
709            let view = runmat_accelerate_api::HostTensorView {
710                data: &tensor.data,
711                shape: &tensor.shape,
712            };
713            let handle = provider.upload(&view).expect("upload");
714            let sys = run_ss(
715                Value::GpuTensor(handle),
716                Value::Num(2.0),
717                Value::Num(3.0),
718                Value::Num(4.0),
719                Vec::new(),
720            )
721            .expect("ss");
722
723            assert_tensor(property(&sys, "A"), &[1, 1], &[1.0]);
724        });
725    }
726}