cygv/
semigroup.rs

1//! Affine semigroups that generate SCRP cones.
2//!
3//! This module contains functions to construct (truncated) affine semigroups
4//! that generate scrongly-convex rational polyhedral cones. In other words, the
5//! semigroups are of the form $S_\sigma=\sigma\cap\mathbb{Z}^n$ for a
6//! strongly-convex rational polyhedal cone $\sigma$ and some
7//! $n\in\mathbb{Z}_{>0}$.
8
9pub mod error;
10
11use core::cmp::Ordering;
12use error::SemigroupError;
13use itertools::Itertools;
14use nalgebra::{DMatrix, DVector, RowDVector};
15use std::collections::HashSet;
16
17/// A structure for an affine truncated semigroup.
18///
19/// The affine semigroup needs a grading vector that results in a
20/// positive-definite grading, which equivalently means that the semigroup
21/// generates a SCRP cone and the grading vector is in the dual cone.
22///
23/// Functions that use this structure assume that the elements are sorted by
24/// degree and that the data is consistent.
25#[derive(Clone, Debug, PartialEq)]
26pub struct Semigroup {
27    pub elements: DMatrix<i32>,
28    pub grading_vector: RowDVector<i32>,
29    pub degrees: RowDVector<u32>,
30    pub max_degree: u32,
31}
32
33impl Semigroup {
34    /// Constructs a semigroup from given data while only performing essential
35    /// checks.
36    ///
37    /// The matrix of elements must be in column-major format.
38    pub fn from_data(
39        mut elements: DMatrix<i32>,
40        grading_vector: RowDVector<i32>,
41    ) -> Result<Self, SemigroupError> {
42        let degrees = sort_elements(&mut elements, &grading_vector);
43        let degrees = check_final_degrees(degrees)?;
44        if elements.column(0).iter().any(|x| *x != 0) {
45            return Err(SemigroupError::MissingIdentityError);
46        }
47
48        let max_degree = degrees[degrees.len() - 1];
49
50        Ok(Self {
51            elements,
52            grading_vector,
53            degrees,
54            max_degree,
55        })
56    }
57
58    /// Constructs a semigroup given a list of elements containing the
59    /// generators, a grading vector, and a maximum degree.
60    pub fn with_max_degree(
61        elements: DMatrix<i32>,
62        grading_vector: RowDVector<i32>,
63        max_degree: u32,
64    ) -> Result<Self, SemigroupError> {
65        // Make sure that the input elements are valid.
66        check_degrees(&elements, &grading_vector)?;
67
68        // Remove elements that exceed the maximum degree from the start.
69        let elements = trim_by_max_deg(&elements, &grading_vector, max_degree);
70
71        let generators = find_generators(&elements);
72
73        let mut elements_set = HashSet::new();
74        for c in elements.column_iter() {
75            let tmp_vec = DVector::from_column_slice(c.as_slice());
76            elements_set.insert(tmp_vec);
77        }
78
79        let mut starting_elements = HashSet::new();
80        for c in elements.column_iter() {
81            let tmp_vec = DVector::from_column_slice(c.as_slice());
82            starting_elements.insert(tmp_vec);
83        }
84        drop(elements);
85
86        // TODO: This part might be easy to parallelize, so it's worth checking out.
87        loop {
88            let new_elements = find_new_elements_until_max_deg(
89                &generators,
90                &starting_elements,
91                &elements_set,
92                &grading_vector,
93                max_degree,
94            );
95            if new_elements.is_empty() {
96                break;
97            }
98            for c in new_elements.iter() {
99                elements_set.insert(c.clone());
100            }
101            starting_elements = new_elements;
102        }
103        // Make sure that the zero vector is in the set of elements.
104        elements_set.insert(DVector::zeros(grading_vector.len()));
105
106        let mut elements = DMatrix::zeros(grading_vector.len(), elements_set.len());
107        elements
108            .column_iter_mut()
109            .zip(elements_set)
110            .for_each(|(mut d, s)| d.copy_from(&s));
111
112        Self::from_data(elements, grading_vector)
113    }
114
115    /// Constructs a semigroup by increasing the maximum degree until the minimum number of elements is achieved.
116    pub fn with_min_elements(
117        elements: DMatrix<i32>,
118        grading_vector: RowDVector<i32>,
119        min_elements: usize,
120    ) -> Result<Self, SemigroupError> {
121        // Make sure that the input elements are valid.
122        check_degrees(&elements, &grading_vector)?;
123
124        let generators = find_generators(&elements);
125
126        let mut elements_set = HashSet::new();
127        elements_set.insert(DVector::zeros(grading_vector.len()));
128
129        drop(elements);
130
131        let mut max_degree = 0;
132        while elements_set.len() < min_elements {
133            max_degree += 1;
134            loop {
135                let new_elements = find_new_elements_until_max_deg(
136                    &generators,
137                    &elements_set,
138                    &elements_set,
139                    &grading_vector,
140                    max_degree,
141                );
142                if new_elements.is_empty() {
143                    break;
144                }
145                for c in new_elements.iter() {
146                    elements_set.insert(c.clone());
147                }
148            }
149        }
150
151        let mut elements = DMatrix::zeros(grading_vector.len(), elements_set.len());
152        elements
153            .column_iter_mut()
154            .zip(elements_set)
155            .for_each(|(mut d, s)| d.copy_from(&s));
156
157        Self::from_data(elements, grading_vector)
158    }
159}
160
161/// Sort the elements by degree and make sure that only the identity has degree
162/// zero.
163fn sort_elements(elements: &mut DMatrix<i32>, grading_vector: &RowDVector<i32>) -> RowDVector<i32> {
164    let signed_degrees = grading_vector * &*elements;
165
166    // TODO: figure out if there's a way to sort in place.
167
168    let mut degs_vecs: Vec<(i32, DVector<i32>)> = signed_degrees
169        .iter()
170        .cloned()
171        .zip(elements.column_iter())
172        .map(|(d, v)| (d, v.clone_owned()))
173        .collect();
174
175    degs_vecs.sort_unstable_by_key(|k| k.0);
176
177    let mut degrees = RowDVector::<i32>::zeros(elements.ncols());
178    degrees
179        .iter_mut()
180        .zip(elements.column_iter_mut())
181        .zip(degs_vecs)
182        .for_each(|((d, mut v), d_v)| {
183            *d = d_v.0;
184            v.copy_from(&d_v.1);
185        });
186
187    degrees
188}
189
190/// Check that the degrees are positive, except for the first one, which must be
191/// zero.
192fn check_final_degrees(degrees: RowDVector<i32>) -> Result<RowDVector<u32>, SemigroupError> {
193    let mut final_degrees = RowDVector::<u32>::zeros(degrees.len());
194
195    let mut degs_iter = degrees.iter();
196    let Some(zero_deg) = degs_iter.next() else {
197        return Err(SemigroupError::MissingIdentityError);
198    };
199    match zero_deg.cmp(&0) {
200        Ordering::Less => return Err(SemigroupError::NonPositiveDegreeError),
201        Ordering::Greater => return Err(SemigroupError::MissingIdentityError),
202        _ => (),
203    }
204
205    for (d, s) in final_degrees.iter_mut().skip(1).zip(degs_iter) {
206        if *s <= 0 {
207            return Err(SemigroupError::NonPositiveDegreeError);
208        }
209        *d = *s as u32
210    }
211
212    Ok(final_degrees)
213}
214
215fn check_degrees(
216    elements: &DMatrix<i32>,
217    grading_vector: &RowDVector<i32>,
218) -> Result<(), SemigroupError> {
219    let signed_degrees = grading_vector * elements;
220    for (d, c) in signed_degrees.into_iter().zip(elements.column_iter()) {
221        if *d < 0 || (*d == 0 && c.iter().any(|x| *x != 0)) {
222            return Err(SemigroupError::NonPositiveDegreeError);
223        }
224    }
225    Ok(())
226}
227
228/// Returns only the elements with degree up to the specified maximum degree.
229fn trim_by_max_deg(
230    elements: &DMatrix<i32>,
231    grading_vector: &RowDVector<i32>,
232    max_degree: u32,
233) -> DMatrix<i32> {
234    let selected: Vec<_> = elements
235        .column_iter()
236        .enumerate()
237        .filter(|(_, c)| (grading_vector * c)[(0, 0)] as u32 <= max_degree)
238        .map(|(i, _)| i)
239        .collect();
240    let mut trimmed = DMatrix::<i32>::zeros(elements.nrows(), selected.len());
241    trimmed
242        .column_iter_mut()
243        .zip(selected)
244        .for_each(|(mut c, i)| c.copy_from(&elements.column(i)));
245    trimmed
246}
247
248/// Tries to find a smaller subset of elements that generates the semigroup. It
249/// does not necessarily return the minimal set of generators since doing so is
250/// very difficult.
251fn find_generators(elements: &DMatrix<i32>) -> DMatrix<i32> {
252    // TODO: Need to check if it is worth to do this in parallel.
253
254    // TODO: Need to check if this is reasonable. For the original code it was
255    // 5, but that was probably too high.
256    let max_sum_elements = 3;
257
258    let dim = elements.nrows();
259    let zero_vec = DVector::<i32>::zeros(dim);
260    let mut tmp_vec = zero_vec.clone();
261
262    let mut generators: HashSet<_> = elements.column_iter().collect();
263    generators.remove(&zero_vec.as_view());
264
265    let mut to_remove = HashSet::new();
266
267    for n in 2..max_sum_elements {
268        for v in generators.iter().combinations_with_replacement(n) {
269            tmp_vec.copy_from(&zero_vec);
270            for c in v.into_iter() {
271                tmp_vec += *c;
272            }
273            let view = tmp_vec.column(0);
274            if generators.contains(&view) {
275                to_remove.insert(tmp_vec.clone());
276            }
277        }
278    }
279
280    for c in to_remove.iter() {
281        generators.remove(&c.as_view());
282    }
283
284    let mut generators_mat = DMatrix::zeros(dim, generators.len());
285    generators
286        .into_iter()
287        .zip(generators_mat.column_iter_mut())
288        .for_each(|(s, mut d)| d.copy_from(&s));
289    generators_mat
290}
291
292/// Find new elements up to a maximum degree using the generators and the starting elements.
293fn find_new_elements_until_max_deg(
294    generators: &DMatrix<i32>,
295    starting_elements: &HashSet<DVector<i32>>,
296    elements_set: &HashSet<DVector<i32>>,
297    grading_vector: &RowDVector<i32>,
298    max_degree: u32,
299) -> HashSet<DVector<i32>> {
300    let mut new_elements = HashSet::new();
301    let mut tmp_vec = DVector::zeros(generators.nrows());
302    for c1 in generators.column_iter() {
303        for c2 in starting_elements.iter() {
304            tmp_vec.copy_from(c2);
305            tmp_vec += c1;
306            let deg = grading_vector.tr_dot(&tmp_vec) as u32;
307            if deg <= max_degree && !elements_set.contains(&tmp_vec) {
308                new_elements.insert(tmp_vec.clone());
309            }
310        }
311    }
312    new_elements
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    fn example_elements_and_grading_vector() -> (DMatrix<i32>, RowDVector<i32>) {
320        #[rustfmt::skip]
321        let elements = DMatrix::from_column_slice(2, 6,
322            &[
323                0, 0,
324                1, 0,
325                0, 1,
326                2, 0,
327                1, 1,
328                0, 2
329            ]
330        );
331        let grading_vector = RowDVector::from_row_slice(&[1, 1]);
332        (elements, grading_vector)
333    }
334
335    #[test]
336    fn test_semigroup_from_data() {
337        let (elements, grading_vector) = example_elements_and_grading_vector();
338
339        let sg_result = Semigroup::from_data(elements.clone(), grading_vector.clone());
340        assert!(sg_result.is_ok());
341        let sg = sg_result.unwrap();
342        assert_eq!(sg.degrees, RowDVector::from_row_slice(&[0, 1, 1, 2, 2, 2]));
343
344        let grading_vector = RowDVector::from_row_slice(&[1, -1]);
345
346        let sg_result = Semigroup::from_data(elements, grading_vector);
347        assert!(sg_result.is_err());
348        let e = sg_result.err().unwrap();
349        assert_eq!(e, SemigroupError::NonPositiveDegreeError);
350    }
351
352    #[test]
353    fn test_semigroup_with_max_degree() {
354        let (elements, grading_vector) = example_elements_and_grading_vector();
355
356        let sg_result = Semigroup::with_max_degree(elements.clone(), grading_vector.clone(), 3);
357        assert!(sg_result.is_ok());
358        let sg = sg_result.unwrap();
359        assert_eq!(sg.degrees.len(), 10);
360        assert_eq!(
361            sg.degrees,
362            RowDVector::from_row_slice(&[0, 1, 1, 2, 2, 2, 3, 3, 3, 3])
363        );
364
365        let sg_result = Semigroup::with_max_degree(elements, grading_vector, 1);
366        assert!(sg_result.is_ok());
367        let sg = sg_result.unwrap();
368        assert_eq!(sg.degrees.len(), 3);
369        assert_eq!(sg.degrees, RowDVector::from_row_slice(&[0, 1, 1]));
370    }
371
372    #[test]
373    fn test_semigroup_with_min_elements() {
374        let (elements, grading_vector) = example_elements_and_grading_vector();
375
376        let sg_result = Semigroup::with_min_elements(elements.clone(), grading_vector.clone(), 11);
377        assert!(sg_result.is_ok());
378        let sg = sg_result.unwrap();
379        assert_eq!(sg.degrees.len(), 15);
380        assert_eq!(
381            sg.degrees,
382            RowDVector::from_row_slice(&[0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4])
383        );
384
385        let sg_result = Semigroup::with_min_elements(elements, grading_vector, 3);
386        assert!(sg_result.is_ok());
387        let sg = sg_result.unwrap();
388        assert_eq!(sg.degrees.len(), 3);
389        assert_eq!(sg.degrees, RowDVector::from_row_slice(&[0, 1, 1]));
390    }
391
392    #[test]
393    fn test_find_generators() {
394        let (elements, _) = example_elements_and_grading_vector();
395
396        let generators = find_generators(&elements);
397        assert_eq!(generators.ncols(), 2);
398    }
399}