bp_pp/
circuit.rs

1#![allow(non_snake_case)]
2//! Definition and implementation of the Bulletproofs++ arithmetic circuit protocol.
3
4use std::ops::{Add, Mul, Sub};
5use k256::{AffinePoint, ProjectivePoint, Scalar};
6use k256::elliptic_curve::ops::Invert;
7use k256::elliptic_curve::rand_core::{CryptoRng, RngCore};
8use merlin::Transcript;
9use serde::{Deserialize, Serialize};
10use crate::util::*;
11use crate::{transcript, wnla};
12use crate::wnla::WeightNormLinearArgument;
13
14#[derive(Clone, Debug, Copy, PartialEq)]
15pub enum PartitionType {
16    LO,
17    LL,
18    LR,
19    NO,
20}
21
22/// Represents arithmetic circuit zero-knowledge proof.
23#[derive(Clone, Debug)]
24pub struct Proof {
25    pub c_l: ProjectivePoint,
26    pub c_r: ProjectivePoint,
27    pub c_o: ProjectivePoint,
28    pub c_s: ProjectivePoint,
29    pub r: Vec<ProjectivePoint>,
30    pub x: Vec<ProjectivePoint>,
31    pub l: Vec<Scalar>,
32    pub n: Vec<Scalar>,
33}
34
35/// Represent serializable version of arithmetic circuit proof (uses AffinePoint instead of ProjectivePoint).
36#[derive(Serialize, Deserialize, Clone, Debug)]
37pub struct SerializableProof {
38    pub c_l: AffinePoint,
39    pub c_r: AffinePoint,
40    pub c_o: AffinePoint,
41    pub c_s: AffinePoint,
42    pub r: Vec<AffinePoint>,
43    pub x: Vec<AffinePoint>,
44    pub l: Vec<Scalar>,
45    pub n: Vec<Scalar>,
46}
47
48impl From<&SerializableProof> for Proof {
49    fn from(value: &SerializableProof) -> Self {
50        return Proof {
51            c_l: ProjectivePoint::from(&value.c_l),
52            c_r: ProjectivePoint::from(&value.c_r),
53            c_o: ProjectivePoint::from(&value.c_o),
54            c_s: ProjectivePoint::from(&value.c_s),
55            r: value.r.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
56            x: value.x.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
57            l: value.l.clone(),
58            n: value.n.clone(),
59        };
60    }
61}
62
63impl From<&Proof> for SerializableProof {
64    fn from(value: &Proof) -> Self {
65        return SerializableProof {
66            c_l: value.c_l.to_affine(),
67            c_r: value.c_r.to_affine(),
68            c_o: value.c_o.to_affine(),
69            c_s: value.c_s.to_affine(),
70            r: value.r.iter().map(|r_val| r_val.to_affine()).collect::<Vec<AffinePoint>>(),
71            x: value.x.iter().map(|x_val| x_val.to_affine()).collect::<Vec<AffinePoint>>(),
72            l: value.l.clone(),
73            n: value.n.clone(),
74        };
75    }
76}
77
78/// Represents arithmetic circuit witness.
79#[derive(Clone, Debug)]
80pub struct Witness {
81    /// Dimension: `k*dim_nv`
82    pub v: Vec<Vec<Scalar>>,
83    /// Dimension: `k`
84    pub s_v: Vec<Scalar>,
85    /// Dimension: `dim_nm`
86    pub w_l: Vec<Scalar>,
87    /// Dimension: `dim_nm`
88    pub w_r: Vec<Scalar>,
89    /// Dimension: `dim_no`
90    pub w_o: Vec<Scalar>,
91}
92
93/// Represents arithmetic circuit.
94/// P - partition function.
95pub struct ArithmeticCircuit<P>
96    where
97        P: Fn(PartitionType, usize) -> Option<usize>
98{
99    pub dim_nm: usize,
100    pub dim_no: usize,
101    pub k: usize,
102
103    /// Equals to: `dim_nv * k`
104    pub dim_nl: usize,
105    /// Count of witness vectors v.
106    pub dim_nv: usize,
107    ///  Equals to: `dim_nm + dim_nm + n_o`
108    pub dim_nw: usize,
109
110    pub g: ProjectivePoint,
111
112    /// Dimension: `dim_nm`
113    pub g_vec: Vec<ProjectivePoint>,
114    /// Dimension: `dim_nv+9`
115    pub h_vec: Vec<ProjectivePoint>,
116
117    /// Dimension: `dim_nm * dim_nw`
118    pub W_m: Vec<Vec<Scalar>>,
119    /// Dimension: `dim_nl * dim_nw`
120    pub W_l: Vec<Vec<Scalar>>,
121
122    /// Dimension: `dim_nm`
123    pub a_m: Vec<Scalar>,
124    /// Dimension: `dim_nl`
125    pub a_l: Vec<Scalar>,
126
127    pub f_l: bool,
128    pub f_m: bool,
129
130    /// Vector of points that will be used in WNLA protocol.
131    /// Dimension: `2^n - dim_nm`
132    pub g_vec_: Vec<ProjectivePoint>,
133    /// Vector of points that will be used in WNLA protocol.
134    /// Dimension: `2^n - (dim_nv+9)`
135    pub h_vec_: Vec<ProjectivePoint>,
136
137    /// Partition function to map `w_o` and corresponding parts of `W_m` and `W_l`
138    pub partition: P,
139}
140
141impl<P> ArithmeticCircuit<P>
142    where
143        P: Fn(PartitionType, usize) -> Option<usize>
144{
145    /// Creates commitment to the arithmetic circuit witness.
146    pub fn commit(&self, v: &[Scalar], s: &Scalar) -> ProjectivePoint {
147        self.
148            g.mul(&v[0]).
149            add(&self.h_vec[0].mul(s)).
150            add(&vector_mul(&self.h_vec[9..], &v[1..]))
151    }
152
153    /// Verifies arithmetic circuit proof with respect to the `v` commitments vector.
154    pub fn verify(&self, v: &[ProjectivePoint], t: &mut Transcript, proof: Proof) -> bool {
155        transcript::app_point(b"commitment_cl", &proof.c_l, t);
156        transcript::app_point(b"commitment_cr", &proof.c_r, t);
157        transcript::app_point(b"commitment_co", &proof.c_o, t);
158
159        v.iter().for_each(|v_val| transcript::app_point(b"commitment_v", v_val, t));
160
161        let rho = transcript::get_challenge(b"circuit_rho", t);
162        let lambda = transcript::get_challenge(b"circuit_lambda", t);
163        let beta = transcript::get_challenge(b"circuit_beta", t);
164        let delta = transcript::get_challenge(b"circuit_delta", t);
165
166        let mu = rho.mul(rho);
167
168        let lambda_vec = self.collect_lambda(&lambda, &mu);
169        let mu_vec = vector_mul_on_scalar(&e(&mu, self.dim_nm), &mu);
170
171        let (
172            c_nL,
173            c_nR,
174            c_nO,
175            c_lL,
176            c_lR,
177            c_lO
178        ) = self.collect_c(&lambda_vec, &mu_vec, &mu);
179
180        let two = Scalar::from(2u32);
181
182        let mut v_ = ProjectivePoint::IDENTITY;
183        (0..self.k).
184            for_each(|i|
185                v_ = v_.add(v[i].mul(self.linear_comb_coef(i, &lambda, &mu)))
186            );
187        v_ = v_.mul(&two);
188
189        transcript::app_point(b"commitment_cs", &proof.c_s, t);
190
191        let tau = transcript::get_challenge(b"circuit_tau", t);
192        let tau_inv = tau.invert_vartime().unwrap();
193        let tau2 = tau.mul(&tau);
194        let tau3 = tau2.mul(&tau);
195
196        let delta_inv = delta.invert_vartime().unwrap();
197
198        let mut pn_tau = vector_mul_on_scalar(&c_nO, &tau3.mul(&delta_inv));
199        pn_tau = vector_sub(&pn_tau, &vector_mul_on_scalar(&c_nL, &tau2));
200        pn_tau = vector_add(&pn_tau, &vector_mul_on_scalar(&c_nR, &tau));
201
202        let ps_tau = weight_vector_mul(&pn_tau, &pn_tau, &mu).
203            add(&vector_mul(&lambda_vec, &self.a_l).mul(&tau3).mul(&two)).
204            sub(&vector_mul(&mu_vec, &self.a_m).mul(&tau3).mul(&two));
205
206        let pt = self.g.mul(ps_tau).add(vector_mul(&self.g_vec, &pn_tau));
207
208        let cr_tau = vec![
209            Scalar::ONE,
210            tau_inv.mul(beta),
211            tau.mul(beta),
212            tau2.mul(beta),
213            tau3.mul(beta),
214            tau.mul(tau3).mul(beta),
215            tau2.mul(tau3).mul(beta),
216            tau3.mul(tau3).mul(beta),
217            tau3.mul(tau3).mul(tau).mul(beta),
218        ];
219
220        let c_l0 = self.collect_cl0(&lambda, &mu);
221
222        let mut cl_tau = vector_mul_on_scalar(&c_lO, &tau3.mul(&delta_inv));
223        cl_tau = vector_sub(&cl_tau, &vector_mul_on_scalar(&c_lL, &tau2));
224        cl_tau = vector_add(&cl_tau, &vector_mul_on_scalar(&c_lR, &tau));
225        cl_tau = vector_mul_on_scalar(&cl_tau, &two);
226        cl_tau = vector_sub(&cl_tau, &c_l0);
227
228        let mut c = [&cr_tau[..], &cl_tau[..]].concat();
229
230        let commitment = pt.
231            add(&proof.c_s.mul(&tau_inv)).
232            sub(&proof.c_o.mul(&delta)).
233            add(&proof.c_l.mul(&tau)).
234            sub(&proof.c_r.mul(&tau2)).
235            add(&v_.mul(&tau3));
236
237        while c.len() < self.h_vec.len() + self.h_vec_.len() {
238            c.push(Scalar::ZERO);
239        }
240
241        let wnla = WeightNormLinearArgument {
242            g: self.g,
243            g_vec: [&self.g_vec[..], &self.g_vec_[..]].concat(),
244            h_vec: [&self.h_vec[..], &self.h_vec_[..]].concat(),
245            c,
246            rho,
247            mu,
248        };
249
250        wnla.verify(&commitment, t, wnla::Proof {
251            r: proof.r,
252            x: proof.x,
253            l: proof.l,
254            n: proof.n,
255        })
256    }
257
258    /// Creates arithmetic circuit proof for the corresponding witness. Also, `v` commitments vector
259    /// should correspond input witness in `witness` argument.
260    pub fn prove<R>(&self, v: &[ProjectivePoint], witness: Witness, t: &mut Transcript, rng: &mut R) -> Proof
261        where
262            R: RngCore + CryptoRng
263    {
264        let ro = vec![
265            Scalar::generate_biased(rng),
266            Scalar::generate_biased(rng),
267            Scalar::generate_biased(rng),
268            Scalar::generate_biased(rng),
269            Scalar::ZERO,
270            Scalar::generate_biased(rng),
271            Scalar::generate_biased(rng),
272            Scalar::generate_biased(rng),
273            Scalar::ZERO,
274        ];
275
276        let rl = vec![
277            Scalar::generate_biased(rng),
278            Scalar::generate_biased(rng),
279            Scalar::generate_biased(rng),
280            Scalar::ZERO,
281            Scalar::generate_biased(rng),
282            Scalar::generate_biased(rng),
283            Scalar::generate_biased(rng),
284            Scalar::ZERO,
285            Scalar::ZERO,
286        ];
287
288        let rr = vec![
289            Scalar::generate_biased(rng),
290            Scalar::generate_biased(rng),
291            Scalar::ZERO,
292            Scalar::generate_biased(rng),
293            Scalar::generate_biased(rng),
294            Scalar::generate_biased(rng),
295            Scalar::ZERO,
296            Scalar::ZERO,
297            Scalar::ZERO,
298        ];
299
300        let nl = witness.w_l;
301        let nr = witness.w_r;
302
303        let no = (0..self.dim_nm).map(|j|
304            if let Some(i) = (self.partition)(PartitionType::NO, j) {
305                witness.w_o[i]
306            } else {
307                Scalar::ZERO
308            }
309        ).collect::<Vec<Scalar>>();
310
311        let lo = (0..self.dim_nv).map(|j|
312            if let Some(i) = (self.partition)(PartitionType::LO, j) {
313                witness.w_o[i]
314            } else {
315                Scalar::ZERO
316            }
317        ).collect::<Vec<Scalar>>();
318
319        let ll = (0..self.dim_nv).map(|j|
320            if let Some(i) = (self.partition)(PartitionType::LL, j) {
321                witness.w_o[i]
322            } else {
323                Scalar::ZERO
324            }
325        ).collect::<Vec<Scalar>>();
326
327        let lr = (0..self.dim_nv).map(|j|
328            if let Some(i) = (self.partition)(PartitionType::LR, j) {
329                witness.w_o[i]
330            } else {
331                Scalar::ZERO
332            }
333        ).collect::<Vec<Scalar>>();
334
335        let co =
336            vector_mul(&self.h_vec, &[&ro[..], &lo[..]].concat()).
337                add(vector_mul(&self.g_vec, &no));
338
339        let cl =
340            vector_mul(&self.h_vec, &[&rl[..], &ll[..]].concat()).
341                add(vector_mul(&self.g_vec, &nl));
342
343        let cr =
344            vector_mul(&self.h_vec, &[&rr[..], &lr[..]].concat()).
345                add(vector_mul(&self.g_vec, &nr));
346
347        transcript::app_point(b"commitment_cl", &cl, t);
348        transcript::app_point(b"commitment_cr", &cr, t);
349        transcript::app_point(b"commitment_co", &co, t);
350        v.iter().for_each(|v_val| transcript::app_point(b"commitment_v", v_val, t));
351
352        let rho = transcript::get_challenge(b"circuit_rho", t);
353        let lambda = transcript::get_challenge(b"circuit_lambda", t);
354        let beta = transcript::get_challenge(b"circuit_beta", t);
355        let delta = transcript::get_challenge(b"circuit_delta", t);
356
357        let mu = rho.mul(rho);
358
359        let lambda_vec = self.collect_lambda(&lambda, &mu);
360        let mu_vec = vector_mul_on_scalar(&e(&mu, self.dim_nm), &mu);
361
362        let (
363            c_nL,
364            c_nR,
365            c_nO,
366            c_lL,
367            c_lR,
368            c_lO
369        ) = self.collect_c(&lambda_vec, &mu_vec, &mu);
370
371        let ls = (0..self.dim_nv).map(|_| Scalar::generate_biased(rng)).collect::<Vec<Scalar>>();
372        let ns = (0..self.dim_nm).map(|_| Scalar::generate_biased(rng)).collect::<Vec<Scalar>>();
373
374        let two = Scalar::from(2u32);
375
376        let mut v_0 = Scalar::ZERO;
377        (0..self.k).
378            for_each(|i|
379                v_0 = v_0.add(witness.v[i][0].mul(self.linear_comb_coef(i, &lambda, &mu)))
380            );
381        v_0 = v_0.mul(&two);
382
383        let mut rv = vec![Scalar::ZERO; 9];
384        (0..self.k).
385            for_each(|i|
386                rv[0] = rv[0].add(witness.s_v[i].mul(self.linear_comb_coef(i, &lambda, &mu)))
387            );
388        rv[0] = rv[0].mul(&two);
389
390        let mut v_1 = vec![Scalar::ZERO; self.dim_nv - 1];
391        (0..self.k).
392            for_each(|i|
393                v_1 = vector_add(&v_1, &vector_mul_on_scalar(&witness.v[i][1..], &self.linear_comb_coef(i, &lambda, &mu)))
394            );
395        v_1 = vector_mul_on_scalar(&v_1, &two);
396
397        let c_l0 = self.collect_cl0(&lambda, &mu);
398
399        // [-2 -1 0 1 2 4 5 6] -> f(tau) coefficients vector
400        let mut f_ = vec![Scalar::ZERO; 8];
401
402        let delta2 = delta.mul(&delta);
403        let delta_inv = delta.invert_vartime().unwrap();
404
405        // -2
406        f_[0] = minus(&weight_vector_mul(&ns, &ns, &mu));
407
408        // -1
409        f_[1] = vector_mul(&c_l0, &ls).
410            add(delta.mul(&two).mul(&weight_vector_mul(&ns, &no, &mu)));
411
412        // 0
413        f_[2] = minus(&vector_mul(&c_lR, &ls).mul(&two)).
414            sub(&vector_mul(&c_l0, &lo).mul(&delta)).
415            sub(&weight_vector_mul(&ns, &vector_add(&nl, &c_nR), &mu).mul(&two)).
416            sub(&weight_vector_mul(&no, &no, &mu).mul(&delta2));
417
418        //1
419        f_[3] = vector_mul(&c_lL, &ls).mul(&two).
420            add(&vector_mul(&c_lR, &lo).mul(&delta).mul(&two)).
421            add(&vector_mul(&c_l0, &ll)).
422            add(&weight_vector_mul(&ns, &vector_add(&nr, &c_nL), &mu).mul(&two)).
423            add(&weight_vector_mul(&no, &vector_add(&nl, &c_nR), &mu).mul(&two).mul(&delta));
424
425        // 2
426        f_[4] = weight_vector_mul(&c_nR, &c_nR, &mu).
427            sub(&vector_mul(&c_lO, &ls).mul(&delta_inv).mul(&two)).
428            sub(&vector_mul(&c_lL, &lo).mul(&delta).mul(&two)).
429            sub(&vector_mul(&c_lR, &ll).mul(&two)).
430            sub(&vector_mul(&c_l0, &lr)).
431            sub(&weight_vector_mul(&ns, &c_nO, &mu).mul(&delta_inv).mul(&two)).
432            sub(&weight_vector_mul(&no, &vector_add(&nr, &c_nL), &mu).mul(&delta).mul(&two)).
433            sub(&weight_vector_mul(&vector_add(&nl, &c_nR), &vector_add(&nl, &c_nR), &mu));
434
435        // 3 should be zero
436
437        // 4
438        f_[5] = weight_vector_mul(&c_nO, &c_nR, &mu).mul(&delta_inv).mul(&two).
439            add(&weight_vector_mul(&c_nL, &c_nL, &mu)).
440            sub(&vector_mul(&c_lO, &ll).mul(&delta_inv).mul(&two)).
441            sub(&vector_mul(&c_lL, &lr).mul(&two)).
442            sub(&vector_mul(&c_lR, &v_1).mul(&two)).
443            sub(&weight_vector_mul(&vector_add(&nl, &c_nR), &c_nO, &mu).mul(&delta_inv).mul(&two)).
444            sub(&weight_vector_mul(&vector_add(&nr, &c_nL), &vector_add(&nr, &c_nL), &mu));
445
446        // 5
447        f_[6] = minus(&weight_vector_mul(&c_nO, &c_nL, &mu).mul(&delta_inv).mul(&two)).
448            add(&vector_mul(&c_nO, &lr).mul(&delta_inv).mul(&two)).
449            add(&vector_mul(&c_lL, &v_1).mul(&two)).
450            add(&weight_vector_mul(&vector_add(&nr, &c_nL), &c_nO, &mu).mul(&delta_inv).mul(&two));
451
452        // 6
453        f_[7] = minus(&vector_mul(&c_lO, &v_1).mul(&delta_inv).mul(&two));
454
455        let beta_inv = beta.invert_vartime().unwrap();
456
457        let rs = vec![
458            f_[1].add(ro[1].mul(&delta).mul(&beta)),
459            f_[0].mul(&beta_inv),
460            ro[0].mul(&delta).add(&f_[2]).mul(&beta_inv).sub(&rl[1]),
461            f_[3].sub(&rl[0]).mul(&beta_inv).add(&ro[2].mul(&delta).add(rr[1])),
462            f_[4].add(&rr[0]).mul(&beta_inv).add(&ro[3].mul(&delta).sub(rl[2])),
463            minus(&rv[0].mul(&beta_inv)),
464            f_[5].mul(&beta_inv).add(&ro[5].mul(&delta)).add(&rr[3]).sub(&rl[4]),
465            f_[6].mul(&beta_inv).add(&rr[4]).add(&ro[6].mul(&delta)).sub(&rl[5]),
466            f_[7].mul(&beta_inv).add(&ro[7].mul(&delta)).sub(&rl[6]).add(&rr[5]),
467        ];
468
469        let cs = vector_mul(&self.h_vec, &[&rs[..], &ls[..]].concat()).
470            add(vector_mul(&self.g_vec, &ns));
471
472        transcript::app_point(b"commitment_cs", &cs, t);
473
474        let tau = transcript::get_challenge(b"circuit_tau", t);
475        let tau_inv = tau.invert_vartime().unwrap();
476        let tau2 = tau.mul(&tau);
477        let tau3 = tau2.mul(&tau);
478
479        let mut l = vector_mul_on_scalar(&[&rs[..], &ls[..]].concat(), &tau_inv);
480        l = vector_sub(&l, &vector_mul_on_scalar(&[&ro[..], &lo[..]].concat(), &delta));
481        l = vector_add(&l, &vector_mul_on_scalar(&[&rl[..], &ll[..]].concat(), &tau));
482        l = vector_sub(&l, &vector_mul_on_scalar(&[&rr[..], &lr[..]].concat(), &tau2));
483        l = vector_add(&l, &vector_mul_on_scalar(&[&rv[..], &v_1[..]].concat(), &tau3));
484
485        let mut pn_tau = vector_mul_on_scalar(&c_nO, &tau3.mul(&delta_inv));
486        pn_tau = vector_sub(&pn_tau, &vector_mul_on_scalar(&c_nL, &tau2));
487        pn_tau = vector_add(&pn_tau, &vector_mul_on_scalar(&c_nR, &tau));
488
489        let ps_tau = weight_vector_mul(&pn_tau, &pn_tau, &mu).
490            add(&vector_mul(&lambda_vec, &self.a_l).mul(&tau3).mul(&two)).
491            sub(&vector_mul(&mu_vec, &self.a_m).mul(&tau3).mul(&two));
492
493        let mut n_tau = vector_mul_on_scalar(&ns, &tau_inv);
494        n_tau = vector_sub(&n_tau, &vector_mul_on_scalar(&no, &delta));
495        n_tau = vector_add(&n_tau, &vector_mul_on_scalar(&nl, &tau));
496        n_tau = vector_sub(&n_tau, &vector_mul_on_scalar(&nr, &tau2));
497
498        let mut n = vector_add(&pn_tau, &n_tau);
499
500        let cr_tau = vec![
501            Scalar::ONE,
502            tau_inv.mul(beta),
503            tau.mul(beta),
504            tau2.mul(beta),
505            tau3.mul(beta),
506            tau.mul(tau3).mul(beta),
507            tau2.mul(tau3).mul(beta),
508            tau3.mul(tau3).mul(beta),
509            tau3.mul(tau3).mul(tau).mul(beta),
510        ];
511
512        let mut cl_tau = vector_mul_on_scalar(&c_lO, &tau3.mul(&delta_inv));
513        cl_tau = vector_sub(&cl_tau, &vector_mul_on_scalar(&c_lL, &tau2));
514        cl_tau = vector_add(&cl_tau, &vector_mul_on_scalar(&c_lR, &tau));
515        cl_tau = vector_mul_on_scalar(&cl_tau, &two);
516        cl_tau = vector_sub(&cl_tau, &c_l0);
517
518        let mut c = [&cr_tau[..], &cl_tau[..]].concat();
519
520        let v = ps_tau.add(&tau3.mul(&v_0));
521
522        let commitment = self.g.mul(v).
523            add(&vector_mul(&self.h_vec, &l)).
524            add(&vector_mul(&self.g_vec, &n));
525
526        while l.len() < self.h_vec.len() + self.h_vec_.len() {
527            l.push(Scalar::ZERO);
528            c.push(Scalar::ZERO);
529        }
530
531        while n.len() < self.g_vec.len() + self.g_vec_.len() {
532            n.push(Scalar::ZERO);
533        }
534
535        let wnla = WeightNormLinearArgument {
536            g: self.g,
537            g_vec: [&self.g_vec[..], &self.g_vec_[..]].concat(),
538            h_vec: [&self.h_vec[..], &self.h_vec_[..]].concat(),
539            c,
540            rho,
541            mu,
542        };
543
544        let proof_wnla = wnla.prove(&commitment, t, l, n);
545
546        Proof {
547            c_l: cl,
548            c_r: cr,
549            c_o: co,
550            c_s: cs,
551            r: proof_wnla.r,
552            x: proof_wnla.x,
553            l: proof_wnla.l,
554            n: proof_wnla.n,
555        }
556    }
557
558
559    fn linear_comb_coef(&self, i: usize, lambda: &Scalar, mu: &Scalar) -> Scalar {
560        let mut coef = Scalar::ZERO;
561        if self.f_l {
562            coef = coef.add(pow(lambda, self.dim_nv * i))
563        }
564
565        if self.f_m {
566            coef = coef.add(pow(mu, self.dim_nv * i + 1))
567        }
568
569        coef
570    }
571
572    fn collect_cl0(&self, lambda: &Scalar, mu: &Scalar) -> Vec<Scalar> {
573        let mut c_l0 = vec![Scalar::ZERO; self.dim_nv - 1];
574        if self.f_l {
575            c_l0 = e(lambda, self.dim_nv)[1..].to_vec();
576        }
577        if self.f_m {
578            c_l0 = vector_sub(&c_l0, &vector_mul_on_scalar(&e(mu, self.dim_nv)[1..], mu));
579        }
580
581        c_l0
582    }
583
584    fn collect_c(&self, lambda_vec: &[Scalar], mu_vec: &[Scalar], mu: &Scalar) -> (Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>) {
585        let (M_lnL, M_mnL, M_lnR, M_mnR) = self.collect_m_rl();
586        let (M_lnO, M_mnO, M_llL, M_mlL, M_llR, M_mlR, M_llO, M_mlO) = self.collect_m_o();
587
588        let mu_diag_inv = diag_inv(mu, self.dim_nm);
589
590        let c_nL = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnL), &vector_mul_on_matrix(mu_vec, &M_mnL)), &mu_diag_inv);
591        let c_nR = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnR), &vector_mul_on_matrix(mu_vec, &M_mnR)), &mu_diag_inv);
592        let c_nO = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnO), &vector_mul_on_matrix(mu_vec, &M_mnO)), &mu_diag_inv);
593
594        let c_lL = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llL), &vector_mul_on_matrix(mu_vec, &M_mlL));
595        let c_lR = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llR), &vector_mul_on_matrix(mu_vec, &M_mlR));
596        let c_lO = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llO), &vector_mul_on_matrix(mu_vec, &M_mlO));
597
598        (c_nL, c_nR, c_nO, c_lL, c_lR, c_lO)
599    }
600
601    fn collect_lambda(&self, lambda: &Scalar, mu: &Scalar) -> Vec<Scalar> {
602        let mut lambda_vec = e(lambda, self.dim_nl);
603        if self.f_l && self.f_m {
604            lambda_vec = vector_sub(
605                &lambda_vec,
606                &vector_add(
607                    &vector_tensor_mul(&vector_mul_on_scalar(&e(lambda, self.dim_nv), mu), &e(&pow(mu, self.dim_nv), self.k)),
608                    &vector_tensor_mul(&e(mu, self.dim_nv), &e(&pow(lambda, self.dim_nv), self.k)),
609                ),
610            );
611        }
612
613        lambda_vec
614    }
615
616    fn collect_m_rl(&self) -> (Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>) {
617        let M_lnL = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][..self.dim_nm])).collect::<Vec<Vec<Scalar>>>();
618        let M_mnL = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][..self.dim_nm])).collect::<Vec<Vec<Scalar>>>();
619        let M_lnR = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][self.dim_nm..self.dim_nm * 2])).collect::<Vec<Vec<Scalar>>>();
620        let M_mnR = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][self.dim_nm..self.dim_nm * 2])).collect::<Vec<Vec<Scalar>>>();
621        (M_lnL, M_mnL, M_lnR, M_mnR)
622    }
623
624    fn collect_m_o(&self) -> (Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>) {
625        let W_lO = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][self.dim_nm * 2..])).collect::<Vec<Vec<Scalar>>>();
626        let W_mO = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][self.dim_nm * 2..])).collect::<Vec<Vec<Scalar>>>();
627
628        let map_f = |isz: usize, jsz: usize, typ: PartitionType, W_x: &Vec<Vec<Scalar>>| -> Vec<Vec<Scalar>>{
629            (0..isz).map(|i|
630                (0..jsz).map(|j|
631                    if let Some(j_) = (self.partition)(typ, j) {
632                       W_x[i][j_]
633                    } else {
634                        Scalar::ZERO
635                    }
636                ).collect::<Vec<Scalar>>()
637            ).collect::<Vec<Vec<Scalar>>>()
638        };
639
640        let M_lnO = map_f(self.dim_nl, self.dim_nm, PartitionType::NO, &W_lO);
641        let M_llL = map_f(self.dim_nl, self.dim_nv, PartitionType::LL, &W_lO);
642        let M_llR = map_f(self.dim_nl, self.dim_nv, PartitionType::LR, &W_lO);
643        let M_llO = map_f(self.dim_nl, self.dim_nv, PartitionType::LO, &W_lO);
644
645
646        let M_mnO = map_f(self.dim_nm, self.dim_nm, PartitionType::NO, &W_mO);
647        let M_mlL = map_f(self.dim_nm, self.dim_nv, PartitionType::LL, &W_mO);
648        let M_mlR = map_f(self.dim_nm, self.dim_nv, PartitionType::LR, &W_mO);
649        let M_mlO = map_f(self.dim_nm, self.dim_nv, PartitionType::LO, &W_mO);
650
651
652        (M_lnO, M_mnO, M_llL, M_mlL, M_llR, M_mlR, M_llO, M_mlO)
653    }
654}