Skip to main content

runmat_runtime/builtins/control/
pole.rs

1//! Pole extraction for transfer-function and state-space control models.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor, Value,
6};
7use runmat_macros::runtime_builtin;
8
9use crate::builtins::common::spec::{
10    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
11    ReductionNaN, ResidencyPolicy, ShapeRequirements,
12};
13use crate::builtins::control::tf_model::{
14    control_error, output_complex_column, ss_poles_from_object, TfModel, SS_CLASS, TF_CLASS,
15};
16use crate::builtins::control::type_resolvers::pole_type;
17use crate::{dispatcher, BuiltinResult};
18
19const POLE_OUTPUT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
20    name: "p",
21    ty: BuiltinParamType::Any,
22    arity: BuiltinParamArity::Required,
23    default: None,
24    description: "Poles of the SISO tf or ss model as a column vector.",
25}];
26const POLE_INPUTS: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
27    name: "sys",
28    ty: BuiltinParamType::Any,
29    arity: BuiltinParamArity::Required,
30    default: None,
31    description: "SISO tf model or ss state-space model.",
32}];
33const POLE_SIGNATURES: [BuiltinSignatureDescriptor; 1] = [BuiltinSignatureDescriptor {
34    label: "p = pole(sys)",
35    inputs: &POLE_INPUTS,
36    outputs: &POLE_OUTPUT,
37}];
38const POLE_ERRORS: [BuiltinErrorDescriptor; 4] = [
39    BuiltinErrorDescriptor {
40        code: "RM.POLE.INVALID_MODEL",
41        identifier: Some("RunMat:pole:InvalidModel"),
42        when: "Input system is not a valid SISO tf or ss object.",
43        message: "pole: invalid model",
44    },
45    BuiltinErrorDescriptor {
46        code: "RM.POLE.UNSUPPORTED_MODEL",
47        identifier: Some("RunMat:pole:UnsupportedModel"),
48        when: "Model form is unsupported.",
49        message: "pole: unsupported model",
50    },
51    BuiltinErrorDescriptor {
52        code: "RM.POLE.INVALID_ARGUMENT",
53        identifier: Some("RunMat:pole:InvalidArgument"),
54        when: "Model metadata or arguments are malformed.",
55        message: "pole: invalid argument",
56    },
57    BuiltinErrorDescriptor {
58        code: "RM.POLE.INTERNAL",
59        identifier: Some("RunMat:pole:Internal"),
60        when: "Root calculation or output construction failed.",
61        message: "pole: internal error",
62    },
63];
64pub const POLE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
65    signatures: &POLE_SIGNATURES,
66    output_mode: BuiltinOutputMode::Fixed,
67    completion_policy: BuiltinCompletionPolicy::Public,
68    errors: &POLE_ERRORS,
69};
70
71#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::control::pole")]
72pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
73    name: "pole",
74    op_kind: GpuOpKind::Custom("control-poles"),
75    supported_precisions: &[],
76    broadcast: BroadcastSemantics::None,
77    provider_hooks: &[],
78    constant_strategy: ConstantStrategy::InlineLiteral,
79    residency: ResidencyPolicy::GatherImmediately,
80    nan_mode: ReductionNaN::Include,
81    two_pass_threshold: None,
82    workgroup_size: None,
83    accepts_nan_mode: false,
84    notes: "pole computes roots or state-matrix eigenvalues from host-side model metadata.",
85};
86
87#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::control::pole")]
88pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
89    name: "pole",
90    shape: ShapeRequirements::Any,
91    constant_strategy: ConstantStrategy::InlineLiteral,
92    elementwise: None,
93    reduction: None,
94    emits_nan: false,
95    notes: "pole is model analysis and is not fused.",
96};
97
98#[runtime_builtin(
99    name = "pole",
100    category = "control",
101    summary = "Return poles of transfer-function and state-space control models.",
102    keywords = "pole,poles,control system,stability,transfer function,state space,tf,ss",
103    type_resolver(pole_type),
104    descriptor(crate::builtins::control::pole::POLE_DESCRIPTOR),
105    builtin_path = "crate::builtins::control::pole"
106)]
107async fn pole_builtin(sys: Value) -> BuiltinResult<Value> {
108    let gathered = dispatcher::gather_if_needed_async(&sys).await?;
109    let poles = match gathered {
110        Value::Object(object) if object.is_class(TF_CLASS) => {
111            TfModel::from_value(Value::Object(object), "pole")?.poles()?
112        }
113        Value::Object(object) if object.is_class(SS_CLASS) => {
114            ss_poles_from_object(&object, "pole")?.0
115        }
116        Value::Object(object) => {
117            return Err(control_error(
118                "pole",
119                "RunMat:pole:UnsupportedModel",
120                format!(
121                    "pole: unsupported model class '{}'; supported classes are tf and ss",
122                    object.class_name
123                ),
124            ));
125        }
126        other => {
127            return Err(control_error(
128                "pole",
129                "RunMat:pole:InvalidModel",
130                format!("pole: expected a tf or ss object, got {other:?}"),
131            ));
132        }
133    };
134    output_complex_column(poles, "pole")
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use futures::executor::block_on;
141    use runmat_builtins::Tensor;
142
143    #[test]
144    fn pole_returns_roots_of_denominator() {
145        let sys = block_on(crate::call_builtin_async(
146            "tf",
147            &[
148                Value::Num(1.0),
149                Value::Tensor(Tensor::new(vec![1.0, 3.0, 2.0], vec![1, 3]).unwrap()),
150            ],
151        ))
152        .expect("tf");
153        let Value::Tensor(poles) = block_on(pole_builtin(sys)).expect("pole") else {
154            panic!("expected real poles");
155        };
156        assert_eq!(poles.shape, vec![2, 1]);
157        assert!(poles.data.iter().any(|p| (*p + 1.0).abs() < 1.0e-8));
158        assert!(poles.data.iter().any(|p| (*p + 2.0).abs() < 1.0e-8));
159    }
160
161    #[test]
162    fn pole_returns_repeated_roots_of_denominator() {
163        let sys = block_on(crate::call_builtin_async(
164            "tf",
165            &[
166                Value::Num(1.0),
167                Value::Tensor(Tensor::new(vec![1.0, 2.0, 1.0], vec![1, 3]).unwrap()),
168            ],
169        ))
170        .expect("tf");
171        let Value::Tensor(poles) = block_on(pole_builtin(sys)).expect("pole") else {
172            panic!("expected real poles");
173        };
174        assert_eq!(poles.shape, vec![2, 1]);
175        assert!(poles.data.iter().all(|p| (*p + 1.0).abs() < 1.0e-8));
176    }
177
178    #[test]
179    fn pole_returns_complex_conjugate_roots() {
180        let sys = block_on(crate::call_builtin_async(
181            "tf",
182            &[
183                Value::Num(1.0),
184                Value::Tensor(Tensor::new(vec![1.0, 0.0, 1.0], vec![1, 3]).unwrap()),
185            ],
186        ))
187        .expect("tf");
188        let Value::ComplexTensor(poles) = block_on(pole_builtin(sys)).expect("pole") else {
189            panic!("expected complex poles");
190        };
191        assert_eq!(poles.shape, vec![2, 1]);
192        assert!(poles.data.iter().all(|(re, _)| re.abs() < 1.0e-8));
193        assert!(poles.data.iter().any(|(_, im)| (*im - 1.0).abs() < 1.0e-8));
194        assert!(poles.data.iter().any(|(_, im)| (*im + 1.0).abs() < 1.0e-8));
195    }
196
197    #[test]
198    fn pole_uses_state_matrix_eigenvalues_for_ss() {
199        let sys = block_on(crate::call_builtin_async(
200            "ss",
201            &[
202                Value::Tensor(Tensor::new(vec![0.0, -4.0, 1.0, -0.5], vec![2, 2]).unwrap()),
203                Value::Tensor(Tensor::new(vec![0.0, 1.0], vec![2, 1]).unwrap()),
204                Value::Tensor(Tensor::new(vec![1.0, 0.0], vec![1, 2]).unwrap()),
205                Value::Num(0.0),
206            ],
207        ))
208        .expect("ss");
209        let Value::ComplexTensor(poles) = block_on(pole_builtin(sys)).expect("pole") else {
210            panic!("expected complex poles");
211        };
212        assert_eq!(poles.shape, vec![2, 1]);
213        assert!(poles.data.iter().all(|(re, _)| (*re + 0.25).abs() < 1.0e-8));
214        assert!(poles
215            .data
216            .iter()
217            .any(|(_, im)| (*im - 1.984313483298443).abs() < 1.0e-8));
218        assert!(poles
219            .data
220            .iter()
221            .any(|(_, im)| (*im + 1.984313483298443).abs() < 1.0e-8));
222    }
223}