use crate::natural::arithmetic::add::{
limbs_add_same_length_to_out, limbs_slice_add_same_length_in_place_left,
};
use crate::natural::arithmetic::mul::{
limbs_mul_greater_to_out, limbs_mul_to_out, limbs_mul_to_out_scratch_len,
};
use crate::natural::arithmetic::sub::{
limbs_sub_same_length_in_place_left, limbs_sub_same_length_in_place_right,
limbs_sub_same_length_to_out,
};
use crate::natural::comparison::cmp::limbs_cmp_same_length;
use crate::platform::{Limb, MATRIX22_STRASSEN_THRESHOLD};
use core::cmp::Ordering::*;
fn limbs_sub_abs_same_length_to_out(out: &mut [Limb], xs: &[Limb], ys: &[Limb]) -> bool {
let n = xs.len();
assert_eq!(ys.len(), n);
if limbs_cmp_same_length(xs, ys) == Less {
limbs_sub_same_length_to_out(out, ys, xs);
true
} else {
limbs_sub_same_length_to_out(out, xs, ys);
false
}
}
fn limbs_sub_abs_same_length_in_place_left(xs: &mut [Limb], ys: &[Limb]) -> bool {
let n = xs.len();
assert_eq!(ys.len(), n);
if limbs_cmp_same_length(xs, ys) == Less {
limbs_sub_same_length_in_place_right(ys, xs);
true
} else {
limbs_sub_same_length_in_place_left(xs, ys);
false
}
}
fn limbs_sub_abs_same_length_in_place_right(xs: &[Limb], ys: &mut [Limb]) -> bool {
let n = xs.len();
assert_eq!(ys.len(), n);
if limbs_cmp_same_length(xs, ys) == Less {
limbs_sub_same_length_in_place_left(ys, xs);
true
} else {
limbs_sub_same_length_in_place_right(xs, ys);
false
}
}
fn limbs_add_signed_same_length_to_out(
out: &mut [Limb],
xs: &[Limb],
x_sign: bool,
ys: &[Limb],
y_sign: bool,
) -> bool {
if x_sign == y_sign {
assert!(!limbs_add_same_length_to_out(out, xs, ys));
x_sign
} else {
x_sign != limbs_sub_abs_same_length_to_out(out, xs, ys)
}
}
fn limbs_add_signed_same_length_in_place_left(
xs: &mut [Limb],
x_sign: bool,
ys: &[Limb],
y_sign: bool,
) -> bool {
if x_sign == y_sign {
assert!(!limbs_slice_add_same_length_in_place_left(xs, ys));
x_sign
} else {
x_sign != limbs_sub_abs_same_length_in_place_left(xs, ys)
}
}
pub_const_test! {limbs_matrix_mul_2_2_scratch_len(xs_len: usize, ys_len: usize) -> usize {
if xs_len < MATRIX22_STRASSEN_THRESHOLD || ys_len < MATRIX22_STRASSEN_THRESHOLD {
3 * xs_len + 2 * ys_len
} else {
3 * (xs_len + ys_len) + 5
}
}}
pub_test! {limbs_matrix_2_2_mul_small(
xs00: &mut [Limb],
xs01: &mut [Limb],
xs10: &mut [Limb],
xs11: &mut [Limb],
xs_len: usize,
ys00: &[Limb],
ys01: &[Limb],
ys10: &[Limb],
ys11: &[Limb],
scratch: &mut [Limb],
) {
let ys_len = ys00.len();
let out_len = xs_len + ys_len;
let (scratch, remainder) = scratch.split_at_mut(xs_len);
split_into_chunks_mut!(remainder, out_len, [p0, p1], _unused);
let mut t0 = &mut *xs00;
let mut t1 = &mut *xs01;
let mut mul_scratch = vec![0; limbs_mul_to_out_scratch_len(xs_len, ys_len)];
for _ in 0..2 {
let t0_0 = &t0[..xs_len];
scratch.copy_from_slice(t0_0);
if xs_len >= ys_len {
limbs_mul_greater_to_out(p0, t0_0, ys00, &mut mul_scratch);
let t1_0 = &t1[..xs_len];
limbs_mul_greater_to_out(p1, t1_0, ys11, &mut mul_scratch);
limbs_mul_greater_to_out(t0, t1_0, ys10, &mut mul_scratch);
limbs_mul_greater_to_out(t1, scratch, ys01, &mut mul_scratch);
} else {
limbs_mul_greater_to_out(p0, ys00, t0_0, &mut mul_scratch);
let t1_0 = &t1[..xs_len];
limbs_mul_greater_to_out(p1, ys11, t1_0, &mut mul_scratch);
limbs_mul_greater_to_out(t0, ys10, t1_0, &mut mul_scratch);
limbs_mul_greater_to_out(t1, ys01, scratch, &mut mul_scratch);
}
let (t0_last, t0_init) = t0[..=out_len].split_last_mut().unwrap();
*t0_last = Limb::from(limbs_slice_add_same_length_in_place_left(t0_init, p0));
let (t1_last, t1_init) = t1[..=out_len].split_last_mut().unwrap();
*t1_last = Limb::from(limbs_slice_add_same_length_in_place_left(t1_init, p1));
t0 = &mut *xs10;
t1 = &mut *xs11;
}
}}
pub_test! {limbs_matrix_2_2_mul_strassen(
xs00: &mut [Limb],
xs01: &mut [Limb],
xs10: &mut [Limb],
xs11: &mut [Limb],
xs_len: usize,
ys00: &[Limb],
ys01: &[Limb],
ys10: &[Limb],
ys11: &[Limb],
scratch: &mut [Limb],
) {
let ys_len = ys00.len();
let sum_len = xs_len + ys_len;
let (s0, remainder) = scratch.split_at_mut(xs_len + 1);
let (s0_last, s0_init) = s0.split_last_mut().unwrap();
let (t0, remainder) = remainder.split_at_mut(ys_len + 1);
let (t0_last, t0_init) = t0.split_last_mut().unwrap();
let (u0, u1) = remainder.split_at_mut(sum_len + 1);
let u1 = &mut u1[..sum_len + 2];
let xs00_lo = &xs00[..xs_len];
let xs01_lo = &mut xs01[..=xs_len];
let (xs01_lo_last, xs01_lo_init) = xs01_lo.split_last_mut().unwrap();
let xs10 = &mut xs10[..=sum_len];
let xs10_lo = &xs10[..xs_len];
let xs11 = &mut xs11[..=sum_len];
let xs11_lo = &mut xs11[..xs_len];
let mut mul_scratch = vec![
0;
max!(
limbs_mul_to_out_scratch_len(xs_len, ys_len),
limbs_mul_to_out_scratch_len(xs_len, ys_len + 1),
limbs_mul_to_out_scratch_len(xs_len + 1, ys_len),
limbs_mul_to_out_scratch_len(xs_len + 1, ys_len + 1)
)
];
assert!(xs01_lo_init.len() <= sum_len + 1);
assert!(ys10.len() <= ys_len + 1);
limbs_mul_to_out(u0, xs01_lo_init, ys10, &mut mul_scratch);
let mut x11_sign = limbs_sub_abs_same_length_in_place_left(xs11_lo, xs10_lo);
let x01_sign = if x11_sign {
*xs01_lo_last = 0;
limbs_sub_abs_same_length_in_place_left(xs01_lo_init, xs11_lo)
} else {
*xs01_lo_last = Limb::from(limbs_slice_add_same_length_in_place_left(
xs01_lo_init,
xs11_lo,
));
false
};
let s0_sign = if x01_sign {
*s0_last = Limb::from(limbs_add_same_length_to_out(s0_init, xs01_lo_init, xs00_lo));
false
} else if *xs01_lo_last != 0 {
*s0_last = *xs01_lo_last;
if limbs_sub_same_length_to_out(s0_init, xs01_lo_init, xs00_lo) {
s0[xs_len] -= 1;
}
true
} else {
*s0_last = 0;
limbs_sub_abs_same_length_to_out(s0_init, xs00_lo, xs01_lo_init)
};
limbs_mul_to_out(u1, xs00_lo, ys00, &mut mul_scratch);
let (u0_last, u0_init) = u0.split_last_mut().unwrap();
xs00[sum_len] = Limb::from(limbs_add_same_length_to_out(xs00, u0_init, &u1[..sum_len]));
assert!(xs00[sum_len] < 2);
let mut t0_sign = limbs_sub_abs_same_length_to_out(t0_init, ys11, ys10);
let u1_sign = x11_sign == t0_sign;
limbs_mul_to_out(u1, xs11_lo, t0_init, &mut mul_scratch);
u1[sum_len] = 0;
*t0_last = if t0_sign {
t0_sign = limbs_sub_abs_same_length_in_place_right(ys01, t0_init);
0
} else {
Limb::from(limbs_slice_add_same_length_in_place_left(t0_init, ys01))
};
if *t0_last != 0 {
limbs_mul_to_out(xs11, xs01_lo_init, t0, &mut mul_scratch);
assert!(*xs01_lo_last < 2);
if *xs01_lo_last != 0 {
limbs_slice_add_same_length_in_place_left(&mut xs11[xs_len..], t0);
}
} else {
limbs_mul_to_out(xs11, xs01_lo, t0_init, &mut mul_scratch);
}
assert!(xs11[sum_len] < 4);
*u0_last = 0;
x11_sign = if x01_sign == t0_sign {
assert!(!limbs_slice_add_same_length_in_place_left(xs11, u0));
false
} else {
limbs_sub_abs_same_length_in_place_right(u0, xs11)
};
let (t0_last, t0_init) = t0.split_last_mut().unwrap();
if t0_sign {
*t0_last = Limb::from(limbs_slice_add_same_length_in_place_left(t0_init, ys00));
} else if *t0_last != 0 {
if limbs_sub_same_length_in_place_left(t0_init, ys00) {
*t0_last -= 1;
}
} else {
t0_sign = limbs_sub_abs_same_length_in_place_left(t0_init, ys00);
}
limbs_mul_to_out(u0, xs10_lo, t0, &mut mul_scratch);
assert!(u0[sum_len] < 2);
let (xs01_lo_last, xs01_lo_init) = xs01_lo.split_last_mut().unwrap();
if x01_sign {
assert!(!limbs_sub_same_length_in_place_right(xs10_lo, xs01_lo_init));
} else if limbs_slice_add_same_length_in_place_left(xs01_lo_init, xs10_lo) {
*xs01_lo_last += 1;
}
t0_sign = limbs_add_signed_same_length_to_out(xs10, xs11, x11_sign, u0, t0_sign);
assert!(xs10[sum_len] < 4);
x11_sign =
limbs_add_signed_same_length_in_place_left(xs11, x11_sign, &u1[..=sum_len], u1_sign);
assert!(xs11[sum_len] < 3);
limbs_mul_to_out(u0, s0, ys01, &mut mul_scratch);
assert!(u0[sum_len] < 2);
t0[ys_len] = Limb::from(limbs_add_same_length_to_out(t0, ys11, ys01));
limbs_mul_to_out(u1, xs01_lo, t0, &mut mul_scratch);
assert!(u1[sum_len] < 4);
let (u1_last, u1_init) = u1.split_last_mut().unwrap();
assert_eq!(*u1_last, 0);
limbs_add_signed_same_length_to_out(xs01, xs11, x11_sign, u0, s0_sign);
assert!(xs01[sum_len] < 2);
if x11_sign {
assert!(!limbs_slice_add_same_length_in_place_left(xs11, u1_init));
} else {
assert!(!limbs_sub_same_length_in_place_right(u1_init, xs11));
}
assert!(xs11[sum_len] < 2);
if t0_sign {
assert!(!limbs_slice_add_same_length_in_place_left(xs10, u1_init));
} else {
assert!(!limbs_sub_same_length_in_place_right(u1_init, xs10));
}
assert!(xs10[sum_len] < 2);
}}
pub_crate_test! {limbs_matrix_2_2_mul(
xs00: &mut [Limb],
xs01: &mut [Limb],
xs10: &mut [Limb],
xs11: &mut [Limb],
xs_len: usize,
ys00: &[Limb],
ys01: &[Limb],
ys10: &[Limb],
ys11: &[Limb],
scratch: &mut [Limb],
) {
let ys_len = ys00.len();
assert_eq!(ys01.len(), ys_len);
assert_eq!(ys10.len(), ys_len);
assert_eq!(ys11.len(), ys_len);
if xs_len < MATRIX22_STRASSEN_THRESHOLD || ys_len < MATRIX22_STRASSEN_THRESHOLD {
limbs_matrix_2_2_mul_small(
xs00, xs01, xs10, xs11, xs_len, ys00, ys01, ys10, ys11, scratch,
);
} else {
limbs_matrix_2_2_mul_strassen(
xs00, xs01, xs10, xs11, xs_len, ys00, ys01, ys10, ys11, scratch,
);
}
}}