1use std::collections::HashMap;
4use std::sync::OnceLock;
5
6use num_complex::Complex64;
7use runmat_builtins::{
8 Access, CharArray, ClassDef, ComplexTensor, MethodDef, ObjectInstance, PropertyDef, Tensor,
9 Value,
10};
11use runmat_macros::runtime_builtin;
12
13use crate::builtins::common::spec::{
14 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
15 ReductionNaN, ResidencyPolicy, ShapeRequirements,
16};
17use crate::builtins::common::tensor;
18use crate::builtins::control::type_resolvers::tf_type;
19use crate::{build_runtime_error, dispatcher, BuiltinResult, RuntimeError};
20
21const BUILTIN_NAME: &str = "tf";
22const TF_CLASS: &str = "tf";
23const DEFAULT_VARIABLE: &str = "s";
24const EPS: f64 = 1.0e-12;
25
26static TF_CLASS_REGISTERED: OnceLock<()> = OnceLock::new();
27
28#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::tf")]
29pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
30 name: "tf",
31 op_kind: GpuOpKind::Custom("transfer-function-constructor"),
32 supported_precisions: &[],
33 broadcast: BroadcastSemantics::None,
34 provider_hooks: &[],
35 constant_strategy: ConstantStrategy::InlineLiteral,
36 residency: ResidencyPolicy::GatherImmediately,
37 nan_mode: ReductionNaN::Include,
38 two_pass_threshold: None,
39 workgroup_size: None,
40 accepts_nan_mode: false,
41 notes: "Object construction runs on the host. gpuArray coefficient inputs are gathered before storing the transfer-function metadata.",
42};
43
44#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::tf")]
45pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
46 name: "tf",
47 shape: ShapeRequirements::Any,
48 constant_strategy: ConstantStrategy::InlineLiteral,
49 elementwise: None,
50 reduction: None,
51 emits_nan: false,
52 notes: "Transfer-function construction is metadata-only and terminates numeric fusion chains.",
53};
54
55fn tf_error(message: impl Into<String>) -> RuntimeError {
56 build_runtime_error(message)
57 .with_builtin(BUILTIN_NAME)
58 .build()
59}
60
61fn ensure_tf_class_registered() {
62 TF_CLASS_REGISTERED.get_or_init(|| {
63 let mut properties = HashMap::new();
64 for name in [
65 "Numerator",
66 "Denominator",
67 "Variable",
68 "Ts",
69 "InputDelay",
70 "OutputDelay",
71 ] {
72 properties.insert(
73 name.to_string(),
74 PropertyDef {
75 name: name.to_string(),
76 is_static: false,
77 is_dependent: false,
78 get_access: Access::Public,
79 set_access: Access::Public,
80 default_value: None,
81 },
82 );
83 }
84
85 let methods: HashMap<String, MethodDef> = HashMap::new();
86 runmat_builtins::register_class(ClassDef {
87 name: TF_CLASS.to_string(),
88 parent: None,
89 properties,
90 methods,
91 });
92 });
93}
94
95#[runtime_builtin(
96 name = "tf",
97 category = "control",
98 summary = "Create a SISO transfer-function object from numerator and denominator coefficient vectors.",
99 keywords = "tf,transfer function,control system,filter,polynomial",
100 type_resolver(tf_type),
101 builtin_path = "crate::builtins::control::tf"
102)]
103async fn tf_builtin(
104 numerator: Value,
105 denominator: Value,
106 rest: Vec<Value>,
107) -> BuiltinResult<Value> {
108 let options = TfOptions::parse(&rest)?;
109 let numerator = Coefficients::parse("numerator", numerator).await?;
110 let denominator = Coefficients::parse("denominator", denominator).await?;
111
112 if denominator.coeffs.is_empty() {
113 return Err(tf_error("tf: denominator coefficients cannot be empty"));
114 }
115 if denominator.is_all_zero() {
116 return Err(tf_error(
117 "tf: denominator coefficients must not all be zero",
118 ));
119 }
120
121 ensure_tf_class_registered();
122 let mut object = ObjectInstance::new(TF_CLASS.to_string());
123 object
124 .properties
125 .insert("Numerator".to_string(), numerator.into_row_value()?);
126 object
127 .properties
128 .insert("Denominator".to_string(), denominator.into_row_value()?);
129 object.properties.insert(
130 "Variable".to_string(),
131 Value::CharArray(CharArray::new_row(&options.variable)),
132 );
133 object
134 .properties
135 .insert("Ts".to_string(), Value::Num(options.sample_time));
136 object
137 .properties
138 .insert("InputDelay".to_string(), Value::Num(0.0));
139 object
140 .properties
141 .insert("OutputDelay".to_string(), Value::Num(0.0));
142 Ok(Value::Object(object))
143}
144
145#[derive(Clone)]
146struct TfOptions {
147 variable: String,
148 sample_time: f64,
149 variable_explicit: bool,
150}
151
152impl TfOptions {
153 fn parse(rest: &[Value]) -> BuiltinResult<Self> {
154 let mut options = Self {
155 variable: DEFAULT_VARIABLE.to_string(),
156 sample_time: 0.0,
157 variable_explicit: false,
158 };
159
160 match rest {
161 [] => {}
162 [sample_time] => {
163 options.sample_time = parse_sample_time(sample_time)?;
164 if options.sample_time > 0.0 {
165 options.variable = "z".to_string();
166 }
167 }
168 _ => {
169 if !rest.len().is_multiple_of(2) {
170 return Err(tf_error(
171 "tf: optional arguments must be name-value pairs or a scalar sample time",
172 ));
173 }
174 let mut idx = 0;
175 while idx < rest.len() {
176 let name = scalar_text(&rest[idx], "option name")?;
177 let lowered = name.trim().to_ascii_lowercase();
178 let value = &rest[idx + 1];
179 match lowered.as_str() {
180 "variable" => {
181 options.variable = parse_variable(value)?;
182 options.variable_explicit = true;
183 }
184 "ts" | "sampletime" => options.sample_time = parse_sample_time(value)?,
185 _ => {
186 return Err(tf_error(format!("tf: unsupported option '{name}'")));
187 }
188 }
189 idx += 2;
190 }
191 if options.sample_time > 0.0 && !options.variable_explicit {
192 options.variable = "z".to_string();
193 }
194 }
195 }
196
197 Ok(options)
198 }
199}
200
201fn parse_sample_time(value: &Value) -> BuiltinResult<f64> {
202 let sample_time = match value {
203 Value::Num(n) => *n,
204 Value::Int(i) => i.to_f64(),
205 other => {
206 return Err(tf_error(format!(
207 "tf: sample time must be a non-negative scalar, got {other:?}"
208 )))
209 }
210 };
211 if !sample_time.is_finite() || sample_time < 0.0 {
212 return Err(tf_error(
213 "tf: sample time must be a finite non-negative scalar",
214 ));
215 }
216 Ok(sample_time)
217}
218
219fn parse_variable(value: &Value) -> BuiltinResult<String> {
220 let variable = scalar_text(value, "Variable")?;
221 let variable = variable.trim();
222 match variable {
223 "s" | "p" | "z" | "q" | "z^-1" | "q^-1" => Ok(variable.to_string()),
224 _ => Err(tf_error(
225 "tf: Variable must be one of 's', 'p', 'z', 'q', 'z^-1', or 'q^-1'",
226 )),
227 }
228}
229
230fn scalar_text(value: &Value, context: &str) -> BuiltinResult<String> {
231 match value {
232 Value::String(text) => Ok(text.clone()),
233 Value::StringArray(array) if array.data.len() == 1 => Ok(array.data[0].clone()),
234 Value::CharArray(array) if array.rows == 1 => Ok(array.data.iter().collect()),
235 other => Err(tf_error(format!(
236 "tf: {context} must be a string scalar or character vector, got {other:?}"
237 ))),
238 }
239}
240
241#[derive(Clone)]
242struct Coefficients {
243 coeffs: Vec<Complex64>,
244}
245
246impl Coefficients {
247 async fn parse(label: &str, value: Value) -> BuiltinResult<Self> {
248 let gathered = dispatcher::gather_if_needed_async(&value).await?;
249 let coeffs = match gathered {
250 Value::Tensor(tensor) => {
251 ensure_vector_shape(label, &tensor.shape)?;
252 tensor
253 .data
254 .into_iter()
255 .map(|re| Complex64::new(re, 0.0))
256 .collect()
257 }
258 Value::ComplexTensor(tensor) => {
259 ensure_vector_shape(label, &tensor.shape)?;
260 tensor
261 .data
262 .into_iter()
263 .map(|(re, im)| Complex64::new(re, im))
264 .collect()
265 }
266 Value::LogicalArray(logical) => {
267 let tensor = tensor::logical_to_tensor(&logical).map_err(tf_error)?;
268 ensure_vector_shape(label, &tensor.shape)?;
269 tensor
270 .data
271 .into_iter()
272 .map(|re| Complex64::new(re, 0.0))
273 .collect()
274 }
275 Value::Num(n) => vec![Complex64::new(n, 0.0)],
276 Value::Int(i) => vec![Complex64::new(i.to_f64(), 0.0)],
277 Value::Bool(b) => vec![Complex64::new(if b { 1.0 } else { 0.0 }, 0.0)],
278 Value::Complex(re, im) => vec![Complex64::new(re, im)],
279 other => {
280 return Err(tf_error(format!(
281 "tf: {label} must be a numeric coefficient vector, got {other:?}"
282 )));
283 }
284 };
285
286 if coeffs.is_empty() {
287 return Err(tf_error(format!(
288 "tf: {label} coefficients cannot be empty"
289 )));
290 }
291 for coeff in &coeffs {
292 if !coeff.re.is_finite() || !coeff.im.is_finite() {
293 return Err(tf_error(format!("tf: {label} coefficients must be finite")));
294 }
295 }
296
297 Ok(Self { coeffs })
298 }
299
300 fn is_all_zero(&self) -> bool {
301 self.coeffs.iter().all(|coeff| coeff.norm() <= EPS)
302 }
303
304 fn into_row_value(self) -> BuiltinResult<Value> {
305 let len = self.coeffs.len();
306 if self.coeffs.iter().all(|coeff| coeff.im.abs() <= EPS) {
307 let data = self.coeffs.into_iter().map(|coeff| coeff.re).collect();
308 let tensor =
309 Tensor::new(data, vec![1, len]).map_err(|err| tf_error(format!("tf: {err}")))?;
310 Ok(Value::Tensor(tensor))
311 } else {
312 let data = self
313 .coeffs
314 .into_iter()
315 .map(|coeff| (coeff.re, coeff.im))
316 .collect();
317 let tensor = ComplexTensor::new(data, vec![1, len])
318 .map_err(|err| tf_error(format!("tf: {err}")))?;
319 Ok(Value::ComplexTensor(tensor))
320 }
321 }
322}
323
324fn ensure_vector_shape(label: &str, shape: &[usize]) -> BuiltinResult<()> {
325 let non_unit = shape.iter().copied().filter(|&dim| dim > 1).count();
326 if non_unit <= 1 {
327 Ok(())
328 } else {
329 Err(tf_error(format!(
330 "tf: {label} coefficients must be a vector"
331 )))
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use futures::executor::block_on;
339 use runmat_builtins::IntValue;
340
341 fn run_tf(numerator: Value, denominator: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
342 block_on(tf_builtin(numerator, denominator, rest))
343 }
344
345 fn property<'a>(value: &'a Value, name: &str) -> &'a Value {
346 let Value::Object(object) = value else {
347 panic!("expected object, got {value:?}");
348 };
349 object
350 .properties
351 .get(name)
352 .unwrap_or_else(|| panic!("missing property {name}"))
353 }
354
355 #[test]
356 fn tf_constructs_continuous_siso_object() {
357 let sys = run_tf(
358 Value::Num(20.0),
359 Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
360 Vec::new(),
361 )
362 .expect("tf");
363
364 let Value::Object(object) = &sys else {
365 panic!("expected object");
366 };
367 assert_eq!(object.class_name, "tf");
368 assert_eq!(
369 property(&sys, "Variable"),
370 &Value::CharArray(CharArray::new_row("s"))
371 );
372 assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
373 match property(&sys, "Numerator") {
374 Value::Tensor(tensor) => {
375 assert_eq!(tensor.shape, vec![1, 1]);
376 assert_eq!(tensor.data, vec![20.0]);
377 }
378 other => panic!("expected numerator tensor, got {other:?}"),
379 }
380 match property(&sys, "Denominator") {
381 Value::Tensor(tensor) => {
382 assert_eq!(tensor.shape, vec![1, 2]);
383 assert_eq!(tensor.data, vec![1.0, 5.0]);
384 }
385 other => panic!("expected denominator tensor, got {other:?}"),
386 }
387 }
388
389 #[test]
390 fn tf_normalizes_column_coefficients_to_rows() {
391 let sys = run_tf(
392 Value::Tensor(Tensor::new(vec![1.0, 2.0], vec![2, 1]).unwrap()),
393 Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0], vec![3, 1]).unwrap()),
394 Vec::new(),
395 )
396 .expect("tf");
397
398 match property(&sys, "Numerator") {
399 Value::Tensor(tensor) => {
400 assert_eq!(tensor.shape, vec![1, 2]);
401 assert_eq!(tensor.data, vec![1.0, 2.0]);
402 }
403 other => panic!("expected numerator tensor, got {other:?}"),
404 }
405 match property(&sys, "Denominator") {
406 Value::Tensor(tensor) => {
407 assert_eq!(tensor.shape, vec![1, 3]);
408 assert_eq!(tensor.data, vec![1.0, 3.0, 2.0]);
409 }
410 other => panic!("expected denominator tensor, got {other:?}"),
411 }
412 }
413
414 #[test]
415 fn tf_accepts_discrete_sample_time() {
416 let sys = run_tf(
417 Value::Int(IntValue::I32(1)),
418 Value::Tensor(Tensor::new(vec![1.0, -0.5], vec![1, 2]).unwrap()),
419 vec![Value::Num(0.1)],
420 )
421 .expect("tf");
422
423 assert_eq!(
424 property(&sys, "Variable"),
425 &Value::CharArray(CharArray::new_row("z"))
426 );
427 assert_eq!(property(&sys, "Ts"), &Value::Num(0.1));
428 }
429
430 #[test]
431 fn tf_positional_zero_sample_time_remains_continuous() {
432 let sys = run_tf(
433 Value::Int(IntValue::I32(1)),
434 Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
435 vec![Value::Num(0.0)],
436 )
437 .expect("tf");
438
439 assert_eq!(
440 property(&sys, "Variable"),
441 &Value::CharArray(CharArray::new_row("s"))
442 );
443 assert_eq!(property(&sys, "Ts"), &Value::Num(0.0));
444 }
445
446 #[test]
447 fn tf_accepts_variable_name_value_option() {
448 let sys = run_tf(
449 Value::Num(1.0),
450 Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap()),
451 vec![Value::from("Variable"), Value::from("p")],
452 )
453 .expect("tf");
454
455 assert_eq!(
456 property(&sys, "Variable"),
457 &Value::CharArray(CharArray::new_row("p"))
458 );
459 }
460
461 #[test]
462 fn tf_explicit_continuous_variable_survives_positive_sample_time() {
463 let sys = run_tf(
464 Value::Num(1.0),
465 Value::Tensor(Tensor::new(vec![1.0, 1.0], vec![1, 2]).unwrap()),
466 vec![
467 Value::from("Variable"),
468 Value::from("s"),
469 Value::from("Ts"),
470 Value::Num(0.5),
471 ],
472 )
473 .expect("tf");
474
475 assert_eq!(
476 property(&sys, "Variable"),
477 &Value::CharArray(CharArray::new_row("s"))
478 );
479 assert_eq!(property(&sys, "Ts"), &Value::Num(0.5));
480 }
481
482 #[test]
483 fn tf_rejects_zero_denominator() {
484 let err = run_tf(
485 Value::Num(1.0),
486 Value::Tensor(Tensor::new(vec![0.0, 0.0], vec![1, 2]).unwrap()),
487 Vec::new(),
488 )
489 .expect_err("zero denominator should fail");
490 assert!(err.message().contains("must not all be zero"));
491 }
492
493 #[test]
494 fn tf_rejects_matrix_coefficients() {
495 let err = run_tf(
496 Value::Tensor(Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap()),
497 Value::Tensor(Tensor::new(vec![1.0, 5.0], vec![1, 2]).unwrap()),
498 Vec::new(),
499 )
500 .expect_err("matrix numerator should fail");
501 assert!(err
502 .message()
503 .contains("numerator coefficients must be a vector"));
504 }
505}