feanor_math/algorithms/linsolve/
extension.rs

1use std::alloc::Allocator;
2
3use crate::matrix::*;
4use crate::rings::extension::{FreeAlgebra, FreeAlgebraStore};
5use crate::ring::*;
6use crate::seq::*;
7
8use super::{LinSolveRing, SolveResult};
9
10#[stability::unstable(feature = "enable")]
11pub fn solve_right_over_extension<R, V1, V2, V3, A>(ring: R, lhs: SubmatrixMut<V1, El<R>>, rhs: SubmatrixMut<V2, El<R>>, mut out: SubmatrixMut<V3, El<R>>, allocator: A) -> SolveResult
12    where R: RingStore,
13        R::Type: FreeAlgebra,
14        <<R::Type as RingExtension>::BaseRing as RingStore>::Type: LinSolveRing,
15        V1: AsPointerToSlice<El<R>>,
16        V2: AsPointerToSlice<El<R>>,
17        V3: AsPointerToSlice<El<R>>,
18        A: Allocator
19{
20    assert_eq!(lhs.row_count(), rhs.row_count());
21    assert_eq!(lhs.col_count(), out.row_count());
22    assert_eq!(rhs.col_count(), out.col_count());
23
24    let mut expanded_lhs = OwnedMatrix::zero_in(lhs.row_count() * ring.rank(), lhs.col_count() * ring.rank(), ring.base_ring(), &allocator);
25    let mut current;
26    let g = ring.canonical_gen();
27    for i in 0..lhs.row_count() {
28        for j in 0..lhs.col_count() {
29            current = ring.clone_el(lhs.at(i, j));
30            for l in 0..ring.rank() {
31                let current_wrt_basis = ring.wrt_canonical_basis(&current);
32                for k in 0..ring.rank() {
33                    *expanded_lhs.at_mut(i * ring.rank() + k, j * ring.rank() + l) = current_wrt_basis.at(k);
34                }
35                drop(current_wrt_basis);
36                ring.mul_assign_ref(&mut current, &g);
37            }
38        }
39    }
40
41    let mut expanded_rhs = OwnedMatrix::zero_in(rhs.row_count() * ring.rank(), rhs.col_count(), ring.base_ring(), &allocator);
42    for i in 0..rhs.row_count() {
43        for j in 0..rhs.col_count() {
44            let value_wrt_basis = ring.wrt_canonical_basis(rhs.at(i, j));
45            for k in 0..ring.rank() {
46                *expanded_rhs.at_mut(i * ring.rank() + k, j) = value_wrt_basis.at(k);
47            }
48        }
49    }
50
51    let mut solution = OwnedMatrix::zero_in(lhs.col_count() * ring.rank(), rhs.col_count(), ring.base_ring(), &allocator);
52    let sol = ring.base_ring().get_ring().solve_right(expanded_lhs.data_mut(), expanded_rhs.data_mut(), solution.data_mut(), &allocator);
53
54    if !sol.is_solved() {
55        return sol;
56    }
57
58    for i in 0..lhs.col_count() {
59        for j in 0..rhs.col_count() {
60            let res_value = ring.from_canonical_basis((0..ring.rank()).map(|k| ring.base_ring().clone_el(solution.at(i * ring.rank() + k, j))));
61            *out.at_mut(i, j) = res_value;
62        }
63    }
64
65    return sol;
66}
67
68#[cfg(test)]
69use std::alloc::Global;
70#[cfg(test)]
71use crate::algorithms::matmul::{MatmulAlgorithm, STANDARD_MATMUL};
72#[cfg(test)]
73use crate::rings::extension::extension_impl::FreeAlgebraImpl;
74#[cfg(test)]
75use crate::rings::zn::zn_static;
76#[cfg(test)]
77use crate::assert_matrix_eq;
78
79#[test]
80fn test_solve() {
81    let base_ring = zn_static::Zn::<15>::RING;
82    // Z_15[X]/(X^3 + X^2 + 1);  X^3 + X^2 + 1 = (X + 2)(X + 2X + 2) mod 3, but it is irreducible mod 5
83    let ring = FreeAlgebraImpl::new(base_ring, 3, [14, 0, 14]);
84    let el = |coeffs: [u64; 3]| ring.from_canonical_basis(coeffs);
85
86    let data_A = [
87        DerefArray::from([ el([1, 0, 0]), el([0, 0, 0]) ]),
88        DerefArray::from([ el([2, 1, 0]), el([0, 0, 0]) ]),
89        DerefArray::from([ el([0, 0, 0]), el([0, 1, 0]) ]),
90    ];
91    let data_B = [
92        DerefArray::from([ el([10, 10, 5]) ]),
93        DerefArray::from([ el([0, 0, 0]) ]),
94        DerefArray::from([ el([1, 0, 0]) ]),
95    ];
96    let mut A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
97    let mut B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
98    let mut sol: OwnedMatrix<_> = OwnedMatrix::zero(2, 1, &ring);
99
100    solve_right_over_extension(&ring, A.data_mut(), B.data_mut(), sol.data_mut(), Global).assert_solved();
101
102    let A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
103    let B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
104    let mut prod: OwnedMatrix<_> = OwnedMatrix::zero(3, 1, &ring);
105    STANDARD_MATMUL.matmul(TransposableSubmatrix::from(A.data()), TransposableSubmatrix::from(sol.data()), TransposableSubmatrixMut::from(prod.data_mut()), ring);
106
107    assert_matrix_eq!(&ring, &B, &prod);
108
109    let data_B = [
110        DerefArray::from([ el([8, 8, 3]) ]),
111        DerefArray::from([ el([0, 0, 0]) ]),
112        DerefArray::from([ el([1, 0, 0]) ]),
113    ];
114    let mut A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
115    let mut B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
116    let mut sol: OwnedMatrix<_> = OwnedMatrix::zero(2, 1, &ring);
117    assert!(!solve_right_over_extension(&ring, A.data_mut(), B.data_mut(), sol.data_mut(), Global).is_solved());
118}
119
120#[test]
121fn test_invert() {
122    let base_ring = zn_static::Zn::<15>::RING;
123    // Z_15[X]/(X^3 + X^2 + 1);  X^3 + X^2 + 1 = (X + 2)(X + 2X + 2) mod 3, but it is irreducible mod 5
124    let ring = FreeAlgebraImpl::new(base_ring, 3, [14, 0, 14]);
125
126    let matrix = OwnedMatrix::from_fn(2, 2, |i, j| if i == 0 || j == 0 {
127        ring.one()
128    } else {
129        ring.sub(ring.canonical_gen(), ring.one())
130    });
131    let mut inverse = OwnedMatrix::zero(2, 2, &ring);
132    solve_right_over_extension(&ring, matrix.clone_matrix(&ring).data_mut(), OwnedMatrix::identity(2, 2, &ring).data_mut(), inverse.data_mut(), Global).assert_solved();
133
134    let mut result = OwnedMatrix::zero(2, 2, &ring);
135    STANDARD_MATMUL.matmul(TransposableSubmatrix::from(matrix.data()), TransposableSubmatrix::from(inverse.data()), TransposableSubmatrixMut::from(result.data_mut()), &ring);
136
137    assert_matrix_eq!(&ring, OwnedMatrix::identity(2, 2, &ring), result);
138}