Skip to main content

oxilean_std/fin/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4use super::functions::*;
5
6/// A function from `Fin n` to some type `T`.
7#[derive(Debug, Clone)]
8pub struct FinFun<T> {
9    values: Vec<T>,
10}
11impl<T: Clone> FinFun<T> {
12    /// Create a `FinFun` from a vector of length `n`.
13    pub fn from_vec(values: Vec<T>) -> Self {
14        FinFun { values }
15    }
16    /// Create a constant function always returning `val`.
17    pub fn constant(bound: usize, val: T) -> Self {
18        FinFun {
19            values: vec![val; bound],
20        }
21    }
22    /// Create a function from a closure.
23    pub fn from_fn(bound: usize, f: impl Fn(Fin) -> T) -> Self {
24        let values = (0..bound).map(|i| f(Fin { val: i, bound })).collect();
25        FinFun { values }
26    }
27    /// Apply the function to a `Fin` element.
28    pub fn apply(&self, i: Fin) -> Option<&T> {
29        self.values.get(i.val)
30    }
31    /// The domain size.
32    pub fn bound(&self) -> usize {
33        self.values.len()
34    }
35    /// Iterate over (Fin i, &T) pairs.
36    pub fn iter(&self) -> impl Iterator<Item = (Fin, &T)> {
37        let bound = self.values.len();
38        self.values
39            .iter()
40            .enumerate()
41            .map(move |(i, v)| (Fin { val: i, bound }, v))
42    }
43}
44/// A permutation of `{0, ..., n-1}` stored as a vector.
45#[allow(dead_code)]
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub struct FinExtPerm {
48    /// The permutation as a map: index i → perm[i].
49    pub perm: Vec<usize>,
50}
51#[allow(dead_code)]
52impl FinExtPerm {
53    /// Create the identity permutation on n elements.
54    pub fn identity(n: usize) -> Self {
55        FinExtPerm {
56            perm: (0..n).collect(),
57        }
58    }
59    /// Create a permutation from a vector. Returns None if not a valid permutation.
60    pub fn from_vec(v: Vec<usize>) -> Option<Self> {
61        let n = v.len();
62        let mut seen = vec![false; n];
63        for &x in &v {
64            if x >= n || seen[x] {
65                return None;
66            }
67            seen[x] = true;
68        }
69        Some(FinExtPerm { perm: v })
70    }
71    /// The size of the permutation.
72    pub fn len(&self) -> usize {
73        self.perm.len()
74    }
75    /// Returns true if the permutation is empty.
76    pub fn is_empty(&self) -> bool {
77        self.perm.is_empty()
78    }
79    /// Apply the permutation to index i.
80    pub fn apply(&self, i: usize) -> Option<usize> {
81        self.perm.get(i).copied()
82    }
83    /// Compose self ∘ other (apply other first, then self).
84    pub fn compose(&self, other: &Self) -> Option<Self> {
85        if self.len() != other.len() {
86            return None;
87        }
88        let perm = (0..self.len()).map(|i| self.perm[other.perm[i]]).collect();
89        Some(FinExtPerm { perm })
90    }
91    /// Compute the inverse permutation.
92    pub fn inverse(&self) -> Self {
93        let mut inv = vec![0usize; self.len()];
94        for (i, &j) in self.perm.iter().enumerate() {
95            inv[j] = i;
96        }
97        FinExtPerm { perm: inv }
98    }
99    /// Compute the sign (parity) of the permutation: +1 for even, -1 for odd.
100    pub fn sign(&self) -> i32 {
101        let n = self.len();
102        let mut visited = vec![false; n];
103        let mut sign = 1i32;
104        for i in 0..n {
105            if !visited[i] {
106                let mut cycle_len = 0;
107                let mut j = i;
108                while !visited[j] {
109                    visited[j] = true;
110                    j = self.perm[j];
111                    cycle_len += 1;
112                }
113                if cycle_len % 2 == 0 {
114                    sign = -sign;
115                }
116            }
117        }
118        sign
119    }
120    /// Count the number of cycles.
121    pub fn cycle_count(&self) -> usize {
122        let n = self.len();
123        let mut visited = vec![false; n];
124        let mut count = 0;
125        for i in 0..n {
126            if !visited[i] {
127                count += 1;
128                let mut j = i;
129                while !visited[j] {
130                    visited[j] = true;
131                    j = self.perm[j];
132                }
133            }
134        }
135        count
136    }
137    /// Check if this permutation is a derangement (no fixed points).
138    pub fn is_derangement(&self) -> bool {
139        self.perm.iter().enumerate().all(|(i, &p)| p != i)
140    }
141    /// Return the order of the permutation (smallest k > 0 with σ^k = id).
142    pub fn order(&self) -> usize {
143        let n = self.len();
144        if n == 0 {
145            return 1;
146        }
147        let mut visited = vec![false; n];
148        let mut lcm = 1usize;
149        for i in 0..n {
150            if !visited[i] {
151                let mut cycle_len = 0;
152                let mut j = i;
153                while !visited[j] {
154                    visited[j] = true;
155                    j = self.perm[j];
156                    cycle_len += 1;
157                }
158                lcm = lcm_usize(lcm, cycle_len);
159            }
160        }
161        lcm
162    }
163}
164/// A bijection between `Fin m * Fin n` and `Fin (m * n)`.
165#[allow(dead_code)]
166#[derive(Debug, Clone, Copy)]
167pub struct FinExtProduct {
168    /// The left bound m.
169    pub m: usize,
170    /// The right bound n.
171    pub n: usize,
172}
173#[allow(dead_code)]
174impl FinExtProduct {
175    /// Create a new product structure.
176    pub fn new(m: usize, n: usize) -> Self {
177        FinExtProduct { m, n }
178    }
179    /// The total size m * n.
180    pub fn size(&self) -> usize {
181        self.m * self.n
182    }
183    /// Encode (i, j) into Fin (m*n).
184    pub fn encode(&self, i: usize, j: usize) -> Option<usize> {
185        if i >= self.m || j >= self.n {
186            return None;
187        }
188        Some(i * self.n + j)
189    }
190    /// Decode index k into (i, j).
191    pub fn decode(&self, k: usize) -> Option<(usize, usize)> {
192        if self.n == 0 || k >= self.size() {
193            return None;
194        }
195        Some((k / self.n, k % self.n))
196    }
197    /// Sum over all elements using a function f(i, j) -> u64.
198    pub fn sum_over<F: Fn(usize, usize) -> u64>(&self, f: F) -> u64 {
199        let mut total = 0u64;
200        for i in 0..self.m {
201            for j in 0..self.n {
202                total += f(i, j);
203            }
204        }
205        total
206    }
207    /// Row sum: sum f(i, j) for fixed i over all j.
208    pub fn row_sum<F: Fn(usize, usize) -> u64>(&self, i: usize, f: F) -> u64 {
209        (0..self.n).map(|j| f(i, j)).sum()
210    }
211    /// Column sum: sum f(i, j) for fixed j over all i.
212    pub fn col_sum<F: Fn(usize, usize) -> u64>(&self, j: usize, f: F) -> u64 {
213        (0..self.m).map(|i| f(i, j)).sum()
214    }
215}
216/// A bounded integer value in `[0, n)`. Host-side representation of `Fin n`.
217#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
218pub struct Fin {
219    /// The numeric value.
220    pub val: usize,
221    /// The (exclusive) upper bound.
222    pub bound: usize,
223}
224impl Fin {
225    /// Create a new `Fin` if `val < bound`, otherwise `None`.
226    pub fn new(val: usize, bound: usize) -> Option<Self> {
227        if val < bound {
228            Some(Fin { val, bound })
229        } else {
230            None
231        }
232    }
233    /// Create a `Fin` representing zero for bound `n > 0`.
234    pub fn zero(bound: usize) -> Option<Self> {
235        if bound > 0 {
236            Some(Fin { val: 0, bound })
237        } else {
238            None
239        }
240    }
241    /// Create the last element (`bound - 1`) of `Fin bound`.
242    pub fn last(bound: usize) -> Option<Self> {
243        if bound > 0 {
244            Some(Fin {
245                val: bound - 1,
246                bound,
247            })
248        } else {
249            None
250        }
251    }
252    /// Return the successor, wrapping modulo `bound`.
253    pub fn succ_wrap(self) -> Self {
254        Fin {
255            val: (self.val + 1) % self.bound,
256            bound: self.bound,
257        }
258    }
259    /// Return the predecessor, wrapping modulo `bound`.
260    pub fn pred_wrap(self) -> Self {
261        Fin {
262            val: if self.val == 0 {
263                self.bound - 1
264            } else {
265                self.val - 1
266            },
267            bound: self.bound,
268        }
269    }
270    /// Return the additive inverse (complement): `bound - 1 - val`.
271    pub fn complement(self) -> Self {
272        Fin {
273            val: self.bound - 1 - self.val,
274            bound: self.bound,
275        }
276    }
277    /// Add two `Fin` values with the same bound (modular).
278    #[allow(clippy::should_implement_trait)]
279    pub fn add(self, other: Self) -> Option<Self> {
280        if self.bound != other.bound {
281            return None;
282        }
283        Some(Fin {
284            val: (self.val + other.val) % self.bound,
285            bound: self.bound,
286        })
287    }
288    /// Multiply two `Fin` values with the same bound (modular).
289    #[allow(clippy::should_implement_trait)]
290    pub fn mul(self, other: Self) -> Option<Self> {
291        if self.bound != other.bound {
292            return None;
293        }
294        Some(Fin {
295            val: (self.val * other.val) % self.bound,
296            bound: self.bound,
297        })
298    }
299    /// Subtract (modular). Returns `None` if bounds differ.
300    #[allow(clippy::should_implement_trait)]
301    pub fn sub(self, other: Self) -> Option<Self> {
302        if self.bound != other.bound {
303            return None;
304        }
305        let v = (self.val + self.bound - other.val) % self.bound;
306        Some(Fin {
307            val: v,
308            bound: self.bound,
309        })
310    }
311    /// Cast into a larger bound (`n ≤ m`).
312    pub fn cast(self, new_bound: usize) -> Option<Self> {
313        if self.val < new_bound {
314            Some(Fin {
315                val: self.val,
316                bound: new_bound,
317            })
318        } else {
319            None
320        }
321    }
322    /// Return all elements of `Fin n` in order.
323    pub fn all(bound: usize) -> Vec<Self> {
324        (0..bound).map(|v| Fin { val: v, bound }).collect()
325    }
326    /// Return true if this is the zero element.
327    pub fn is_zero(&self) -> bool {
328        self.val == 0
329    }
330    /// Return true if this is the last element.
331    pub fn is_last(&self) -> bool {
332        self.val + 1 == self.bound
333    }
334    /// Embed into `usize`.
335    pub fn as_usize(self) -> usize {
336        self.val
337    }
338    /// Embed into a `u64`.
339    pub fn as_u64(self) -> u64 {
340        self.val as u64
341    }
342}
343/// Young tableau shape: a partition of n stored as a non-increasing sequence.
344#[allow(dead_code)]
345#[derive(Debug, Clone, PartialEq, Eq)]
346pub struct FinExtYoungShape {
347    /// The partition: parts[i] is the length of the i-th row.
348    pub parts: Vec<usize>,
349}
350#[allow(dead_code)]
351impl FinExtYoungShape {
352    /// Create a Young shape from a partition (must be non-increasing).
353    pub fn new(parts: Vec<usize>) -> Option<Self> {
354        if parts.windows(2).all(|w| w[0] >= w[1]) {
355            let parts: Vec<usize> = parts.into_iter().filter(|&x| x > 0).collect();
356            Some(FinExtYoungShape { parts })
357        } else {
358            None
359        }
360    }
361    /// The total number of cells (size of the partition).
362    pub fn size(&self) -> usize {
363        self.parts.iter().sum()
364    }
365    /// Number of rows.
366    pub fn rows(&self) -> usize {
367        self.parts.len()
368    }
369    /// Length of row i (0-indexed).
370    pub fn row_len(&self, i: usize) -> usize {
371        self.parts.get(i).copied().unwrap_or(0)
372    }
373    /// The conjugate (transpose) partition.
374    pub fn conjugate(&self) -> Self {
375        if self.parts.is_empty() {
376            return FinExtYoungShape { parts: vec![] };
377        }
378        let max_col = self.parts[0];
379        let conj_parts: Vec<usize> = (0..max_col)
380            .map(|j| self.parts.iter().filter(|&&r| r > j).count())
381            .collect();
382        FinExtYoungShape { parts: conj_parts }
383    }
384    /// Check if this shape is a valid Young diagram (non-increasing rows).
385    pub fn is_valid(&self) -> bool {
386        self.parts.windows(2).all(|w| w[0] >= w[1])
387    }
388    /// Count standard Young tableaux using the hook length formula.
389    pub fn hook_length_count(&self) -> u64 {
390        let n = self.size();
391        if n == 0 {
392            return 1;
393        }
394        let factorial_n: u64 = (1..=n as u64).product();
395        let mut hook_product = 1u64;
396        for (i, &ri) in self.parts.iter().enumerate() {
397            for j in 0..ri {
398                let arm = ri - j - 1;
399                let leg = self.parts[i + 1..].iter().filter(|&&r| r > j).count();
400                let hook = arm + leg + 1;
401                hook_product *= hook as u64;
402            }
403        }
404        factorial_n / hook_product
405    }
406}
407/// An iterator over all `Fin n` elements.
408pub struct FinIter {
409    pub(super) current: usize,
410    pub(super) bound: usize,
411}
412impl FinIter {
413    /// Create an iterator over `Fin n`.
414    pub fn new(bound: usize) -> Self {
415        FinIter { current: 0, bound }
416    }
417}