numr/runtime/cpu/polynomial/
polynomial.rs1use super::super::{CpuClient, CpuRuntime};
8use crate::algorithm::polynomial::PolynomialAlgorithms;
9use crate::algorithm::polynomial::core::{self, DTypeSupport};
10use crate::algorithm::polynomial::types::PolynomialRoots;
11use crate::error::Result;
12use crate::tensor::Tensor;
13
14impl PolynomialAlgorithms<CpuRuntime> for CpuClient {
15 fn polyroots(&self, coeffs: &Tensor<CpuRuntime>) -> Result<PolynomialRoots<CpuRuntime>> {
16 core::polyroots_impl(self, coeffs, DTypeSupport::FULL)
17 }
18
19 fn polyval(
20 &self,
21 coeffs: &Tensor<CpuRuntime>,
22 x: &Tensor<CpuRuntime>,
23 ) -> Result<Tensor<CpuRuntime>> {
24 core::polyval_impl(self, coeffs, x, DTypeSupport::FULL)
25 }
26
27 fn polyfromroots(
28 &self,
29 roots_real: &Tensor<CpuRuntime>,
30 roots_imag: &Tensor<CpuRuntime>,
31 ) -> Result<Tensor<CpuRuntime>> {
32 core::polyfromroots_impl(self, roots_real, roots_imag, DTypeSupport::FULL)
33 }
34
35 fn polymul(
36 &self,
37 a: &Tensor<CpuRuntime>,
38 b: &Tensor<CpuRuntime>,
39 ) -> Result<Tensor<CpuRuntime>> {
40 core::polymul_impl(self, a, b, DTypeSupport::FULL)
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use crate::runtime::Runtime;
48 use crate::runtime::cpu::CpuDevice;
49
50 fn create_client() -> (CpuClient, CpuDevice) {
51 let device = CpuDevice::new();
52 let client = CpuRuntime::default_client(&device);
53 (client, device)
54 }
55
56 #[test]
57 fn test_polyroots_quadratic_real() {
58 let (client, device) = create_client();
59
60 let coeffs = Tensor::<CpuRuntime>::from_slice(&[2.0f32, -3.0, 1.0], &[3], &device);
63
64 let roots = client.polyroots(&coeffs).unwrap();
65
66 let real: Vec<f32> = roots.roots_real.to_vec();
67 let imag: Vec<f32> = roots.roots_imag.to_vec();
68
69 assert_eq!(real.len(), 2);
70 assert_eq!(imag.len(), 2);
71
72 let mut sorted_real: Vec<f32> = real.clone();
74 sorted_real.sort_by(|a, b| a.partial_cmp(b).unwrap());
75
76 assert!(
77 (sorted_real[0] - 1.0).abs() < 1e-4,
78 "Expected root 1, got {}",
79 sorted_real[0]
80 );
81 assert!(
82 (sorted_real[1] - 2.0).abs() < 1e-4,
83 "Expected root 2, got {}",
84 sorted_real[1]
85 );
86
87 for (i, &im) in imag.iter().enumerate() {
88 assert!(
89 im.abs() < 1e-4,
90 "Expected real root, got imag={} at {}",
91 im,
92 i
93 );
94 }
95 }
96
97 #[test]
98 fn test_polyroots_quadratic_complex() {
99 let (client, device) = create_client();
100
101 let coeffs = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 1.0], &[3], &device);
104
105 let roots = client.polyroots(&coeffs).unwrap();
106
107 let real: Vec<f32> = roots.roots_real.to_vec();
108 let imag: Vec<f32> = roots.roots_imag.to_vec();
109
110 assert_eq!(real.len(), 2);
111 assert_eq!(imag.len(), 2);
112
113 for r in &real {
115 assert!(r.abs() < 1e-4, "Expected real part ~0, got {}", r);
116 }
117
118 let mut sorted_imag: Vec<f32> = imag.clone();
119 sorted_imag.sort_by(|a, b| a.partial_cmp(b).unwrap());
120 assert!((sorted_imag[0] - (-1.0)).abs() < 1e-4);
121 assert!((sorted_imag[1] - 1.0).abs() < 1e-4);
122 }
123
124 #[test]
125 fn test_polyval_constant() {
126 let (client, device) = create_client();
127
128 let coeffs = Tensor::<CpuRuntime>::from_slice(&[5.0f32], &[1], &device);
130 let x = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
131
132 let result = client.polyval(&coeffs, &x).unwrap();
133 let data: Vec<f32> = result.to_vec();
134
135 assert_eq!(data.len(), 3);
136 for &v in &data {
137 assert!((v - 5.0).abs() < 1e-6);
138 }
139 }
140
141 #[test]
142 fn test_polyval_linear() {
143 let (client, device) = create_client();
144
145 let coeffs = Tensor::<CpuRuntime>::from_slice(&[2.0f32, 3.0], &[2], &device);
147 let x = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 1.0, 2.0], &[3], &device);
148
149 let result = client.polyval(&coeffs, &x).unwrap();
150 let data: Vec<f32> = result.to_vec();
151
152 assert!((data[0] - 2.0).abs() < 1e-6); assert!((data[1] - 5.0).abs() < 1e-6); assert!((data[2] - 8.0).abs() < 1e-6); }
156
157 #[test]
158 fn test_polyval_quadratic() {
159 let (client, device) = create_client();
160
161 let coeffs = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
163 let x = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1], &device);
164
165 let result = client.polyval(&coeffs, &x).unwrap();
166 let data: Vec<f32> = result.to_vec();
167
168 assert!((data[0] - 17.0).abs() < 1e-5);
169 }
170
171 #[test]
172 fn test_polyfromroots_real() {
173 let (client, device) = create_client();
174
175 let roots_real = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device);
177 let roots_imag = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &device);
178
179 let coeffs = client.polyfromroots(&roots_real, &roots_imag).unwrap();
180 let data: Vec<f32> = coeffs.to_vec();
181
182 assert_eq!(data.len(), 3);
183 assert!(
184 (data[0] - 2.0).abs() < 1e-5,
185 "Expected c0=2, got {}",
186 data[0]
187 );
188 assert!(
189 (data[1] - (-3.0)).abs() < 1e-5,
190 "Expected c1=-3, got {}",
191 data[1]
192 );
193 assert!(
194 (data[2] - 1.0).abs() < 1e-5,
195 "Expected c2=1, got {}",
196 data[2]
197 );
198 }
199
200 #[test]
201 fn test_polyfromroots_complex() {
202 let (client, device) = create_client();
203
204 let roots_real = Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0], &[2], &device);
206 let roots_imag = Tensor::<CpuRuntime>::from_slice(&[1.0f32, -1.0], &[2], &device);
207
208 let coeffs = client.polyfromroots(&roots_real, &roots_imag).unwrap();
209 let data: Vec<f32> = coeffs.to_vec();
210
211 assert_eq!(data.len(), 3);
212 assert!(
213 (data[0] - 1.0).abs() < 1e-5,
214 "Expected c0=1, got {}",
215 data[0]
216 );
217 assert!(data[1].abs() < 1e-5, "Expected c1=0, got {}", data[1]);
218 assert!(
219 (data[2] - 1.0).abs() < 1e-5,
220 "Expected c2=1, got {}",
221 data[2]
222 );
223 }
224
225 #[test]
226 fn test_polymul_linear() {
227 let (client, device) = create_client();
228
229 let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
231 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
232
233 let c = client.polymul(&a, &b).unwrap();
234 let data: Vec<f32> = c.to_vec();
235
236 assert_eq!(data.len(), 3);
237 assert!((data[0] - 1.0).abs() < 1e-6);
238 assert!((data[1] - 2.0).abs() < 1e-6);
239 assert!((data[2] - 1.0).abs() < 1e-6);
240 }
241
242 #[test]
243 fn test_polymul_difference_of_squares() {
244 let (client, device) = create_client();
245
246 let a = Tensor::<CpuRuntime>::from_slice(&[1.0f32, -1.0], &[2], &device);
248 let b = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device);
249
250 let c = client.polymul(&a, &b).unwrap();
251 let data: Vec<f32> = c.to_vec();
252
253 assert_eq!(data.len(), 3);
254 assert!((data[0] - 1.0).abs() < 1e-6);
255 assert!(data[1].abs() < 1e-6);
256 assert!((data[2] - (-1.0)).abs() < 1e-6);
257 }
258
259 #[test]
260 fn test_roundtrip_roots_coeffs() {
261 let (client, device) = create_client();
262
263 let original = Tensor::<CpuRuntime>::from_slice(&[6.0f32, -5.0, 1.0], &[3], &device);
265
266 let roots = client.polyroots(&original).unwrap();
268
269 let reconstructed = client
271 .polyfromroots(&roots.roots_real, &roots.roots_imag)
272 .unwrap();
273 let data: Vec<f32> = reconstructed.to_vec();
274
275 assert_eq!(data.len(), 3);
278 assert!(
279 (data[0] - 6.0).abs() < 1e-4,
280 "c0: expected 6, got {}",
281 data[0]
282 );
283 assert!(
284 (data[1] - (-5.0)).abs() < 1e-4,
285 "c1: expected -5, got {}",
286 data[1]
287 );
288 assert!(
289 (data[2] - 1.0).abs() < 1e-4,
290 "c2: expected 1, got {}",
291 data[2]
292 );
293 }
294
295 #[test]
296 fn test_polyroots_f64() {
297 let (client, device) = create_client();
298
299 let coeffs = Tensor::<CpuRuntime>::from_slice(&[2.0f64, -3.0, 1.0], &[3], &device);
301
302 let roots = client.polyroots(&coeffs).unwrap();
303
304 let real: Vec<f64> = roots.roots_real.to_vec();
305 let imag: Vec<f64> = roots.roots_imag.to_vec();
306
307 assert_eq!(real.len(), 2);
308
309 let mut sorted_real: Vec<f64> = real.clone();
310 sorted_real.sort_by(|a, b| a.partial_cmp(b).unwrap());
311
312 assert!((sorted_real[0] - 1.0).abs() < 1e-10);
313 assert!((sorted_real[1] - 2.0).abs() < 1e-10);
314
315 for im in &imag {
316 assert!(im.abs() < 1e-10);
317 }
318 }
319}