Skip to main content

runmat_runtime/builtins/control/
tf.rs

1//! MATLAB-compatible `tf` transfer-function constructor for RunMat.
2
3use std::collections::HashMap;
4use std::sync::OnceLock;
5
6use num_complex::Complex64;
7use runmat_builtins::{
8    Access, CharArray, ClassDef, ComplexTensor, MethodDef, ObjectInstance, PropertyDef, Tensor,
9    Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15    ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::control::type_resolvers::tf_type;
19use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
20
21const BUILTIN_NAME: &str = "tf";
22const TF_CLASS: &str = "tf";
23const DEFAULT_VARIABLE: &str = "s";
24const EPS: f64 = 1.0e-12;
25
26static TF_CLASS_REGISTERED: OnceLock<()> = OnceLock::new();
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::tf")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30    name: "tf",
31    op_kind: GpuOpKind::Custom("transfer-function-constructor"),
32    supported_precisions: &[],
33    broadcast: BroadcastSemantics::None,
34    provider_hooks: &[],
35    constant_strategy: ConstantStrategy::InlineLiteral,
36    residency: ResidencyPolicy::GatherImmediately,
37    nan_mode: ReductionNaN::Include,
38    two_pass_threshold: None,
39    workgroup_size: None,
40    accepts_nan_mode: false,
41    notes: "Object construction runs on the host. gpuArray coefficient inputs are gathered before storing the transfer-function metadata.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::tf")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46    name: "tf",
47    shape: ShapeRequirements::Any,
48    constant_strategy: ConstantStrategy::InlineLiteral,
49    elementwise: None,
50    reduction: None,
51    emits_nan: false,
52    notes: "Transfer-function construction is metadata-only and terminates numeric fusion chains.",
53};
54
55fn tf_error(message: impl Into<String>) -> RuntimeError {
56    build_runtime_error(message)
57        .with_builtin(BUILTIN_NAME)
58        .build()
59}
60
61fn ensure_tf_class_registered() {
62    TF_CLASS_REGISTERED.get_or_init(|| {
63        let mut properties = HashMap::new();
64        for name in [
65            "Numerator",
66            "Denominator",
67            "Variable",
68            "Ts",
69            "InputDelay",
70            "OutputDelay",
71        ] {
72            properties.insert(
73                name.to_string(),
74                PropertyDef {
75                    name: name.to_string(),
76                    is_static: false,
77                    is_dependent: false,
78                    get_access: Access::Public,
79                    set_access: Access::Public,
80                    default_value: None,
81                },
82            );
83        }
84
85        let methods: HashMap<String, MethodDef> = HashMap::new();
86        runmat_builtins::register_class(ClassDef {
87            name: TF_CLASS.to_string(),
88            parent: None,
89            properties,
90            methods,
91        });
92    });
93}
94
95#[runtime_builtin(
96    name = "tf",
97    category = "control",
98    summary = "Create a SISO transfer-function object from numerator and denominator coefficient vectors.",
99    keywords = "tf,transfer function,control system,filter,polynomial",
100    type_resolver(tf_type),
101    builtin_path = "crate::builtins::control::tf"
102)]
103async fn tf_builtin(
104    numerator: Value,
105    denominator: Value,
106    rest: Vec<Value>,
107) -> BuiltinResult<Value> {
108    let options = TfOptions::parse(&rest)?;
109    let numerator = Coefficients::parse("numerator", numerator).await?;
110    let denominator = Coefficients::parse("denominator", denominator).await?;
111
112    if denominator.coeffs.is_empty() {
113        return Err(tf_error("tf: denominator coefficients cannot be empty"));
114    }
115    if denominator.is_all_zero() {
116        return Err(tf_error(
117            "tf: denominator coefficients must not all be zero",
118        ));
119    }
120
121    ensure_tf_class_registered();
122    let mut object = ObjectInstance::new(TF_CLASS.to_string());
123    object
124        .properties
125        .insert("Numerator".to_string(), numerator.into_row_value()?);
126    object
127        .properties
128        .insert("Denominator".to_string(), denominator.into_row_value()?);
129    object.properties.insert(
130        "Variable".to_string(),
131        Value::CharArray(CharArray::new_row(&options.variable)),
132    );
133    object
134        .properties
135        .insert("Ts".to_string(), Value::Num(options.sample_time));
136    object
137        .properties
138        .insert("InputDelay".to_string(), Value::Num(0.0));
139    object
140        .properties
141        .insert("OutputDelay".to_string(), Value::Num(0.0));
142    Ok(Value::Object(object))
143}
144
145#[derive(Clone)]
146struct TfOptions {
147    variable: String,
148    sample_time: f64,
149    variable_explicit: bool,
150}
151
152impl TfOptions {
153    fn parse(rest: &[Value]) -> BuiltinResult<Self> {
154        let mut options = Self {
155            variable: DEFAULT_VARIABLE.to_string(),
156            sample_time: 0.0,
157            variable_explicit: false,
158        };
159
160        match rest {
161            [] => {}
162            [sample_time] => {
163                options.sample_time = parse_sample_time(sample_time)?;
164                if options.sample_time > 0.0 {
165                    options.variable = "z".to_string();
166                }
167            }
168            _ => {
169                if !rest.len().is_multiple_of(2) {
170                    return Err(tf_error(
171                        "tf: optional arguments must be name-value pairs or a scalar sample time",
172                    ));
173                }
174                let mut idx = 0;
175                while idx < rest.len() {
176                    let name = scalar_text(&rest[idx], "option name")?;
177                    let lowered = name.trim().to_ascii_lowercase();
178                    let value = &rest[idx + 1];
179                    match lowered.as_str() {
180                        "variable" => {
181                            options.variable = parse_variable(value)?;
182                            options.variable_explicit = true;
183                        }
184                        "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
185                        _ => {
186                            return Err(tf_error(format!("tf: unsupported option '{name}'")));
187                        }
188                    }
189                    idx += 2;
190                }
191                if options.sample_time > 0.0 && !options.variable_explicit {
192                    options.variable = "z".to_string();
193                }
194            }
195        }
196
197        Ok(options)
198    }
199}
200
201fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
202    let sample_time = match value {
203        Value::Num(n) => *n,
204        Value::Int(i) => i.to_f64(),
205        other => {
206            return Err(tf_error(format!(
207                "tf: sample time must be a non-negative scalar, got {other:?}"
208            )))
209        }
210    };
211    if !sample_time.is_finite() || sample_time < 0.0 {
212        return Err(tf_error(
213            "tf: sample time must be a finite non-negative scalar",
214        ));
215    }
216    Ok(sample_time)
217}
218
219fn parse_variable(value: &Value) -> BuiltinResult<String> {
220    let variable = scalar_text(value, "Variable")?;
221    let variable = variable.trim();
222    match variable {
223        "s" | "p" | "z" | "q" | "z^-1" | "q^-1" => Ok(variable.to_string()),
224        _ => Err(tf_error(
225            "tf: Variable must be one of 's', 'p', 'z', 'q', 'z^-1', or 'q^-1'",
226        )),
227    }
228}
229
230fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
231    match value {
232        Value::String(text) => Ok(text.clone()),
233        Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
234        Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
235        other => Err(tf_error(format!(
236            "tf: {context} must be a string scalar or character vector, got {other:?}"
237        ))),
238    }
239}
240
241#[derive(Clone)]
242struct Coefficients {
243    coeffs: Vec<Complex64>,
244}
245
246impl Coefficients {
247    async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
248        let gathered = dispatcher::gather_if_needed_async(&value).await?;
249        let coeffs = match gathered {
250            Value::Tensor(tensor) => {
251                ensure_vector_shape(label, &tensor.shape)?;
252                tensor
253                    .data
254                    .into_iter()
255                    .map(|re| Complex64::new(re, 0.0))
256                    .collect()
257            }
258            Value::ComplexTensor(tensor) => {
259                ensure_vector_shape(label, &tensor.shape)?;
260                tensor
261                    .data
262                    .into_iter()
263                    .map(|(re, im)| Complex64::new(re, im))
264                    .collect()
265            }
266            Value::LogicalArray(logical) => {
267                let tensor = tensor::logical_to_tensor(&logical).map_err(tf_error)?;
268                ensure_vector_shape(label, &tensor.shape)?;
269                tensor
270                    .data
271                    .into_iter()
272                    .map(|re| Complex64::new(re, 0.0))
273                    .collect()
274            }
275            Value::Num(n) => vec![Complex64::new(n, 0.0)],
276            Value::Int(i) => vec![Complex64::new(i.to_f64(), 0.0)],
277            Value::Bool(b) => vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
278            Value::Complex(re, im) => vec![Complex64::new(re, im)],
279            other => {
280                return Err(tf_error(format!(
281                    "tf: {label} must be a numeric coefficient vector, got {other:?}"
282                )));
283            }
284        };
285
286        if coeffs.is_empty() {
287            return Err(tf_error(format!(
288                "tf: {label} coefficients cannot be empty"
289            )));
290        }
291        for coeff in &coeffs {
292            if !coeff.re.is_finite() || !coeff.im.is_finite() {
293                return Err(tf_error(format!("tf: {label} coefficients must be finite")));
294            }
295        }
296
297        Ok(Self { coeffs })
298    }
299
300    fn is_all_zero(&self) -> bool {
301        self.coeffs.iter().all(|coeff| coeff.norm() <= EPS)
302    }
303
304    fn into_row_value(self) -> BuiltinResult<Value> {
305        let len = self.coeffs.len();
306        if self.coeffs.iter().all(|coeff| coeff.im.abs() <= EPS) {
307            let data = self.coeffs.into_iter().map(|coeff| coeff.re).collect();
308            let tensor =
309                Tensor::new(data, vec![1, len]).map_err(|err| tf_error(format!("tf: {err}")))?;
310            Ok(Value::Tensor(tensor))
311        } else {
312            let data = self
313                .coeffs
314                .into_iter()
315                .map(|coeff| (coeff.re, coeff.im))
316                .collect();
317            let tensor = ComplexTensor::new(data, vec![1, len])
318                .map_err(|err| tf_error(format!("tf: {err}")))?;
319            Ok(Value::ComplexTensor(tensor))
320        }
321    }
322}
323
324fn ensure_vector_shape(label: &str, shape: &[usize]) -> BuiltinResult<()> {
325    let non_unit = shape.iter().copied().filter(|&dim| dim > 1).count();
326    if non_unit <= 1 {
327        Ok(())
328    } else {
329        Err(tf_error(format!(
330            "tf: {label} coefficients must be a vector"
331        )))
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use futures::executor::block_on;
339    use runmat_builtins::IntValue;
340
341    fn run_tf(numerator: Value, denominator: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
342        block_on(tf_builtin(numerator, denominator, rest))
343    }
344
345    fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
346        let Value::Object(object) = value else {
347            panic!("expected object, got {value:?}");
348        };
349        object
350            .properties
351            .get(name)
352            .unwrap_or_else(|| panic!("missing property {name}"))
353    }
354
355    #[test]
356    fn tf_constructs_continuous_siso_object() {
357        let sys = run_tf(
358            Value::Num(20.0),
359            Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
360            Vec::new(),
361        )
362        .expect("tf");
363
364        let Value::Object(object) = &sys else {
365            panic!("expected object");
366        };
367        assert_eq!(object.class_name, "tf");
368        assert_eq!(
369            property(&sys, "Variable"),
370            &Value::CharArray(CharArray::new_row("s"))
371        );
372        assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
373        match property(&sys, "Numerator") {
374            Value::Tensor(tensor) => {
375                assert_eq!(tensor.shape, vec![1, 1]);
376                assert_eq!(tensor.data, vec![20.0]);
377            }
378            other => panic!("expected numerator tensor, got {other:?}"),
379        }
380        match property(&sys, "Denominator") {
381            Value::Tensor(tensor) => {
382                assert_eq!(tensor.shape, vec![1, 2]);
383                assert_eq!(tensor.data, vec![1.0, 5.0]);
384            }
385            other => panic!("expected denominator tensor, got {other:?}"),
386        }
387    }
388
389    #[test]
390    fn tf_normalizes_column_coefficients_to_rows() {
391        let sys = run_tf(
392            Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
393            Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap()),
394            Vec::new(),
395        )
396        .expect("tf");
397
398        match property(&sys, "Numerator") {
399            Value::Tensor(tensor) => {
400                assert_eq!(tensor.shape, vec![1, 2]);
401                assert_eq!(tensor.data, vec![1.0, 2.0]);
402            }
403            other => panic!("expected numerator tensor, got {other:?}"),
404        }
405        match property(&sys, "Denominator") {
406            Value::Tensor(tensor) => {
407                assert_eq!(tensor.shape, vec![1, 3]);
408                assert_eq!(tensor.data, vec![1.0, 3.0, 2.0]);
409            }
410            other => panic!("expected denominator tensor, got {other:?}"),
411        }
412    }
413
414    #[test]
415    fn tf_accepts_discrete_sample_time() {
416        let sys = run_tf(
417            Value::Int(IntValue::I32(1)),
418            Value::Tensor(Tensor::new(vec![1.0, -0.5], vec![1, 2]).unwrap()),
419            vec![Value::Num(0.1)],
420        )
421        .expect("tf");
422
423        assert_eq!(
424            property(&sys, "Variable"),
425            &Value::CharArray(CharArray::new_row("z"))
426        );
427        assert_eq!(property(&sys, "Ts"), &Value::Num(0.1));
428    }
429
430    #[test]
431    fn tf_positional_zero_sample_time_remains_continuous() {
432        let sys = run_tf(
433            Value::Int(IntValue::I32(1)),
434            Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
435            vec![Value::Num(0.0)],
436        )
437        .expect("tf");
438
439        assert_eq!(
440            property(&sys, "Variable"),
441            &Value::CharArray(CharArray::new_row("s"))
442        );
443        assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
444    }
445
446    #[test]
447    fn tf_accepts_variable_name_value_option() {
448        let sys = run_tf(
449            Value::Num(1.0),
450            Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap()),
451            vec![Value::from("Variable"), Value::from("p")],
452        )
453        .expect("tf");
454
455        assert_eq!(
456            property(&sys, "Variable"),
457            &Value::CharArray(CharArray::new_row("p"))
458        );
459    }
460
461    #[test]
462    fn tf_explicit_continuous_variable_survives_positive_sample_time() {
463        let sys = run_tf(
464            Value::Num(1.0),
465            Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap()),
466            vec![
467                Value::from("Variable"),
468                Value::from("s"),
469                Value::from("Ts"),
470                Value::Num(0.5),
471            ],
472        )
473        .expect("tf");
474
475        assert_eq!(
476            property(&sys, "Variable"),
477            &Value::CharArray(CharArray::new_row("s"))
478        );
479        assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
480    }
481
482    #[test]
483    fn tf_rejects_zero_denominator() {
484        let err = run_tf(
485            Value::Num(1.0),
486            Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
487            Vec::new(),
488        )
489        .expect_err("zero denominator should fail");
490        assert!(err.message().contains("must not all be zero"));
491    }
492
493    #[test]
494    fn tf_rejects_matrix_coefficients() {
495        let err = run_tf(
496            Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap()),
497            Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
498            Vec::new(),
499        )
500        .expect_err("matrix numerator should fail");
501        assert!(err
502            .message()
503            .contains("numerator coefficients must be a vector"));
504    }
505}