1#![allow(non_snake_case)]
2use std::ops::{Add, Mul, Sub};
5use k256::{AffinePoint, ProjectivePoint, Scalar};
6use k256::elliptic_curve::ops::Invert;
7use k256::elliptic_curve::rand_core::{CryptoRng, RngCore};
8use merlin::Transcript;
9use serde::{Deserialize, Serialize};
10use crate::util::*;
11use crate::{transcript, wnla};
12use crate::wnla::WeightNormLinearArgument;
13
14#[derive(Clone, Debug, Copy, PartialEq)]
15pub enum PartitionType {
16 LO,
17 LL,
18 LR,
19 NO,
20}
21
22#[derive(Clone, Debug)]
24pub struct Proof {
25 pub c_l: ProjectivePoint,
26 pub c_r: ProjectivePoint,
27 pub c_o: ProjectivePoint,
28 pub c_s: ProjectivePoint,
29 pub r: Vec<ProjectivePoint>,
30 pub x: Vec<ProjectivePoint>,
31 pub l: Vec<Scalar>,
32 pub n: Vec<Scalar>,
33}
34
35#[derive(Serialize, Deserialize, Clone, Debug)]
37pub struct SerializableProof {
38 pub c_l: AffinePoint,
39 pub c_r: AffinePoint,
40 pub c_o: AffinePoint,
41 pub c_s: AffinePoint,
42 pub r: Vec<AffinePoint>,
43 pub x: Vec<AffinePoint>,
44 pub l: Vec<Scalar>,
45 pub n: Vec<Scalar>,
46}
47
48impl From<&SerializableProof> for Proof {
49 fn from(value: &SerializableProof) -> Self {
50 return Proof {
51 c_l: ProjectivePoint::from(&value.c_l),
52 c_r: ProjectivePoint::from(&value.c_r),
53 c_o: ProjectivePoint::from(&value.c_o),
54 c_s: ProjectivePoint::from(&value.c_s),
55 r: value.r.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
56 x: value.x.iter().map(ProjectivePoint::from).collect::<Vec<ProjectivePoint>>(),
57 l: value.l.clone(),
58 n: value.n.clone(),
59 };
60 }
61}
62
63impl From<&Proof> for SerializableProof {
64 fn from(value: &Proof) -> Self {
65 return SerializableProof {
66 c_l: value.c_l.to_affine(),
67 c_r: value.c_r.to_affine(),
68 c_o: value.c_o.to_affine(),
69 c_s: value.c_s.to_affine(),
70 r: value.r.iter().map(|r_val| r_val.to_affine()).collect::<Vec<AffinePoint>>(),
71 x: value.x.iter().map(|x_val| x_val.to_affine()).collect::<Vec<AffinePoint>>(),
72 l: value.l.clone(),
73 n: value.n.clone(),
74 };
75 }
76}
77
78#[derive(Clone, Debug)]
80pub struct Witness {
81 pub v: Vec<Vec<Scalar>>,
83 pub s_v: Vec<Scalar>,
85 pub w_l: Vec<Scalar>,
87 pub w_r: Vec<Scalar>,
89 pub w_o: Vec<Scalar>,
91}
92
93pub struct ArithmeticCircuit<P>
96 where
97 P: Fn(PartitionType, usize) -> Option<usize>
98{
99 pub dim_nm: usize,
100 pub dim_no: usize,
101 pub k: usize,
102
103 pub dim_nl: usize,
105 pub dim_nv: usize,
107 pub dim_nw: usize,
109
110 pub g: ProjectivePoint,
111
112 pub g_vec: Vec<ProjectivePoint>,
114 pub h_vec: Vec<ProjectivePoint>,
116
117 pub W_m: Vec<Vec<Scalar>>,
119 pub W_l: Vec<Vec<Scalar>>,
121
122 pub a_m: Vec<Scalar>,
124 pub a_l: Vec<Scalar>,
126
127 pub f_l: bool,
128 pub f_m: bool,
129
130 pub g_vec_: Vec<ProjectivePoint>,
133 pub h_vec_: Vec<ProjectivePoint>,
136
137 pub partition: P,
139}
140
141impl<P> ArithmeticCircuit<P>
142 where
143 P: Fn(PartitionType, usize) -> Option<usize>
144{
145 pub fn commit(&self, v: &[Scalar], s: &Scalar) -> ProjectivePoint {
147 self.
148 g.mul(&v[0]).
149 add(&self.h_vec[0].mul(s)).
150 add(&vector_mul(&self.h_vec[9..], &v[1..]))
151 }
152
153 pub fn verify(&self, v: &[ProjectivePoint], t: &mut Transcript, proof: Proof) -> bool {
155 transcript::app_point(b"commitment_cl", &proof.c_l, t);
156 transcript::app_point(b"commitment_cr", &proof.c_r, t);
157 transcript::app_point(b"commitment_co", &proof.c_o, t);
158
159 v.iter().for_each(|v_val| transcript::app_point(b"commitment_v", v_val, t));
160
161 let rho = transcript::get_challenge(b"circuit_rho", t);
162 let lambda = transcript::get_challenge(b"circuit_lambda", t);
163 let beta = transcript::get_challenge(b"circuit_beta", t);
164 let delta = transcript::get_challenge(b"circuit_delta", t);
165
166 let mu = rho.mul(rho);
167
168 let lambda_vec = self.collect_lambda(&lambda, &mu);
169 let mu_vec = vector_mul_on_scalar(&e(&mu, self.dim_nm), &mu);
170
171 let (
172 c_nL,
173 c_nR,
174 c_nO,
175 c_lL,
176 c_lR,
177 c_lO
178 ) = self.collect_c(&lambda_vec, &mu_vec, &mu);
179
180 let two = Scalar::from(2u32);
181
182 let mut v_ = ProjectivePoint::IDENTITY;
183 (0..self.k).
184 for_each(|i|
185 v_ = v_.add(v[i].mul(self.linear_comb_coef(i, &lambda, &mu)))
186 );
187 v_ = v_.mul(&two);
188
189 transcript::app_point(b"commitment_cs", &proof.c_s, t);
190
191 let tau = transcript::get_challenge(b"circuit_tau", t);
192 let tau_inv = tau.invert_vartime().unwrap();
193 let tau2 = tau.mul(&tau);
194 let tau3 = tau2.mul(&tau);
195
196 let delta_inv = delta.invert_vartime().unwrap();
197
198 let mut pn_tau = vector_mul_on_scalar(&c_nO, &tau3.mul(&delta_inv));
199 pn_tau = vector_sub(&pn_tau, &vector_mul_on_scalar(&c_nL, &tau2));
200 pn_tau = vector_add(&pn_tau, &vector_mul_on_scalar(&c_nR, &tau));
201
202 let ps_tau = weight_vector_mul(&pn_tau, &pn_tau, &mu).
203 add(&vector_mul(&lambda_vec, &self.a_l).mul(&tau3).mul(&two)).
204 sub(&vector_mul(&mu_vec, &self.a_m).mul(&tau3).mul(&two));
205
206 let pt = self.g.mul(ps_tau).add(vector_mul(&self.g_vec, &pn_tau));
207
208 let cr_tau = vec![
209 Scalar::ONE,
210 tau_inv.mul(beta),
211 tau.mul(beta),
212 tau2.mul(beta),
213 tau3.mul(beta),
214 tau.mul(tau3).mul(beta),
215 tau2.mul(tau3).mul(beta),
216 tau3.mul(tau3).mul(beta),
217 tau3.mul(tau3).mul(tau).mul(beta),
218 ];
219
220 let c_l0 = self.collect_cl0(&lambda, &mu);
221
222 let mut cl_tau = vector_mul_on_scalar(&c_lO, &tau3.mul(&delta_inv));
223 cl_tau = vector_sub(&cl_tau, &vector_mul_on_scalar(&c_lL, &tau2));
224 cl_tau = vector_add(&cl_tau, &vector_mul_on_scalar(&c_lR, &tau));
225 cl_tau = vector_mul_on_scalar(&cl_tau, &two);
226 cl_tau = vector_sub(&cl_tau, &c_l0);
227
228 let mut c = [&cr_tau[..], &cl_tau[..]].concat();
229
230 let commitment = pt.
231 add(&proof.c_s.mul(&tau_inv)).
232 sub(&proof.c_o.mul(&delta)).
233 add(&proof.c_l.mul(&tau)).
234 sub(&proof.c_r.mul(&tau2)).
235 add(&v_.mul(&tau3));
236
237 while c.len() < self.h_vec.len() + self.h_vec_.len() {
238 c.push(Scalar::ZERO);
239 }
240
241 let wnla = WeightNormLinearArgument {
242 g: self.g,
243 g_vec: [&self.g_vec[..], &self.g_vec_[..]].concat(),
244 h_vec: [&self.h_vec[..], &self.h_vec_[..]].concat(),
245 c,
246 rho,
247 mu,
248 };
249
250 wnla.verify(&commitment, t, wnla::Proof {
251 r: proof.r,
252 x: proof.x,
253 l: proof.l,
254 n: proof.n,
255 })
256 }
257
258 pub fn prove<R>(&self, v: &[ProjectivePoint], witness: Witness, t: &mut Transcript, rng: &mut R) -> Proof
261 where
262 R: RngCore + CryptoRng
263 {
264 let ro = vec![
265 Scalar::generate_biased(rng),
266 Scalar::generate_biased(rng),
267 Scalar::generate_biased(rng),
268 Scalar::generate_biased(rng),
269 Scalar::ZERO,
270 Scalar::generate_biased(rng),
271 Scalar::generate_biased(rng),
272 Scalar::generate_biased(rng),
273 Scalar::ZERO,
274 ];
275
276 let rl = vec![
277 Scalar::generate_biased(rng),
278 Scalar::generate_biased(rng),
279 Scalar::generate_biased(rng),
280 Scalar::ZERO,
281 Scalar::generate_biased(rng),
282 Scalar::generate_biased(rng),
283 Scalar::generate_biased(rng),
284 Scalar::ZERO,
285 Scalar::ZERO,
286 ];
287
288 let rr = vec![
289 Scalar::generate_biased(rng),
290 Scalar::generate_biased(rng),
291 Scalar::ZERO,
292 Scalar::generate_biased(rng),
293 Scalar::generate_biased(rng),
294 Scalar::generate_biased(rng),
295 Scalar::ZERO,
296 Scalar::ZERO,
297 Scalar::ZERO,
298 ];
299
300 let nl = witness.w_l;
301 let nr = witness.w_r;
302
303 let no = (0..self.dim_nm).map(|j|
304 if let Some(i) = (self.partition)(PartitionType::NO, j) {
305 witness.w_o[i]
306 } else {
307 Scalar::ZERO
308 }
309 ).collect::<Vec<Scalar>>();
310
311 let lo = (0..self.dim_nv).map(|j|
312 if let Some(i) = (self.partition)(PartitionType::LO, j) {
313 witness.w_o[i]
314 } else {
315 Scalar::ZERO
316 }
317 ).collect::<Vec<Scalar>>();
318
319 let ll = (0..self.dim_nv).map(|j|
320 if let Some(i) = (self.partition)(PartitionType::LL, j) {
321 witness.w_o[i]
322 } else {
323 Scalar::ZERO
324 }
325 ).collect::<Vec<Scalar>>();
326
327 let lr = (0..self.dim_nv).map(|j|
328 if let Some(i) = (self.partition)(PartitionType::LR, j) {
329 witness.w_o[i]
330 } else {
331 Scalar::ZERO
332 }
333 ).collect::<Vec<Scalar>>();
334
335 let co =
336 vector_mul(&self.h_vec, &[&ro[..], &lo[..]].concat()).
337 add(vector_mul(&self.g_vec, &no));
338
339 let cl =
340 vector_mul(&self.h_vec, &[&rl[..], &ll[..]].concat()).
341 add(vector_mul(&self.g_vec, &nl));
342
343 let cr =
344 vector_mul(&self.h_vec, &[&rr[..], &lr[..]].concat()).
345 add(vector_mul(&self.g_vec, &nr));
346
347 transcript::app_point(b"commitment_cl", &cl, t);
348 transcript::app_point(b"commitment_cr", &cr, t);
349 transcript::app_point(b"commitment_co", &co, t);
350 v.iter().for_each(|v_val| transcript::app_point(b"commitment_v", v_val, t));
351
352 let rho = transcript::get_challenge(b"circuit_rho", t);
353 let lambda = transcript::get_challenge(b"circuit_lambda", t);
354 let beta = transcript::get_challenge(b"circuit_beta", t);
355 let delta = transcript::get_challenge(b"circuit_delta", t);
356
357 let mu = rho.mul(rho);
358
359 let lambda_vec = self.collect_lambda(&lambda, &mu);
360 let mu_vec = vector_mul_on_scalar(&e(&mu, self.dim_nm), &mu);
361
362 let (
363 c_nL,
364 c_nR,
365 c_nO,
366 c_lL,
367 c_lR,
368 c_lO
369 ) = self.collect_c(&lambda_vec, &mu_vec, &mu);
370
371 let ls = (0..self.dim_nv).map(|_| Scalar::generate_biased(rng)).collect::<Vec<Scalar>>();
372 let ns = (0..self.dim_nm).map(|_| Scalar::generate_biased(rng)).collect::<Vec<Scalar>>();
373
374 let two = Scalar::from(2u32);
375
376 let mut v_0 = Scalar::ZERO;
377 (0..self.k).
378 for_each(|i|
379 v_0 = v_0.add(witness.v[i][0].mul(self.linear_comb_coef(i, &lambda, &mu)))
380 );
381 v_0 = v_0.mul(&two);
382
383 let mut rv = vec![Scalar::ZERO; 9];
384 (0..self.k).
385 for_each(|i|
386 rv[0] = rv[0].add(witness.s_v[i].mul(self.linear_comb_coef(i, &lambda, &mu)))
387 );
388 rv[0] = rv[0].mul(&two);
389
390 let mut v_1 = vec![Scalar::ZERO; self.dim_nv - 1];
391 (0..self.k).
392 for_each(|i|
393 v_1 = vector_add(&v_1, &vector_mul_on_scalar(&witness.v[i][1..], &self.linear_comb_coef(i, &lambda, &mu)))
394 );
395 v_1 = vector_mul_on_scalar(&v_1, &two);
396
397 let c_l0 = self.collect_cl0(&lambda, &mu);
398
399 let mut f_ = vec![Scalar::ZERO; 8];
401
402 let delta2 = delta.mul(&delta);
403 let delta_inv = delta.invert_vartime().unwrap();
404
405 f_[0] = minus(&weight_vector_mul(&ns, &ns, &mu));
407
408 f_[1] = vector_mul(&c_l0, &ls).
410 add(delta.mul(&two).mul(&weight_vector_mul(&ns, &no, &mu)));
411
412 f_[2] = minus(&vector_mul(&c_lR, &ls).mul(&two)).
414 sub(&vector_mul(&c_l0, &lo).mul(&delta)).
415 sub(&weight_vector_mul(&ns, &vector_add(&nl, &c_nR), &mu).mul(&two)).
416 sub(&weight_vector_mul(&no, &no, &mu).mul(&delta2));
417
418 f_[3] = vector_mul(&c_lL, &ls).mul(&two).
420 add(&vector_mul(&c_lR, &lo).mul(&delta).mul(&two)).
421 add(&vector_mul(&c_l0, &ll)).
422 add(&weight_vector_mul(&ns, &vector_add(&nr, &c_nL), &mu).mul(&two)).
423 add(&weight_vector_mul(&no, &vector_add(&nl, &c_nR), &mu).mul(&two).mul(&delta));
424
425 f_[4] = weight_vector_mul(&c_nR, &c_nR, &mu).
427 sub(&vector_mul(&c_lO, &ls).mul(&delta_inv).mul(&two)).
428 sub(&vector_mul(&c_lL, &lo).mul(&delta).mul(&two)).
429 sub(&vector_mul(&c_lR, &ll).mul(&two)).
430 sub(&vector_mul(&c_l0, &lr)).
431 sub(&weight_vector_mul(&ns, &c_nO, &mu).mul(&delta_inv).mul(&two)).
432 sub(&weight_vector_mul(&no, &vector_add(&nr, &c_nL), &mu).mul(&delta).mul(&two)).
433 sub(&weight_vector_mul(&vector_add(&nl, &c_nR), &vector_add(&nl, &c_nR), &mu));
434
435 f_[5] = weight_vector_mul(&c_nO, &c_nR, &mu).mul(&delta_inv).mul(&two).
439 add(&weight_vector_mul(&c_nL, &c_nL, &mu)).
440 sub(&vector_mul(&c_lO, &ll).mul(&delta_inv).mul(&two)).
441 sub(&vector_mul(&c_lL, &lr).mul(&two)).
442 sub(&vector_mul(&c_lR, &v_1).mul(&two)).
443 sub(&weight_vector_mul(&vector_add(&nl, &c_nR), &c_nO, &mu).mul(&delta_inv).mul(&two)).
444 sub(&weight_vector_mul(&vector_add(&nr, &c_nL), &vector_add(&nr, &c_nL), &mu));
445
446 f_[6] = minus(&weight_vector_mul(&c_nO, &c_nL, &mu).mul(&delta_inv).mul(&two)).
448 add(&vector_mul(&c_nO, &lr).mul(&delta_inv).mul(&two)).
449 add(&vector_mul(&c_lL, &v_1).mul(&two)).
450 add(&weight_vector_mul(&vector_add(&nr, &c_nL), &c_nO, &mu).mul(&delta_inv).mul(&two));
451
452 f_[7] = minus(&vector_mul(&c_lO, &v_1).mul(&delta_inv).mul(&two));
454
455 let beta_inv = beta.invert_vartime().unwrap();
456
457 let rs = vec![
458 f_[1].add(ro[1].mul(&delta).mul(&beta)),
459 f_[0].mul(&beta_inv),
460 ro[0].mul(&delta).add(&f_[2]).mul(&beta_inv).sub(&rl[1]),
461 f_[3].sub(&rl[0]).mul(&beta_inv).add(&ro[2].mul(&delta).add(rr[1])),
462 f_[4].add(&rr[0]).mul(&beta_inv).add(&ro[3].mul(&delta).sub(rl[2])),
463 minus(&rv[0].mul(&beta_inv)),
464 f_[5].mul(&beta_inv).add(&ro[5].mul(&delta)).add(&rr[3]).sub(&rl[4]),
465 f_[6].mul(&beta_inv).add(&rr[4]).add(&ro[6].mul(&delta)).sub(&rl[5]),
466 f_[7].mul(&beta_inv).add(&ro[7].mul(&delta)).sub(&rl[6]).add(&rr[5]),
467 ];
468
469 let cs = vector_mul(&self.h_vec, &[&rs[..], &ls[..]].concat()).
470 add(vector_mul(&self.g_vec, &ns));
471
472 transcript::app_point(b"commitment_cs", &cs, t);
473
474 let tau = transcript::get_challenge(b"circuit_tau", t);
475 let tau_inv = tau.invert_vartime().unwrap();
476 let tau2 = tau.mul(&tau);
477 let tau3 = tau2.mul(&tau);
478
479 let mut l = vector_mul_on_scalar(&[&rs[..], &ls[..]].concat(), &tau_inv);
480 l = vector_sub(&l, &vector_mul_on_scalar(&[&ro[..], &lo[..]].concat(), &delta));
481 l = vector_add(&l, &vector_mul_on_scalar(&[&rl[..], &ll[..]].concat(), &tau));
482 l = vector_sub(&l, &vector_mul_on_scalar(&[&rr[..], &lr[..]].concat(), &tau2));
483 l = vector_add(&l, &vector_mul_on_scalar(&[&rv[..], &v_1[..]].concat(), &tau3));
484
485 let mut pn_tau = vector_mul_on_scalar(&c_nO, &tau3.mul(&delta_inv));
486 pn_tau = vector_sub(&pn_tau, &vector_mul_on_scalar(&c_nL, &tau2));
487 pn_tau = vector_add(&pn_tau, &vector_mul_on_scalar(&c_nR, &tau));
488
489 let ps_tau = weight_vector_mul(&pn_tau, &pn_tau, &mu).
490 add(&vector_mul(&lambda_vec, &self.a_l).mul(&tau3).mul(&two)).
491 sub(&vector_mul(&mu_vec, &self.a_m).mul(&tau3).mul(&two));
492
493 let mut n_tau = vector_mul_on_scalar(&ns, &tau_inv);
494 n_tau = vector_sub(&n_tau, &vector_mul_on_scalar(&no, &delta));
495 n_tau = vector_add(&n_tau, &vector_mul_on_scalar(&nl, &tau));
496 n_tau = vector_sub(&n_tau, &vector_mul_on_scalar(&nr, &tau2));
497
498 let mut n = vector_add(&pn_tau, &n_tau);
499
500 let cr_tau = vec![
501 Scalar::ONE,
502 tau_inv.mul(beta),
503 tau.mul(beta),
504 tau2.mul(beta),
505 tau3.mul(beta),
506 tau.mul(tau3).mul(beta),
507 tau2.mul(tau3).mul(beta),
508 tau3.mul(tau3).mul(beta),
509 tau3.mul(tau3).mul(tau).mul(beta),
510 ];
511
512 let mut cl_tau = vector_mul_on_scalar(&c_lO, &tau3.mul(&delta_inv));
513 cl_tau = vector_sub(&cl_tau, &vector_mul_on_scalar(&c_lL, &tau2));
514 cl_tau = vector_add(&cl_tau, &vector_mul_on_scalar(&c_lR, &tau));
515 cl_tau = vector_mul_on_scalar(&cl_tau, &two);
516 cl_tau = vector_sub(&cl_tau, &c_l0);
517
518 let mut c = [&cr_tau[..], &cl_tau[..]].concat();
519
520 let v = ps_tau.add(&tau3.mul(&v_0));
521
522 let commitment = self.g.mul(v).
523 add(&vector_mul(&self.h_vec, &l)).
524 add(&vector_mul(&self.g_vec, &n));
525
526 while l.len() < self.h_vec.len() + self.h_vec_.len() {
527 l.push(Scalar::ZERO);
528 c.push(Scalar::ZERO);
529 }
530
531 while n.len() < self.g_vec.len() + self.g_vec_.len() {
532 n.push(Scalar::ZERO);
533 }
534
535 let wnla = WeightNormLinearArgument {
536 g: self.g,
537 g_vec: [&self.g_vec[..], &self.g_vec_[..]].concat(),
538 h_vec: [&self.h_vec[..], &self.h_vec_[..]].concat(),
539 c,
540 rho,
541 mu,
542 };
543
544 let proof_wnla = wnla.prove(&commitment, t, l, n);
545
546 Proof {
547 c_l: cl,
548 c_r: cr,
549 c_o: co,
550 c_s: cs,
551 r: proof_wnla.r,
552 x: proof_wnla.x,
553 l: proof_wnla.l,
554 n: proof_wnla.n,
555 }
556 }
557
558
559 fn linear_comb_coef(&self, i: usize, lambda: &Scalar, mu: &Scalar) -> Scalar {
560 let mut coef = Scalar::ZERO;
561 if self.f_l {
562 coef = coef.add(pow(lambda, self.dim_nv * i))
563 }
564
565 if self.f_m {
566 coef = coef.add(pow(mu, self.dim_nv * i + 1))
567 }
568
569 coef
570 }
571
572 fn collect_cl0(&self, lambda: &Scalar, mu: &Scalar) -> Vec<Scalar> {
573 let mut c_l0 = vec![Scalar::ZERO; self.dim_nv - 1];
574 if self.f_l {
575 c_l0 = e(lambda, self.dim_nv)[1..].to_vec();
576 }
577 if self.f_m {
578 c_l0 = vector_sub(&c_l0, &vector_mul_on_scalar(&e(mu, self.dim_nv)[1..], mu));
579 }
580
581 c_l0
582 }
583
584 fn collect_c(&self, lambda_vec: &[Scalar], mu_vec: &[Scalar], mu: &Scalar) -> (Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>) {
585 let (M_lnL, M_mnL, M_lnR, M_mnR) = self.collect_m_rl();
586 let (M_lnO, M_mnO, M_llL, M_mlL, M_llR, M_mlR, M_llO, M_mlO) = self.collect_m_o();
587
588 let mu_diag_inv = diag_inv(mu, self.dim_nm);
589
590 let c_nL = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnL), &vector_mul_on_matrix(mu_vec, &M_mnL)), &mu_diag_inv);
591 let c_nR = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnR), &vector_mul_on_matrix(mu_vec, &M_mnR)), &mu_diag_inv);
592 let c_nO = vector_mul_on_matrix(&vector_sub(&vector_mul_on_matrix(lambda_vec, &M_lnO), &vector_mul_on_matrix(mu_vec, &M_mnO)), &mu_diag_inv);
593
594 let c_lL = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llL), &vector_mul_on_matrix(mu_vec, &M_mlL));
595 let c_lR = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llR), &vector_mul_on_matrix(mu_vec, &M_mlR));
596 let c_lO = vector_sub(&vector_mul_on_matrix(lambda_vec, &M_llO), &vector_mul_on_matrix(mu_vec, &M_mlO));
597
598 (c_nL, c_nR, c_nO, c_lL, c_lR, c_lO)
599 }
600
601 fn collect_lambda(&self, lambda: &Scalar, mu: &Scalar) -> Vec<Scalar> {
602 let mut lambda_vec = e(lambda, self.dim_nl);
603 if self.f_l && self.f_m {
604 lambda_vec = vector_sub(
605 &lambda_vec,
606 &vector_add(
607 &vector_tensor_mul(&vector_mul_on_scalar(&e(lambda, self.dim_nv), mu), &e(&pow(mu, self.dim_nv), self.k)),
608 &vector_tensor_mul(&e(mu, self.dim_nv), &e(&pow(lambda, self.dim_nv), self.k)),
609 ),
610 );
611 }
612
613 lambda_vec
614 }
615
616 fn collect_m_rl(&self) -> (Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>) {
617 let M_lnL = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][..self.dim_nm])).collect::<Vec<Vec<Scalar>>>();
618 let M_mnL = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][..self.dim_nm])).collect::<Vec<Vec<Scalar>>>();
619 let M_lnR = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][self.dim_nm..self.dim_nm * 2])).collect::<Vec<Vec<Scalar>>>();
620 let M_mnR = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][self.dim_nm..self.dim_nm * 2])).collect::<Vec<Vec<Scalar>>>();
621 (M_lnL, M_mnL, M_lnR, M_mnR)
622 }
623
624 fn collect_m_o(&self) -> (Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>, Vec<Vec<Scalar>>) {
625 let W_lO = (0..self.dim_nl).map(|i| Vec::from(&self.W_l[i][self.dim_nm * 2..])).collect::<Vec<Vec<Scalar>>>();
626 let W_mO = (0..self.dim_nm).map(|i| Vec::from(&self.W_m[i][self.dim_nm * 2..])).collect::<Vec<Vec<Scalar>>>();
627
628 let map_f = |isz: usize, jsz: usize, typ: PartitionType, W_x: &Vec<Vec<Scalar>>| -> Vec<Vec<Scalar>>{
629 (0..isz).map(|i|
630 (0..jsz).map(|j|
631 if let Some(j_) = (self.partition)(typ, j) {
632 W_x[i][j_]
633 } else {
634 Scalar::ZERO
635 }
636 ).collect::<Vec<Scalar>>()
637 ).collect::<Vec<Vec<Scalar>>>()
638 };
639
640 let M_lnO = map_f(self.dim_nl, self.dim_nm, PartitionType::NO, &W_lO);
641 let M_llL = map_f(self.dim_nl, self.dim_nv, PartitionType::LL, &W_lO);
642 let M_llR = map_f(self.dim_nl, self.dim_nv, PartitionType::LR, &W_lO);
643 let M_llO = map_f(self.dim_nl, self.dim_nv, PartitionType::LO, &W_lO);
644
645
646 let M_mnO = map_f(self.dim_nm, self.dim_nm, PartitionType::NO, &W_mO);
647 let M_mlL = map_f(self.dim_nm, self.dim_nv, PartitionType::LL, &W_mO);
648 let M_mlR = map_f(self.dim_nm, self.dim_nv, PartitionType::LR, &W_mO);
649 let M_mlO = map_f(self.dim_nm, self.dim_nv, PartitionType::LO, &W_mO);
650
651
652 (M_lnO, M_mnO, M_llL, M_mlL, M_llR, M_mlR, M_llO, M_mlO)
653 }
654}