feanor_math/algorithms/linsolve/
extension.rs1use std::alloc::Allocator;
2
3use crate::matrix::*;
4use crate::rings::extension::{FreeAlgebra, FreeAlgebraStore};
5use crate::ring::*;
6use crate::seq::*;
7
8use super::{LinSolveRing, SolveResult};
9
10#[stability::unstable(feature = "enable")]
11pub fn solve_right_over_extension<R, V1, V2, V3, A>(ring: R, lhs: SubmatrixMut<V1, El<R>>, rhs: SubmatrixMut<V2, El<R>>, mut out: SubmatrixMut<V3, El<R>>, allocator: A) -> SolveResult
12 where R: RingStore,
13 R::Type: FreeAlgebra,
14 <<R::Type as RingExtension>::BaseRing as RingStore>::Type: LinSolveRing,
15 V1: AsPointerToSlice<El<R>>,
16 V2: AsPointerToSlice<El<R>>,
17 V3: AsPointerToSlice<El<R>>,
18 A: Allocator
19{
20 assert_eq!(lhs.row_count(), rhs.row_count());
21 assert_eq!(lhs.col_count(), out.row_count());
22 assert_eq!(rhs.col_count(), out.col_count());
23
24 let mut expanded_lhs = OwnedMatrix::zero_in(lhs.row_count() * ring.rank(), lhs.col_count() * ring.rank(), ring.base_ring(), &allocator);
25 let mut current;
26 let g = ring.canonical_gen();
27 for i in 0..lhs.row_count() {
28 for j in 0..lhs.col_count() {
29 current = ring.clone_el(lhs.at(i, j));
30 for l in 0..ring.rank() {
31 let current_wrt_basis = ring.wrt_canonical_basis(¤t);
32 for k in 0..ring.rank() {
33 *expanded_lhs.at_mut(i * ring.rank() + k, j * ring.rank() + l) = current_wrt_basis.at(k);
34 }
35 drop(current_wrt_basis);
36 ring.mul_assign_ref(&mut current, &g);
37 }
38 }
39 }
40
41 let mut expanded_rhs = OwnedMatrix::zero_in(rhs.row_count() * ring.rank(), rhs.col_count(), ring.base_ring(), &allocator);
42 for i in 0..rhs.row_count() {
43 for j in 0..rhs.col_count() {
44 let value_wrt_basis = ring.wrt_canonical_basis(rhs.at(i, j));
45 for k in 0..ring.rank() {
46 *expanded_rhs.at_mut(i * ring.rank() + k, j) = value_wrt_basis.at(k);
47 }
48 }
49 }
50
51 let mut solution = OwnedMatrix::zero_in(lhs.col_count() * ring.rank(), rhs.col_count(), ring.base_ring(), &allocator);
52 let sol = ring.base_ring().get_ring().solve_right(expanded_lhs.data_mut(), expanded_rhs.data_mut(), solution.data_mut(), &allocator);
53
54 if !sol.is_solved() {
55 return sol;
56 }
57
58 for i in 0..lhs.col_count() {
59 for j in 0..rhs.col_count() {
60 let res_value = ring.from_canonical_basis((0..ring.rank()).map(|k| ring.base_ring().clone_el(solution.at(i * ring.rank() + k, j))));
61 *out.at_mut(i, j) = res_value;
62 }
63 }
64
65 return sol;
66}
67
68#[cfg(test)]
69use std::alloc::Global;
70#[cfg(test)]
71use crate::algorithms::matmul::{MatmulAlgorithm, STANDARD_MATMUL};
72#[cfg(test)]
73use crate::rings::extension::extension_impl::FreeAlgebraImpl;
74#[cfg(test)]
75use crate::rings::zn::zn_static;
76#[cfg(test)]
77use crate::assert_matrix_eq;
78
79#[test]
80fn test_solve() {
81 let base_ring = zn_static::Zn::<15>::RING;
82 let ring = FreeAlgebraImpl::new(base_ring, 3, [14, 0, 14]);
84 let el = |coeffs: [u64; 3]| ring.from_canonical_basis(coeffs);
85
86 let data_A = [
87 DerefArray::from([ el([1, 0, 0]), el([0, 0, 0]) ]),
88 DerefArray::from([ el([2, 1, 0]), el([0, 0, 0]) ]),
89 DerefArray::from([ el([0, 0, 0]), el([0, 1, 0]) ]),
90 ];
91 let data_B = [
92 DerefArray::from([ el([10, 10, 5]) ]),
93 DerefArray::from([ el([0, 0, 0]) ]),
94 DerefArray::from([ el([1, 0, 0]) ]),
95 ];
96 let mut A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
97 let mut B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
98 let mut sol: OwnedMatrix<_> = OwnedMatrix::zero(2, 1, &ring);
99
100 solve_right_over_extension(&ring, A.data_mut(), B.data_mut(), sol.data_mut(), Global).assert_solved();
101
102 let A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
103 let B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
104 let mut prod: OwnedMatrix<_> = OwnedMatrix::zero(3, 1, &ring);
105 STANDARD_MATMUL.matmul(TransposableSubmatrix::from(A.data()), TransposableSubmatrix::from(sol.data()), TransposableSubmatrixMut::from(prod.data_mut()), ring);
106
107 assert_matrix_eq!(&ring, &B, &prod);
108
109 let data_B = [
110 DerefArray::from([ el([8, 8, 3]) ]),
111 DerefArray::from([ el([0, 0, 0]) ]),
112 DerefArray::from([ el([1, 0, 0]) ]),
113 ];
114 let mut A = OwnedMatrix::from_fn_in(3, 2, |i, j| ring.clone_el(&data_A[i][j]), Global);
115 let mut B = OwnedMatrix::from_fn_in(3, 1, |i, j| ring.clone_el(&data_B[i][j]), Global);
116 let mut sol: OwnedMatrix<_> = OwnedMatrix::zero(2, 1, &ring);
117 assert!(!solve_right_over_extension(&ring, A.data_mut(), B.data_mut(), sol.data_mut(), Global).is_solved());
118}
119
120#[test]
121fn test_invert() {
122 let base_ring = zn_static::Zn::<15>::RING;
123 let ring = FreeAlgebraImpl::new(base_ring, 3, [14, 0, 14]);
125
126 let matrix = OwnedMatrix::from_fn(2, 2, |i, j| if i == 0 || j == 0 {
127 ring.one()
128 } else {
129 ring.sub(ring.canonical_gen(), ring.one())
130 });
131 let mut inverse = OwnedMatrix::zero(2, 2, &ring);
132 solve_right_over_extension(&ring, matrix.clone_matrix(&ring).data_mut(), OwnedMatrix::identity(2, 2, &ring).data_mut(), inverse.data_mut(), Global).assert_solved();
133
134 let mut result = OwnedMatrix::zero(2, 2, &ring);
135 STANDARD_MATMUL.matmul(TransposableSubmatrix::from(matrix.data()), TransposableSubmatrix::from(inverse.data()), TransposableSubmatrixMut::from(result.data_mut()), &ring);
136
137 assert_matrix_eq!(&ring, OwnedMatrix::identity(2, 2, &ring), result);
138}