1use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
2use crate::integer::IntegerRingStore;
3use crate::iters::multi_cartesian_product;
4use crate::primitive_int::StaticRing;
5use crate::ring::*;
6use crate::rings::multivariate::*;
7use crate::rings::poly::*;
8use crate::seq::*;
9use crate::homomorphism::Homomorphism;
10use crate::rings::poly::dense_poly::DensePolyRing;
11
12use std::alloc::Allocator;
13use std::cmp::min;
14use std::ops::Range;
15
16#[allow(unused)]
17#[stability::unstable(feature = "enable")]
18fn invert_many<R>(ring: R, values: &[El<R>], out: &mut [El<R>]) -> Result<(), ()>
19 where R: RingStore,
20 R::Type: DivisibilityRing
21{
22 assert_eq!(out.len(), values.len());
23 out[0] = ring.clone_el(&values[0]);
24 for i in 1..out.len() {
25 out[i] = ring.mul_ref(&values[i], &out[i - 1]);
26 }
27 let joint_inv = ring.invert(&out[out.len() - 1]).ok_or(())?;
28 out[out.len() - 1] = joint_inv;
29 for i in (1..out.len()).rev() {
30 let (fst, snd) = out.split_at_mut(i);
31 ring.mul_assign_ref(&mut fst[i - 1], &snd[0]);
32 ring.mul_assign_ref(&mut snd[0], &values[i]);
33 std::mem::swap(&mut fst[i - 1], &mut snd[0]);
34 }
35 return Ok(());
36}
37
38#[stability::unstable(feature = "enable")]
64pub fn product_except_one<V, R>(ring: R, values: V, out: &mut [El<R>])
65 where R: RingStore,
66 V: VectorFn<El<R>>
67{
68 assert_eq!(values.len(), out.len());
69 let n = values.len();
70 let log2_n = StaticRing::<i64>::RING.abs_log2_ceil(&(n as i64)).unwrap();
71 assert!(n <= (1 << log2_n));
72 if n % 2 == 0 {
73 for i in 0..n {
74 out[i] = values.at(i ^ 1);
75 }
76 } else {
77 for i in 0..(n - 1) {
78 out[i] = values.at(i ^ 1);
79 }
80 out[n - 1] = ring.one();
81 }
82 for s in 1..log2_n {
83 for j in 0..(1 << (log2_n - s - 1)) {
84 let block_index = j << (s + 1);
85 if block_index + (1 << s) < n {
86 let (fst, snd) = (&mut out[block_index..min(n, block_index + (1 << (s + 1)))]).split_at_mut(1 << s);
87 let snd_block_prod = ring.mul_ref_fst(&snd[0], values.at(block_index + (1 << s)));
88 let fst_block_prod = ring.mul_ref_fst(&fst[0], values.at(block_index));
89 for i in 0..(1 << s) {
90 ring.mul_assign_ref(&mut fst[i], &snd_block_prod);
91 }
92 for i in 0..snd.len() {
93 ring.mul_assign_ref(&mut snd[i], &fst_block_prod);
94 }
95 }
96 }
97 }
98}
99
100#[stability::unstable(feature = "enable")]
101#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
102pub enum InterpolationError {
103 NotInvertible
104}
105
106#[stability::unstable(feature = "enable")]
151pub fn interpolate<P, V1, V2, A: Allocator>(poly_ring: P, x: V1, y: V2, allocator: A) -> Result<El<P>, InterpolationError>
152 where P: RingStore,
153 P::Type: PolyRing,
154 <<P::Type as RingExtension>::BaseRing as RingStore>::Type: DivisibilityRing,
155 V1: VectorFn<El<<P::Type as RingExtension>::BaseRing>>,
156 V2: VectorFn<El<<P::Type as RingExtension>::BaseRing>>
157{
158 assert_eq!(x.len(), y.len());
159 let mut nums = Vec::with_capacity_in(x.len(), &allocator);
160 nums.resize_with(x.len(), || poly_ring.zero());
161 let R = poly_ring.base_ring();
162 product_except_one(&poly_ring, (0..x.len()).map_fn(|i| poly_ring.from_terms([(R.negate(x.at(i)), 0), (R.one(), 1)].into_iter())), &mut nums[..]);
163
164 let mut denoms = Vec::with_capacity_in(x.len(), &allocator);
165 denoms.extend((0..x.len()).map(|i| poly_ring.evaluate(&nums[i], &x.at(i), &R.identity())));
166 let mut factors = Vec::with_capacity_in(x.len(), &allocator);
167 factors.resize_with(x.len(), || R.zero());
168 product_except_one(R, (&denoms[..]).into_clone_ring_els(R), &mut factors);
169 let denominator = R.mul_ref(&factors[0], &denoms[0]);
170 for i in 0..x.len() {
171 R.mul_assign(&mut factors[i], y.at(i));
172 }
173
174 if let Some(inv) = R.invert(&denominator) {
175 return Ok(poly_ring.inclusion().mul_map(<_ as RingStore>::sum(&poly_ring, nums.into_iter().zip(factors.into_iter()).map(|(num, c)| poly_ring.inclusion().mul_map(num, c))), inv));
176 } else {
177 let scaled_result = <_ as RingStore>::sum(&poly_ring, nums.into_iter().zip(factors.into_iter()).map(|(num, c)| poly_ring.inclusion().mul_map(num, c)));
178 let mut failed_division = false;
179 let result = poly_ring.from_terms(poly_ring.terms(&scaled_result).map_while(|(c, i)| match R.checked_div(&c, &denominator) {
180 Some(c) => Some((c, i)),
181 None => {
182 failed_division = true;
183 None
184 }
185 }));
186 if failed_division {
187 return Err(InterpolationError::NotInvertible);
188 } else {
189 return Ok(result);
190 }
191 }
192}
193
194#[stability::unstable(feature = "enable")]
195pub fn interpolate_multivariate<P, V1, V2, A, A2>(poly_ring: P, interpolation_points: V1, mut values: Vec<El<<P::Type as RingExtension>::BaseRing>, A2>, allocator: A) -> Result<El<P>, InterpolationError>
196 where P: RingStore,
197 P::Type: MultivariatePolyRing,
198 <<P::Type as RingExtension>::BaseRing as RingStore>::Type: DivisibilityRing,
199 V1: VectorFn<V2>,
200 V2: VectorFn<El<<P::Type as RingExtension>::BaseRing>>,
201 A: Allocator,
202 A2: Allocator
203{
204 let dim_prod = |range: Range<usize>| <_ as RingStore>::prod(&StaticRing::<i64>::RING, range.map(|i| interpolation_points.at(i).len() as i64)) as usize;
205 assert_eq!(interpolation_points.len(), poly_ring.indeterminate_count());
206 let n = poly_ring.indeterminate_count();
207 assert_eq!(values.len(), dim_prod(0..n));
208
209 let uni_poly_ring = DensePolyRing::new_with(poly_ring.base_ring(), "X", &allocator, STANDARD_CONVOLUTION);
210
211 for i in (0..n).rev() {
212 let leading_dim = dim_prod((i + 1)..n);
213 let outer_block_count = dim_prod(0..i);
214 let len = interpolation_points.at(i).len();
215 let outer_block_size = leading_dim * len;
216 for outer_block_index in 0..outer_block_count {
217 for inner_block_index in 0..leading_dim {
218 let block_start = inner_block_index + outer_block_index * outer_block_size;
219 let poly = interpolate(&uni_poly_ring, interpolation_points.at(i), (&values[..]).into_clone_ring_els(poly_ring.base_ring()).restrict(block_start..(block_start + outer_block_size + 1 - leading_dim)).step_by_fn(leading_dim), &allocator)?;
220 for j in 0..len {
221 values[block_start + leading_dim * j] = poly_ring.base_ring().clone_el(uni_poly_ring.coefficient_at(&poly, j));
222 }
223 }
224 }
225 }
226 return Ok(poly_ring.from_terms(
227 multi_cartesian_product((0..n).map(|i| 0..interpolation_points.at(i).len()), |idxs| poly_ring.get_ring().create_monomial(idxs.iter().map(|e| *e)), |_, x| *x)
228 .zip(values.into_iter())
229 .map(|(m, c)| (c, m))
230 ));
231}
232
233#[cfg(test)]
234use crate::rings::finite::FiniteRingStore;
235#[cfg(test)]
236use crate::rings::zn::zn_64::Zn;
237#[cfg(test)]
238use std::alloc::Global;
239#[cfg(test)]
240use multivariate_impl::MultivariatePolyRingImpl;
241
242use super::convolution::STANDARD_CONVOLUTION;
243
244#[test]
245fn test_product_except_one() {
246 let ring = StaticRing::<i64>::RING;
247 let data = [2, 3, 5, 7, 11, 13, 17, 19];
248 let mut actual = [0; 8];
249 let expected = [
250 3 * 5 * 7 * 11 * 13 * 17 * 19,
251 2 * 5 * 7 * 11 * 13 * 17 * 19,
252 2 * 3 * 7 * 11 * 13 * 17 * 19,
253 2 * 3 * 5 * 11 * 13 * 17 * 19,
254 2 * 3 * 5 * 7 * 13 * 17 * 19,
255 2 * 3 * 5 * 7 * 11 * 17 * 19,
256 2 * 3 * 5 * 7 * 11 * 13 * 19,
257 2 * 3 * 5 * 7 * 11 * 13 * 17
258 ];
259 product_except_one(&ring, (&data[..]).clone_els_by(|x| *x), &mut actual);
260 assert_eq!(expected, actual);
261
262 let data = [2, 3, 5, 7, 11, 13, 17];
263 let mut actual = [0; 7];
264 let expected = [
265 3 * 5 * 7 * 11 * 13 * 17,
266 2 * 5 * 7 * 11 * 13 * 17,
267 2 * 3 * 7 * 11 * 13 * 17,
268 2 * 3 * 5 * 11 * 13 * 17,
269 2 * 3 * 5 * 7 * 13 * 17,
270 2 * 3 * 5 * 7 * 11 * 17,
271 2 * 3 * 5 * 7 * 11 * 13
272 ];
273 product_except_one(&ring, (&data[..]).clone_els_by(|x| *x), &mut actual);
274 assert_eq!(expected, actual);
275
276 let data = [2, 3, 5, 7, 11, 13];
277 let mut actual = [0; 6];
278 let expected = [
279 3 * 5 * 7 * 11 * 13,
280 2 * 5 * 7 * 11 * 13,
281 2 * 3 * 7 * 11 * 13,
282 2 * 3 * 5 * 11 * 13,
283 2 * 3 * 5 * 7 * 13,
284 2 * 3 * 5 * 7 * 11
285 ];
286 product_except_one(&ring, (&data[..]).clone_els_by(|x| *x), &mut actual);
287 assert_eq!(expected, actual);
288}
289
290#[test]
291fn test_invert_many() {
292 let ring = Zn::new(17);
293 let data = ring.elements().skip(1).collect::<Vec<_>>();
294 let mut actual = (0..16).map(|_| ring.zero()).collect::<Vec<_>>();
295 let expected = ring.elements().skip(1).map(|x| ring.invert(&x).unwrap()).collect::<Vec<_>>();
296 invert_many(&ring, &data, &mut actual).unwrap();
297 for i in 0..16 {
298 assert_el_eq!(&ring, &expected[i], &actual[i]);
299 }
300}
301
302#[test]
303fn test_interpolate() {
304 let ring = StaticRing::<i64>::RING;
305 let poly_ring = DensePolyRing::new(ring, "X");
306 let poly = poly_ring.from_terms([(3, 0), (1, 1), (-1, 3), (2, 4), (1, 5)].into_iter());
307 let actual = interpolate(&poly_ring, (0..6).map_fn(|x| x as i64), (0..6).map_fn(|x| poly_ring.evaluate(&poly, &(x as i64), &ring.identity())), Global).unwrap();
308 assert_el_eq!(&poly_ring, &poly, &actual);
309
310 let ring = Zn::new(25);
311 let poly_ring = DensePolyRing::new(ring, "X");
312 let poly = interpolate(&poly_ring, (0..5).map_fn(|x| ring.int_hom().map(x as i32)), (0..5).map_fn(|x| if x == 3 { ring.int_hom().map(6) } else { ring.zero() }), Global).unwrap();
313 for x in 0..5 {
314 if x == 3 {
315 assert_el_eq!(ring, ring.int_hom().map(6), poly_ring.evaluate(&poly, &ring.int_hom().map(x), &ring.identity()));
316 } else {
317 assert_el_eq!(ring, ring.zero(), poly_ring.evaluate(&poly, &ring.int_hom().map(x), &ring.identity()));
318 }
319 }
320}
321
322#[test]
323fn test_interpolate_multivariate() {
324 let ring = Zn::new(25);
325 let poly_ring: MultivariatePolyRingImpl<_> = MultivariatePolyRingImpl::new(ring, 2);
326
327 let interpolation_points = (0..2).map_fn(|_| (0..5).map_fn(|x| ring.int_hom().map(x as i32)));
328 let values = (0..25).map(|x| ring.int_hom().map(x & 1)).collect::<Vec<_>>();
329 let poly = interpolate_multivariate(&poly_ring, &interpolation_points, values, Global).unwrap();
330
331 for x in 0..5 {
332 for y in 0..5 {
333 let expected = (x * 5 + y) & 1;
334 assert_el_eq!(ring, ring.int_hom().map(expected), poly_ring.evaluate(&poly, [ring.int_hom().map(x), ring.int_hom().map(y)].into_clone_ring_els(&ring), &ring.identity()));
335 }
336 }
337
338 let poly_ring: MultivariatePolyRingImpl<_> = MultivariatePolyRingImpl::new(ring, 3);
339
340 let interpolation_points = (0..3).map_fn(|i| (0..(i + 2)).map_fn(|x| ring.int_hom().map(x as i32)));
341 let values = (0..24).map(|x| ring.int_hom().map(x / 2)).collect::<Vec<_>>();
342 let poly = interpolate_multivariate(&poly_ring, &interpolation_points, values, Global).unwrap();
343
344 for x in 0..2 {
345 for y in 0..3 {
346 for z in 0..4 {
347 let expected = (x * 12 + y * 4 + z) / 2;
348 assert_el_eq!(ring, ring.int_hom().map(expected), poly_ring.evaluate(&poly, [ring.int_hom().map(x), ring.int_hom().map(y), ring.int_hom().map(z)].into_clone_ring_els(&ring), &ring.identity()));
349 }
350 }
351 }
352}