qudit_core/
function.rs

1use std::ops::Range;
2
3use bit_set::BitSet;
4use std::hash::{Hash, Hasher};
5
6use crate::RealScalar;
7
8/// A data structure representing indices of parameters (e.g. for a function).
9/// There are optimized methods for consecutive and disjoint parameters.
10#[derive(Clone, Debug)]
11pub enum ParamIndices {
12    /// The index of the first parameter (consecutive parameters)
13    Joint(usize, usize), // start, length
14
15    /// The index of each parameter (disjoint parameters)
16    Disjoint(Vec<usize>),
17}
18
19impl PartialEq for ParamIndices {
20    fn eq(&self, other: &Self) -> bool {
21        if self.len() != other.len() {
22            return false;
23        }
24
25        match (self, other) {
26            // Same variants - delegate to inner equality
27            (ParamIndices::Joint(start1, len1), ParamIndices::Joint(start2, len2)) => {
28                start1 == start2 && len1 == len2
29            }
30            (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => v1 == v2,
31            // Cross variants - compare iterators element by element
32            _ => self.iter().eq(other.iter()),
33        }
34    }
35}
36
37impl Eq for ParamIndices {}
38
39impl Hash for ParamIndices {
40    fn hash<H: Hasher>(&self, state: &mut H) {
41        self.len().hash(state);
42        for index in self.iter() {
43            index.hash(state);
44        }
45    }
46}
47
48impl ParamIndices {
49    /// Converts the `ParamIndices` to a `BitSet`.
50    ///
51    /// This function creates a `BitSet` where each bit corresponds to a parameter index.
52    /// If a parameter index is present in the `ParamIndices`, the corresponding bit in the
53    /// `BitSet` is set to 1.
54    ///
55    /// # Returns
56    ///
57    /// A `BitSet` representing the parameter indices.
58    pub fn to_bitset(&self) -> BitSet {
59        let mut bitset = BitSet::new();
60        match self {
61            ParamIndices::Joint(start, length) => {
62                for i in 0..*length {
63                    bitset.insert(*start + i);
64                }
65            }
66            ParamIndices::Disjoint(indices) => {
67                for index in indices {
68                    bitset.insert(*index);
69                }
70            }
71        }
72        bitset
73    }
74
75    /// Creates a `ParamIndices` representing a constant value (no parameters).
76    ///
77    /// # Returns
78    ///
79    /// A `ParamIndices` representing an empty set of parameter indices.
80    pub fn empty() -> ParamIndices {
81        ParamIndices::Joint(0, 0)
82    }
83
84    /// Checks if the `ParamIndices` is empty (contains no parameters).
85    ///
86    /// # Returns
87    ///
88    /// `true` if the `ParamIndices` is empty, `false` otherwise.
89    pub fn is_empty(&self) -> bool {
90        match self {
91            ParamIndices::Joint(_, length) => *length == 0,
92            ParamIndices::Disjoint(v) => v.is_empty(),
93        }
94    }
95
96    /// Checks if the `ParamIndices` represents a consecutive range of parameters.
97    ///
98    /// # Returns
99    ///
100    /// `true` if the `ParamIndices` is consecutive, `false` otherwise.
101    pub fn is_consecutive(&self) -> bool {
102        match self {
103            ParamIndices::Joint(_, _) => true,
104            ParamIndices::Disjoint(v) => {
105                if v.len() <= 1 {
106                    return true;
107                }
108                let mut clone_v = v.clone();
109                clone_v.sort();
110                for i in 1..clone_v.len() {
111                    if clone_v[i] != clone_v[i - 1] + 1 {
112                        return false;
113                    }
114                }
115                true
116            }
117        }
118    }
119
120    /// Returns the number of parameters represented by the `ParamIndices`.
121    ///
122    /// # Returns
123    ///
124    /// The number of parameters.
125    ///
126    pub fn num_params(&self) -> usize {
127        match self {
128            ParamIndices::Joint(_, length) => *length,
129            ParamIndices::Disjoint(v) => v.len(),
130        }
131    }
132
133    /// Returns the starting index of the parameters.
134    ///
135    /// For `ParamIndices::Joint`, this is the index of the first parameter in the consecutive range.
136    /// For `ParamIndices::Disjoint`, this is the index of the first parameter in the vector, or 0 if the vector is empty.
137    ///
138    /// # Returns
139    ///
140    /// The starting index of the parameters.
141    pub fn start(&self) -> usize {
142        match self {
143            ParamIndices::Joint(start, _) => *start,
144            ParamIndices::Disjoint(v) => *v.first().unwrap_or(&0),
145        }
146    }
147
148    /// Concatenates two `ParamIndices` into a single `ParamIndices`.
149    ///
150    /// This function combines the parameter indices from two `ParamIndices` instances.
151    /// If the two `ParamIndices` represent overlapping ranges, they are merged into a single `ParamIndices::Joint` if possible.
152    /// Otherwise, the parameter indices are combined into a `ParamIndices::Disjoint`.
153    ///
154    /// # Arguments
155    ///
156    /// * `other` - The other `ParamIndices` to concatenate with.
157    ///
158    /// # Returns
159    ///
160    /// A new `ParamIndices` representing the concatenation of the two input `ParamIndices`.
161    pub fn union(&self, other: &ParamIndices) -> ParamIndices {
162        match (self, other) {
163            (ParamIndices::Joint(start1, length1), ParamIndices::Joint(start2, length2)) => {
164                if *start2 > *start1 && *start2 < *start1 + *length1 {
165                    ParamIndices::Joint(*start1, *start2 + *length2 - *start1)
166                } else if *start1 > *start2 && *start1 < *start2 + *length2 {
167                    ParamIndices::Joint(*start2, *start1 + *length1 - *start2)
168                } else {
169                    let mut indices = Vec::new();
170                    for i in *start1..*start1 + *length1 {
171                        indices.push(i);
172                    }
173                    for i in *start2..*start2 + *length2 {
174                        indices.push(i);
175                    }
176                    ParamIndices::Disjoint(indices)
177                }
178            }
179            (ParamIndices::Joint(start, length), ParamIndices::Disjoint(v)) => {
180                let mut indices = Vec::new();
181                for i in *start..*start + *length {
182                    if v.contains(&i) {
183                        continue;
184                    }
185                    indices.push(i);
186                }
187                indices.extend(v.iter());
188                ParamIndices::Disjoint(indices)
189            }
190            (ParamIndices::Disjoint(v), ParamIndices::Joint(start, length)) => {
191                let mut indices = Vec::new();
192                for i in *start..*start + *length {
193                    if v.contains(&i) {
194                        continue;
195                    }
196                    indices.push(i);
197                }
198                indices.extend(v.iter());
199                ParamIndices::Disjoint(indices)
200            }
201            (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => {
202                let mut indices = Vec::new();
203                for i in v1 {
204                    if v2.contains(i) {
205                        continue;
206                    }
207                    indices.push(*i);
208                }
209                indices.extend(v2.iter());
210                ParamIndices::Disjoint(indices)
211            }
212        }
213    }
214
215    /// Computes the intersection of two `ParamIndices`.
216    ///
217    /// This function returns a new `ParamIndices` containing only the parameter indices that are present in both input `ParamIndices`.
218    /// The result is always a `ParamIndices::Disjoint`.
219    ///
220    /// # Arguments
221    ///
222    /// * `other` - The other `ParamIndices` to intersect with.
223    ///
224    /// # Returns
225    ///
226    /// A new `ParamIndices` representing the intersection of the two input `ParamIndices`.
227    pub fn intersect(&self, other: &ParamIndices) -> ParamIndices {
228        match (self, other) {
229            (ParamIndices::Joint(start1, length1), ParamIndices::Joint(start2, length2)) => {
230                let mut indices = Vec::new();
231                for i in *start1..*start1 + *length1 {
232                    if i >= *start2 && i < *start2 + *length2 {
233                        indices.push(i);
234                    }
235                }
236                ParamIndices::Disjoint(indices)
237            }
238            (ParamIndices::Joint(start, length), ParamIndices::Disjoint(v)) => {
239                let mut indices = Vec::new();
240                for i in *start..*start + *length {
241                    if v.contains(&i) {
242                        indices.push(i);
243                    }
244                }
245                ParamIndices::Disjoint(indices)
246            }
247            (ParamIndices::Disjoint(v), ParamIndices::Joint(start, length)) => {
248                let mut indices = Vec::new();
249                for i in *start..*start + *length {
250                    if v.contains(&i) {
251                        indices.push(i);
252                    }
253                }
254                ParamIndices::Disjoint(indices)
255            }
256            (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => {
257                let mut indices = Vec::new();
258                for i in v1 {
259                    if v2.contains(i) {
260                        indices.push(*i);
261                    }
262                }
263                ParamIndices::Disjoint(indices)
264            }
265        }
266    }
267
268    /// Sorts the indices.
269    ///
270    /// If the `ParamIndices` is `Joint`, this method does nothing.
271    ///
272    /// # Returns
273    ///
274    /// A mutable reference to `self` for chaining.
275    pub fn sort(&mut self) -> &mut Self {
276        match self {
277            ParamIndices::Joint(_, _) => {}
278            ParamIndices::Disjoint(indices) => indices.sort(),
279        }
280        self
281    }
282
283    /// Returns a new `ParamIndices` with the indices sorted.
284    pub fn sorted(&self) -> Self {
285        match self {
286            ParamIndices::Joint(s, l) => ParamIndices::Joint(*s, *l),
287            ParamIndices::Disjoint(indices) => {
288                let mut indices_out = indices.clone();
289                indices_out.sort();
290                ParamIndices::Disjoint(indices_out)
291            }
292        }
293    }
294
295    /// Creates an iterator over the parameter indices.
296    ///
297    /// # Returns
298    ///
299    /// A `ParamIndicesIter` that yields the parameter indices.
300    pub fn iter<'a>(&'a self) -> ParamIndicesIter<'a> {
301        match self {
302            ParamIndices::Joint(start, length) => ParamIndicesIter::Joint {
303                start: *start,
304                length: *length,
305                current: 0,
306            },
307            ParamIndices::Disjoint(indices) => ParamIndicesIter::Disjoint {
308                indices,
309                current: 0,
310            },
311        }
312    }
313
314    /// Checks if the `ParamIndices` contains the given index.
315    ///
316    /// # Arguments
317    ///
318    /// * `index` - The index to check.
319    ///
320    /// # Returns
321    ///
322    /// `true` if the `ParamIndices` contains the index, `false` otherwise.
323    pub fn contains(&self, index: usize) -> bool {
324        match self {
325            ParamIndices::Joint(start, length) => index >= *start && index < *start + *length,
326            ParamIndices::Disjoint(indices) => indices.contains(&index),
327        }
328    }
329
330    /// Convert the indices to a vector
331    pub fn to_vec(self) -> Vec<usize> {
332        match self {
333            ParamIndices::Joint(_, _) => self.iter().collect(),
334            ParamIndices::Disjoint(vec) => vec,
335        }
336    }
337
338    /// Convert the indices to a vector without consuming the indices object.
339    pub fn as_vec(&self) -> Vec<usize> {
340        self.iter().collect()
341    }
342
343    /// Returns the number of indices tracked by this object; alias for num_params()
344    pub fn len(&self) -> usize {
345        match self {
346            ParamIndices::Joint(_, length) => *length,
347            ParamIndices::Disjoint(indices) => indices.len(),
348        }
349    }
350}
351
352/// An iterator over the parameter indices in a `ParamIndices`.
353pub enum ParamIndicesIter<'a> {
354    /// Iterator for `ParamIndices::Joint`.
355    Joint {
356        /// The starting index of the consecutive range.
357        start: usize,
358        /// The length of the consecutive range.
359        length: usize,
360        /// The current index in the range.
361        current: usize,
362    },
363    /// Iterator for `ParamIndices::Disjoint`.
364    Disjoint {
365        /// A slice of the indices.
366        indices: &'a [usize],
367        /// The current index in the slice.
368        current: usize,
369    },
370}
371
372impl<'a> Iterator for ParamIndicesIter<'a> {
373    type Item = usize;
374
375    fn next(&mut self) -> Option<Self::Item> {
376        match self {
377            ParamIndicesIter::Joint {
378                start,
379                length,
380                current,
381            } => {
382                if *current < *length {
383                    let result = *start + *current;
384                    *current += 1;
385                    Some(result)
386                } else {
387                    None
388                }
389            }
390            ParamIndicesIter::Disjoint { indices, current } => {
391                if *current < indices.len() {
392                    let result = indices[*current];
393                    *current += 1;
394                    Some(result)
395                } else {
396                    None
397                }
398            }
399        }
400    }
401}
402
403impl<T: AsRef<[usize]>> From<T> for ParamIndices {
404    fn from(indices: T) -> Self {
405        ParamIndices::Disjoint(indices.as_ref().to_vec())
406    }
407}
408
409impl<'a> IntoIterator for &'a ParamIndices {
410    type Item = usize;
411    type IntoIter = ParamIndicesIter<'a>;
412
413    fn into_iter(self) -> Self::IntoIter {
414        self.iter()
415    }
416}
417
418/// Information about function parameters, including their indices and whether they are constant.
419///
420/// # Fields
421///
422/// * `indices` - The parameter indices, stored as either consecutive or disjoint ranges
423/// * `constant` - A vector indicating whether each parameter is constant (true) or variable (false)
424#[derive(Clone, Debug, PartialEq, Eq, Hash)]
425pub struct ParamInfo {
426    indices: ParamIndices,
427    constant: Vec<bool>, // TODO: change to bitmap
428}
429
430impl ParamInfo {
431    /// Creates a new `ParamInfo` with the given parameter indices and constant flags.
432    ///
433    /// # Arguments
434    ///
435    /// * `indices` - The parameter indices (can be any type that converts to `ParamIndices`)
436    /// * `constant` - A vector indicating whether each parameter is constant
437    ///
438    /// # Returns
439    ///
440    /// A new `ParamInfo` instance.
441    ///
442    /// # Panics
443    ///
444    /// The length of the `constant` vector should match the number of parameters in `indices`.
445    pub fn new(indices: impl Into<ParamIndices>, constant: Vec<bool>) -> ParamInfo {
446        ParamInfo {
447            indices: indices.into(),
448            constant,
449        }
450    }
451
452    /// Creates a new `ParamInfo` where all parameters are variable (not constant).
453    ///
454    /// This is a convenience method for creating parameter information where all
455    /// parameters can vary during optimization or analysis.
456    ///
457    /// # Arguments
458    ///
459    /// * `indices` - The parameter indices (can be any type that converts to `ParamIndices`)
460    ///
461    /// # Returns
462    ///
463    /// A new `ParamInfo` instance with all parameters marked as variable.
464    pub fn parameterized(indices: impl Into<ParamIndices>) -> ParamInfo {
465        let indices = indices.into();
466        let num_params = indices.num_params();
467        ParamInfo {
468            indices,
469            constant: vec![false; num_params],
470        }
471    }
472
473    /// Returns a new `ParamInfo` with its parameter indices sorted.
474    ///
475    /// If the `ParamIndices` are `Disjoint`, the internal vector is sorted.
476    /// If the `ParamIndices` are `Joint`, they remain `Joint` as their order is inherently sorted.
477    /// The `constant` vector is reordered to match the new order of indices.
478    pub fn to_sorted(&self) -> ParamInfo {
479        // Create a map from each parameter index to its 'constant' status.
480        let index_to_constant_map: std::collections::HashMap<usize, bool> = self
481            .indices
482            .iter()
483            .zip(self.constant.iter().cloned())
484            .collect();
485
486        // Get the sorted parameter indices. This will preserve the `Joint` variant if applicable.
487        let new_indices = self.indices.sorted();
488
489        // Build the new `constant` vector by iterating through the `new_indices`
490        // and looking up their constant status in the map.
491        let new_constant: Vec<bool> = new_indices
492            .iter()
493            .map(|index| *index_to_constant_map.get(&index).unwrap_or(&false))
494            .collect();
495
496        ParamInfo {
497            indices: new_indices,
498            constant: new_constant,
499        }
500    }
501
502    /// Unions this `ParamInfo` with another, combining their parameter indices and constant flags.
503    ///
504    /// This operation combines parameter indices from both `ParamInfo` instances. If a parameter
505    /// index appears in both instances, their constant flags must match or an assertion will fail.
506    ///
507    /// # Arguments
508    ///
509    /// * `other` - The other `ParamInfo` to concatenate with
510    ///
511    /// # Returns
512    ///
513    /// A new `ParamInfo` containing the combined parameters and their constant status.
514    pub fn union(&self, other: &ParamInfo) -> ParamInfo {
515        let combined_indices = self.indices.union(&other.indices);
516
517        let self_index_to_constant: std::collections::HashMap<usize, bool> = self
518            .indices
519            .iter()
520            .zip(self.constant.iter().cloned())
521            .collect();
522
523        let other_index_to_constant: std::collections::HashMap<usize, bool> = other
524            .indices
525            .iter()
526            .zip(other.constant.iter().cloned())
527            .collect();
528
529        let mut new_constant = Vec::new();
530        for index in combined_indices.iter() {
531            if let Some(&c) = self_index_to_constant.get(&index) {
532                if let Some(&c_other) = other_index_to_constant.get(&index) {
533                    assert_eq!(c, c_other);
534                }
535                new_constant.push(c);
536            } else if let Some(&c) = other_index_to_constant.get(&index) {
537                new_constant.push(c);
538            } else {
539                unreachable!();
540            };
541        }
542
543        ParamInfo {
544            indices: combined_indices,
545            constant: new_constant,
546        }
547    }
548
549    /// Computes the intersection of this `ParamInfo` with another.
550    ///
551    /// Returns a new `ParamInfo` containing only the parameters that are present in both
552    /// instances. The constant flags must match for common parameters.
553    ///
554    /// # Arguments
555    ///
556    /// * `other` - The other `ParamInfo` to intersect with
557    ///
558    /// # Returns
559    ///
560    /// A new `ParamInfo` containing only the common parameters.
561    pub fn intersect(&self, other: &ParamInfo) -> ParamInfo {
562        let combined_indices = self.indices.intersect(&other.indices);
563
564        let self_index_to_constant: std::collections::HashMap<usize, bool> = self
565            .indices
566            .iter()
567            .zip(self.constant.iter().cloned())
568            .collect();
569
570        let other_index_to_constant: std::collections::HashMap<usize, bool> = other
571            .indices
572            .iter()
573            .zip(other.constant.iter().cloned())
574            .collect();
575
576        let mut new_constant = Vec::new();
577        for index in combined_indices.iter() {
578            let c_self = self_index_to_constant.get(&index);
579            let c_other = other_index_to_constant.get(&index);
580
581            // An index in combined_indices must exist in both self and other ParamInfo
582            match (c_self, c_other) {
583                (Some(&s_const), Some(&o_const)) => {
584                    // Their constant status must be the same for the common index
585                    assert_eq!(
586                        s_const, o_const,
587                        "Constant status mismatch for common index {}",
588                        index
589                    );
590                    new_constant.push(s_const);
591                }
592                _ => unreachable!(
593                    "Intersected index {} not found in both original ParamInfos",
594                    index
595                ),
596            }
597        }
598
599        ParamInfo {
600            indices: combined_indices,
601            constant: new_constant,
602        }
603    }
604
605    /// Returns the total number of parameters.
606    ///
607    /// # Returns
608    ///
609    /// The number of parameters tracked by this `ParamInfo`.
610    pub fn num_params(&self) -> usize {
611        self.indices.num_params()
612    }
613
614    /// Returns the total number of variable parameters.
615    ///
616    /// # Returns
617    ///
618    /// The number of non-constant parameters tracked by this `ParamInfo`.
619    pub fn num_var_params(&self) -> usize {
620        self.indices.num_params() - self.constant.iter().filter(|&x| *x).count()
621    }
622
623    /// Returns the parameter indices as a vector of u64 values.
624    ///
625    /// This is useful for interfacing with external libraries that expect
626    /// parameter indices as 64-bit unsigned integers.
627    ///
628    /// # Returns
629    ///
630    /// A vector containing all parameter indices converted to u64.
631    pub fn get_param_map(&self) -> Vec<u64> {
632        self.indices.iter().map(|x| x as u64).collect()
633    }
634
635    /// Returns a copy of the constant flags for all parameters.
636    ///
637    /// The returned vector has the same length as the number of parameters,
638    /// where each boolean indicates whether the corresponding parameter is constant.
639    ///
640    /// # Returns
641    ///
642    /// A vector of boolean values indicating parameter constancy.
643    pub fn get_const_map(&self) -> Vec<bool> {
644        self.constant.clone()
645    }
646
647    /// Creates an empty `ParamInfo` with no parameters.
648    ///
649    /// This is useful as a default or starting point for parameter information.
650    ///
651    /// # Returns
652    ///
653    /// An empty `ParamInfo` instance.
654    pub fn empty() -> ParamInfo {
655        ParamInfo {
656            indices: vec![].into(),
657            constant: vec![],
658        }
659    }
660
661    /// Checks if this `ParamInfo` contains no parameters.
662    ///
663    /// # Returns
664    ///
665    /// `true` if there are no parameters, `false` otherwise.
666    pub fn is_empty(&self) -> bool {
667        self.len() == 0
668    }
669
670    /// Returns the number of parameters (alias for `num_params`).
671    ///
672    /// # Returns
673    ///
674    /// The total number of parameters.
675    pub fn len(&self) -> usize {
676        self.constant.len()
677    }
678
679    /// Returns a sorted vector of indices for parameters that are not constant.
680    ///
681    /// This is useful for optimization routines that only need to vary non-constant
682    /// parameters.
683    ///
684    /// # Returns
685    ///
686    /// A sorted vector containing the indices of all non-constant parameters.
687    pub fn sorted_non_constant(&self) -> Vec<usize> {
688        let mut non_constant_indices = Vec::new();
689        for (index, &is_constant) in self.indices.iter().zip(self.constant.iter()) {
690            if !is_constant {
691                non_constant_indices.push(index);
692            }
693        }
694        non_constant_indices.sort();
695        non_constant_indices
696    }
697}
698
699/// A parameterized object.
700pub trait HasParams {
701    /// The number of parameters this object requires.
702    fn num_params(&self) -> usize;
703}
704
705/// A bounded, parameterized object.
706pub trait HasBounds<R: RealScalar>: HasParams {
707    /// The bounds for each variable of the function
708    fn bounds(&self) -> Vec<Range<R>>;
709}
710
711/// A periodic, parameterized object
712pub trait HasPeriods<R: RealScalar>: HasParams {
713    /// The core period for each variable of the function
714    fn periods(&self) -> Vec<Range<R>>;
715}
716
717// #[cfg(test)]
718// pub mod strategies {
719//     use std::ops::Range;
720
721//     use proptest::prelude::*;
722
723//     use super::BoundedFn;
724
725//     pub fn params(num_params: usize) -> impl Strategy<Value = Vec<f64>> {
726//         prop::collection::vec(
727//             prop::num::f64::POSITIVE
728//                 | prop::num::f64::NEGATIVE
729//                 | prop::num::f64::NORMAL
730//                 | prop::num::f64::SUBNORMAL
731//                 | prop::num::f64::ZERO,
732//             num_params,
733//         )
734//     }
735
736//     pub fn params_with_bounds(
737//         bounds: Vec<Range<f64>>,
738//     ) -> impl Strategy<Value = Vec<f64>> {
739//         bounds
740//     }
741
742//     pub fn arbitrary_with_params_strategy<F: Clone + BoundedFn + Arbitrary>(
743//     ) -> impl Strategy<Value = (F, Vec<f64>)> {
744//         any::<F>().prop_flat_map(|f| (Just(f.clone()), f.get_bounds()))
745//     }
746// }
747
748#[cfg(feature = "python")]
749mod python {
750    use super::*;
751    use pyo3::prelude::*;
752    use pyo3::types::PyIterator;
753
754    /// Python wrapper for parameter indices.
755    ///
756    /// This provides parameter index functionality to Python, supporting both
757    /// consecutive (Joint) and disjoint parameter representations.
758    #[pyclass(name = "ParamIndices", frozen, eq, hash)]
759    #[derive(Clone, Debug, PartialEq, Eq, Hash)]
760    pub struct PyParamIndices {
761        inner: ParamIndices,
762    }
763
764    #[pymethods]
765    impl PyParamIndices {
766        /// Creates new parameter indices.
767        ///
768        /// # Arguments
769        ///
770        /// * `indices_or_start` - Either an iterable of parameter indices, or the starting index for a consecutive range
771        /// * `length` - Optional length for consecutive range (only used if `indices_or_start` is an integer)
772        ///
773        /// # Examples
774        ///
775        /// ```python
776        /// # Disjoint parameters
777        /// params1 = ParamIndices([0, 2, 5, 7])
778        ///
779        /// # Consecutive parameters starting at index 10 with length 5
780        /// params2 = ParamIndices(10, 5)  # represents [10, 11, 12, 13, 14]
781        /// ```
782        #[new]
783        #[pyo3(signature = (indices_or_start, length = None))]
784        fn new<'py>(indices_or_start: &Bound<'py, PyAny>, length: Option<usize>) -> PyResult<Self> {
785            if let Some(len) = length {
786                // Joint case: start + length
787                let start: usize = indices_or_start.extract()?;
788                Ok(PyParamIndices {
789                    inner: ParamIndices::Joint(start, len),
790                })
791            } else {
792                // Try to extract as an integer first (single parameter)
793                if let Ok(single_index) = indices_or_start.extract::<usize>() {
794                    Ok(PyParamIndices {
795                        inner: ParamIndices::Joint(single_index, 1),
796                    })
797                } else {
798                    // Try to extract as an iterable of indices
799                    let iter = PyIterator::from_object(indices_or_start)?;
800                    let mut indices = Vec::new();
801                    for item in iter {
802                        let index: usize = item?.extract()?;
803                        indices.push(index);
804                    }
805                    Ok(PyParamIndices {
806                        inner: ParamIndices::Disjoint(indices),
807                    })
808                }
809            }
810        }
811
812        /// Returns the number of parameters.
813        #[getter]
814        fn num_params(&self) -> usize {
815            self.inner.num_params()
816        }
817
818        /// Returns the starting index (for Joint) or first index (for Disjoint).
819        #[getter]
820        fn start(&self) -> usize {
821            self.inner.start()
822        }
823
824        /// Returns whether the parameters are consecutive.
825        #[getter]
826        fn is_consecutive(&self) -> bool {
827            self.inner.is_consecutive()
828        }
829
830        /// Returns whether the parameter indices are empty.
831        #[getter]
832        fn is_empty(&self) -> bool {
833            self.inner.is_empty()
834        }
835
836        /// Checks if the parameter indices contain the given index.
837        fn contains(&self, index: usize) -> bool {
838            self.inner.contains(index)
839        }
840
841        /// Returns all parameter indices as a list.
842        fn to_list(&self) -> Vec<usize> {
843            self.inner.as_vec()
844        }
845
846        /// Returns a sorted copy of these parameter indices.
847        fn sorted(&self) -> PyParamIndices {
848            PyParamIndices {
849                inner: self.inner.sorted(),
850            }
851        }
852
853        /// Unions with another ParamIndices.
854        fn union(&self, other: &PyParamIndices) -> PyParamIndices {
855            PyParamIndices {
856                inner: self.inner.union(&other.inner),
857            }
858        }
859
860        /// Intersects with another ParamIndices.
861        fn intersect(&self, other: &PyParamIndices) -> PyParamIndices {
862            PyParamIndices {
863                inner: self.inner.intersect(&other.inner),
864            }
865        }
866
867        /// Creates parameter indices for a constant (no parameters).
868        #[staticmethod]
869        fn constant() -> PyParamIndices {
870            PyParamIndices {
871                inner: ParamIndices::empty(),
872            }
873        }
874
875        fn __repr__(&self) -> String {
876            match &self.inner {
877                ParamIndices::Joint(start, length) => {
878                    if *length == 1 {
879                        format!("ParamIndices({})", start)
880                    } else {
881                        format!("ParamIndices({}, {})", start, length)
882                    }
883                }
884                ParamIndices::Disjoint(indices) => {
885                    format!("ParamIndices({:?})", indices)
886                }
887            }
888        }
889
890        fn __str__(&self) -> String {
891            let indices = self.inner.as_vec();
892            if indices.len() <= 3 {
893                format!("{:?}", indices)
894            } else {
895                format!(
896                    "[{}, ..., {}]",
897                    indices.first().unwrap(),
898                    indices.last().unwrap()
899                )
900            }
901        }
902
903        fn __len__(&self) -> usize {
904            self.inner.len()
905        }
906
907        fn __iter__(slf: PyRef<'_, Self>) -> PyParamIndicesIterator {
908            PyParamIndicesIterator {
909                indices: slf.inner.as_vec(),
910                index: 0,
911            }
912        }
913
914        fn __contains__(&self, index: usize) -> bool {
915            self.inner.contains(index)
916        }
917    }
918
919    #[pyclass]
920    struct PyParamIndicesIterator {
921        indices: Vec<usize>,
922        index: usize,
923    }
924
925    #[pymethods]
926    impl PyParamIndicesIterator {
927        fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
928            slf
929        }
930
931        fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<usize> {
932            if slf.index < slf.indices.len() {
933                let result = slf.indices[slf.index];
934                slf.index += 1;
935                Some(result)
936            } else {
937                None
938            }
939        }
940    }
941
942    impl From<ParamIndices> for PyParamIndices {
943        fn from(indices: ParamIndices) -> Self {
944            PyParamIndices { inner: indices }
945        }
946    }
947
948    impl From<PyParamIndices> for ParamIndices {
949        fn from(py_indices: PyParamIndices) -> Self {
950            py_indices.inner
951        }
952    }
953
954    impl<'py> IntoPyObject<'py> for ParamIndices {
955        type Target = PyParamIndices;
956        type Output = Bound<'py, Self::Target>;
957        type Error = PyErr;
958
959        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
960            let py_indices = PyParamIndices::from(self);
961            Bound::new(py, py_indices)
962        }
963    }
964
965    impl<'a, 'py> FromPyObject<'a, 'py> for ParamIndices {
966        type Error = PyErr;
967
968        fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
969            let py_indices: PyParamIndices = ob.extract()?;
970            Ok(py_indices.into())
971        }
972    }
973}