qudit_core/function.rs
1use std::ops::Range;
2
3use bit_set::BitSet;
4use std::hash::{Hash, Hasher};
5
6use crate::RealScalar;
7
8/// A data structure representing indices of parameters (e.g. for a function).
9/// There are optimized methods for consecutive and disjoint parameters.
10#[derive(Clone, Debug)]
11pub enum ParamIndices {
12 /// The index of the first parameter (consecutive parameters)
13 Joint(usize, usize), // start, length
14
15 /// The index of each parameter (disjoint parameters)
16 Disjoint(Vec<usize>),
17}
18
19impl PartialEq for ParamIndices {
20 fn eq(&self, other: &Self) -> bool {
21 if self.len() != other.len() {
22 return false;
23 }
24
25 match (self, other) {
26 // Same variants - delegate to inner equality
27 (ParamIndices::Joint(start1, len1), ParamIndices::Joint(start2, len2)) => {
28 start1 == start2 && len1 == len2
29 }
30 (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => v1 == v2,
31 // Cross variants - compare iterators element by element
32 _ => self.iter().eq(other.iter()),
33 }
34 }
35}
36
37impl Eq for ParamIndices {}
38
39impl Hash for ParamIndices {
40 fn hash<H: Hasher>(&self, state: &mut H) {
41 self.len().hash(state);
42 for index in self.iter() {
43 index.hash(state);
44 }
45 }
46}
47
48impl ParamIndices {
49 /// Converts the `ParamIndices` to a `BitSet`.
50 ///
51 /// This function creates a `BitSet` where each bit corresponds to a parameter index.
52 /// If a parameter index is present in the `ParamIndices`, the corresponding bit in the
53 /// `BitSet` is set to 1.
54 ///
55 /// # Returns
56 ///
57 /// A `BitSet` representing the parameter indices.
58 pub fn to_bitset(&self) -> BitSet {
59 let mut bitset = BitSet::new();
60 match self {
61 ParamIndices::Joint(start, length) => {
62 for i in 0..*length {
63 bitset.insert(*start + i);
64 }
65 }
66 ParamIndices::Disjoint(indices) => {
67 for index in indices {
68 bitset.insert(*index);
69 }
70 }
71 }
72 bitset
73 }
74
75 /// Creates a `ParamIndices` representing a constant value (no parameters).
76 ///
77 /// # Returns
78 ///
79 /// A `ParamIndices` representing an empty set of parameter indices.
80 pub fn empty() -> ParamIndices {
81 ParamIndices::Joint(0, 0)
82 }
83
84 /// Checks if the `ParamIndices` is empty (contains no parameters).
85 ///
86 /// # Returns
87 ///
88 /// `true` if the `ParamIndices` is empty, `false` otherwise.
89 pub fn is_empty(&self) -> bool {
90 match self {
91 ParamIndices::Joint(_, length) => *length == 0,
92 ParamIndices::Disjoint(v) => v.is_empty(),
93 }
94 }
95
96 /// Checks if the `ParamIndices` represents a consecutive range of parameters.
97 ///
98 /// # Returns
99 ///
100 /// `true` if the `ParamIndices` is consecutive, `false` otherwise.
101 pub fn is_consecutive(&self) -> bool {
102 match self {
103 ParamIndices::Joint(_, _) => true,
104 ParamIndices::Disjoint(v) => {
105 if v.len() <= 1 {
106 return true;
107 }
108 let mut clone_v = v.clone();
109 clone_v.sort();
110 for i in 1..clone_v.len() {
111 if clone_v[i] != clone_v[i - 1] + 1 {
112 return false;
113 }
114 }
115 true
116 }
117 }
118 }
119
120 /// Returns the number of parameters represented by the `ParamIndices`.
121 ///
122 /// # Returns
123 ///
124 /// The number of parameters.
125 ///
126 pub fn num_params(&self) -> usize {
127 match self {
128 ParamIndices::Joint(_, length) => *length,
129 ParamIndices::Disjoint(v) => v.len(),
130 }
131 }
132
133 /// Returns the starting index of the parameters.
134 ///
135 /// For `ParamIndices::Joint`, this is the index of the first parameter in the consecutive range.
136 /// For `ParamIndices::Disjoint`, this is the index of the first parameter in the vector, or 0 if the vector is empty.
137 ///
138 /// # Returns
139 ///
140 /// The starting index of the parameters.
141 pub fn start(&self) -> usize {
142 match self {
143 ParamIndices::Joint(start, _) => *start,
144 ParamIndices::Disjoint(v) => *v.first().unwrap_or(&0),
145 }
146 }
147
148 /// Concatenates two `ParamIndices` into a single `ParamIndices`.
149 ///
150 /// This function combines the parameter indices from two `ParamIndices` instances.
151 /// If the two `ParamIndices` represent overlapping ranges, they are merged into a single `ParamIndices::Joint` if possible.
152 /// Otherwise, the parameter indices are combined into a `ParamIndices::Disjoint`.
153 ///
154 /// # Arguments
155 ///
156 /// * `other` - The other `ParamIndices` to concatenate with.
157 ///
158 /// # Returns
159 ///
160 /// A new `ParamIndices` representing the concatenation of the two input `ParamIndices`.
161 pub fn union(&self, other: &ParamIndices) -> ParamIndices {
162 match (self, other) {
163 (ParamIndices::Joint(start1, length1), ParamIndices::Joint(start2, length2)) => {
164 if *start2 > *start1 && *start2 < *start1 + *length1 {
165 ParamIndices::Joint(*start1, *start2 + *length2 - *start1)
166 } else if *start1 > *start2 && *start1 < *start2 + *length2 {
167 ParamIndices::Joint(*start2, *start1 + *length1 - *start2)
168 } else {
169 let mut indices = Vec::new();
170 for i in *start1..*start1 + *length1 {
171 indices.push(i);
172 }
173 for i in *start2..*start2 + *length2 {
174 indices.push(i);
175 }
176 ParamIndices::Disjoint(indices)
177 }
178 }
179 (ParamIndices::Joint(start, length), ParamIndices::Disjoint(v)) => {
180 let mut indices = Vec::new();
181 for i in *start..*start + *length {
182 if v.contains(&i) {
183 continue;
184 }
185 indices.push(i);
186 }
187 indices.extend(v.iter());
188 ParamIndices::Disjoint(indices)
189 }
190 (ParamIndices::Disjoint(v), ParamIndices::Joint(start, length)) => {
191 let mut indices = Vec::new();
192 for i in *start..*start + *length {
193 if v.contains(&i) {
194 continue;
195 }
196 indices.push(i);
197 }
198 indices.extend(v.iter());
199 ParamIndices::Disjoint(indices)
200 }
201 (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => {
202 let mut indices = Vec::new();
203 for i in v1 {
204 if v2.contains(i) {
205 continue;
206 }
207 indices.push(*i);
208 }
209 indices.extend(v2.iter());
210 ParamIndices::Disjoint(indices)
211 }
212 }
213 }
214
215 /// Computes the intersection of two `ParamIndices`.
216 ///
217 /// This function returns a new `ParamIndices` containing only the parameter indices that are present in both input `ParamIndices`.
218 /// The result is always a `ParamIndices::Disjoint`.
219 ///
220 /// # Arguments
221 ///
222 /// * `other` - The other `ParamIndices` to intersect with.
223 ///
224 /// # Returns
225 ///
226 /// A new `ParamIndices` representing the intersection of the two input `ParamIndices`.
227 pub fn intersect(&self, other: &ParamIndices) -> ParamIndices {
228 match (self, other) {
229 (ParamIndices::Joint(start1, length1), ParamIndices::Joint(start2, length2)) => {
230 let mut indices = Vec::new();
231 for i in *start1..*start1 + *length1 {
232 if i >= *start2 && i < *start2 + *length2 {
233 indices.push(i);
234 }
235 }
236 ParamIndices::Disjoint(indices)
237 }
238 (ParamIndices::Joint(start, length), ParamIndices::Disjoint(v)) => {
239 let mut indices = Vec::new();
240 for i in *start..*start + *length {
241 if v.contains(&i) {
242 indices.push(i);
243 }
244 }
245 ParamIndices::Disjoint(indices)
246 }
247 (ParamIndices::Disjoint(v), ParamIndices::Joint(start, length)) => {
248 let mut indices = Vec::new();
249 for i in *start..*start + *length {
250 if v.contains(&i) {
251 indices.push(i);
252 }
253 }
254 ParamIndices::Disjoint(indices)
255 }
256 (ParamIndices::Disjoint(v1), ParamIndices::Disjoint(v2)) => {
257 let mut indices = Vec::new();
258 for i in v1 {
259 if v2.contains(i) {
260 indices.push(*i);
261 }
262 }
263 ParamIndices::Disjoint(indices)
264 }
265 }
266 }
267
268 /// Sorts the indices.
269 ///
270 /// If the `ParamIndices` is `Joint`, this method does nothing.
271 ///
272 /// # Returns
273 ///
274 /// A mutable reference to `self` for chaining.
275 pub fn sort(&mut self) -> &mut Self {
276 match self {
277 ParamIndices::Joint(_, _) => {}
278 ParamIndices::Disjoint(indices) => indices.sort(),
279 }
280 self
281 }
282
283 /// Returns a new `ParamIndices` with the indices sorted.
284 pub fn sorted(&self) -> Self {
285 match self {
286 ParamIndices::Joint(s, l) => ParamIndices::Joint(*s, *l),
287 ParamIndices::Disjoint(indices) => {
288 let mut indices_out = indices.clone();
289 indices_out.sort();
290 ParamIndices::Disjoint(indices_out)
291 }
292 }
293 }
294
295 /// Creates an iterator over the parameter indices.
296 ///
297 /// # Returns
298 ///
299 /// A `ParamIndicesIter` that yields the parameter indices.
300 pub fn iter<'a>(&'a self) -> ParamIndicesIter<'a> {
301 match self {
302 ParamIndices::Joint(start, length) => ParamIndicesIter::Joint {
303 start: *start,
304 length: *length,
305 current: 0,
306 },
307 ParamIndices::Disjoint(indices) => ParamIndicesIter::Disjoint {
308 indices,
309 current: 0,
310 },
311 }
312 }
313
314 /// Checks if the `ParamIndices` contains the given index.
315 ///
316 /// # Arguments
317 ///
318 /// * `index` - The index to check.
319 ///
320 /// # Returns
321 ///
322 /// `true` if the `ParamIndices` contains the index, `false` otherwise.
323 pub fn contains(&self, index: usize) -> bool {
324 match self {
325 ParamIndices::Joint(start, length) => index >= *start && index < *start + *length,
326 ParamIndices::Disjoint(indices) => indices.contains(&index),
327 }
328 }
329
330 /// Convert the indices to a vector
331 pub fn to_vec(self) -> Vec<usize> {
332 match self {
333 ParamIndices::Joint(_, _) => self.iter().collect(),
334 ParamIndices::Disjoint(vec) => vec,
335 }
336 }
337
338 /// Convert the indices to a vector without consuming the indices object.
339 pub fn as_vec(&self) -> Vec<usize> {
340 self.iter().collect()
341 }
342
343 /// Returns the number of indices tracked by this object; alias for num_params()
344 pub fn len(&self) -> usize {
345 match self {
346 ParamIndices::Joint(_, length) => *length,
347 ParamIndices::Disjoint(indices) => indices.len(),
348 }
349 }
350}
351
352/// An iterator over the parameter indices in a `ParamIndices`.
353pub enum ParamIndicesIter<'a> {
354 /// Iterator for `ParamIndices::Joint`.
355 Joint {
356 /// The starting index of the consecutive range.
357 start: usize,
358 /// The length of the consecutive range.
359 length: usize,
360 /// The current index in the range.
361 current: usize,
362 },
363 /// Iterator for `ParamIndices::Disjoint`.
364 Disjoint {
365 /// A slice of the indices.
366 indices: &'a [usize],
367 /// The current index in the slice.
368 current: usize,
369 },
370}
371
372impl<'a> Iterator for ParamIndicesIter<'a> {
373 type Item = usize;
374
375 fn next(&mut self) -> Option<Self::Item> {
376 match self {
377 ParamIndicesIter::Joint {
378 start,
379 length,
380 current,
381 } => {
382 if *current < *length {
383 let result = *start + *current;
384 *current += 1;
385 Some(result)
386 } else {
387 None
388 }
389 }
390 ParamIndicesIter::Disjoint { indices, current } => {
391 if *current < indices.len() {
392 let result = indices[*current];
393 *current += 1;
394 Some(result)
395 } else {
396 None
397 }
398 }
399 }
400 }
401}
402
403impl<T: AsRef<[usize]>> From<T> for ParamIndices {
404 fn from(indices: T) -> Self {
405 ParamIndices::Disjoint(indices.as_ref().to_vec())
406 }
407}
408
409impl<'a> IntoIterator for &'a ParamIndices {
410 type Item = usize;
411 type IntoIter = ParamIndicesIter<'a>;
412
413 fn into_iter(self) -> Self::IntoIter {
414 self.iter()
415 }
416}
417
418/// Information about function parameters, including their indices and whether they are constant.
419///
420/// # Fields
421///
422/// * `indices` - The parameter indices, stored as either consecutive or disjoint ranges
423/// * `constant` - A vector indicating whether each parameter is constant (true) or variable (false)
424#[derive(Clone, Debug, PartialEq, Eq, Hash)]
425pub struct ParamInfo {
426 indices: ParamIndices,
427 constant: Vec<bool>, // TODO: change to bitmap
428}
429
430impl ParamInfo {
431 /// Creates a new `ParamInfo` with the given parameter indices and constant flags.
432 ///
433 /// # Arguments
434 ///
435 /// * `indices` - The parameter indices (can be any type that converts to `ParamIndices`)
436 /// * `constant` - A vector indicating whether each parameter is constant
437 ///
438 /// # Returns
439 ///
440 /// A new `ParamInfo` instance.
441 ///
442 /// # Panics
443 ///
444 /// The length of the `constant` vector should match the number of parameters in `indices`.
445 pub fn new(indices: impl Into<ParamIndices>, constant: Vec<bool>) -> ParamInfo {
446 ParamInfo {
447 indices: indices.into(),
448 constant,
449 }
450 }
451
452 /// Creates a new `ParamInfo` where all parameters are variable (not constant).
453 ///
454 /// This is a convenience method for creating parameter information where all
455 /// parameters can vary during optimization or analysis.
456 ///
457 /// # Arguments
458 ///
459 /// * `indices` - The parameter indices (can be any type that converts to `ParamIndices`)
460 ///
461 /// # Returns
462 ///
463 /// A new `ParamInfo` instance with all parameters marked as variable.
464 pub fn parameterized(indices: impl Into<ParamIndices>) -> ParamInfo {
465 let indices = indices.into();
466 let num_params = indices.num_params();
467 ParamInfo {
468 indices,
469 constant: vec![false; num_params],
470 }
471 }
472
473 /// Returns a new `ParamInfo` with its parameter indices sorted.
474 ///
475 /// If the `ParamIndices` are `Disjoint`, the internal vector is sorted.
476 /// If the `ParamIndices` are `Joint`, they remain `Joint` as their order is inherently sorted.
477 /// The `constant` vector is reordered to match the new order of indices.
478 pub fn to_sorted(&self) -> ParamInfo {
479 // Create a map from each parameter index to its 'constant' status.
480 let index_to_constant_map: std::collections::HashMap<usize, bool> = self
481 .indices
482 .iter()
483 .zip(self.constant.iter().cloned())
484 .collect();
485
486 // Get the sorted parameter indices. This will preserve the `Joint` variant if applicable.
487 let new_indices = self.indices.sorted();
488
489 // Build the new `constant` vector by iterating through the `new_indices`
490 // and looking up their constant status in the map.
491 let new_constant: Vec<bool> = new_indices
492 .iter()
493 .map(|index| *index_to_constant_map.get(&index).unwrap_or(&false))
494 .collect();
495
496 ParamInfo {
497 indices: new_indices,
498 constant: new_constant,
499 }
500 }
501
502 /// Unions this `ParamInfo` with another, combining their parameter indices and constant flags.
503 ///
504 /// This operation combines parameter indices from both `ParamInfo` instances. If a parameter
505 /// index appears in both instances, their constant flags must match or an assertion will fail.
506 ///
507 /// # Arguments
508 ///
509 /// * `other` - The other `ParamInfo` to concatenate with
510 ///
511 /// # Returns
512 ///
513 /// A new `ParamInfo` containing the combined parameters and their constant status.
514 pub fn union(&self, other: &ParamInfo) -> ParamInfo {
515 let combined_indices = self.indices.union(&other.indices);
516
517 let self_index_to_constant: std::collections::HashMap<usize, bool> = self
518 .indices
519 .iter()
520 .zip(self.constant.iter().cloned())
521 .collect();
522
523 let other_index_to_constant: std::collections::HashMap<usize, bool> = other
524 .indices
525 .iter()
526 .zip(other.constant.iter().cloned())
527 .collect();
528
529 let mut new_constant = Vec::new();
530 for index in combined_indices.iter() {
531 if let Some(&c) = self_index_to_constant.get(&index) {
532 if let Some(&c_other) = other_index_to_constant.get(&index) {
533 assert_eq!(c, c_other);
534 }
535 new_constant.push(c);
536 } else if let Some(&c) = other_index_to_constant.get(&index) {
537 new_constant.push(c);
538 } else {
539 unreachable!();
540 };
541 }
542
543 ParamInfo {
544 indices: combined_indices,
545 constant: new_constant,
546 }
547 }
548
549 /// Computes the intersection of this `ParamInfo` with another.
550 ///
551 /// Returns a new `ParamInfo` containing only the parameters that are present in both
552 /// instances. The constant flags must match for common parameters.
553 ///
554 /// # Arguments
555 ///
556 /// * `other` - The other `ParamInfo` to intersect with
557 ///
558 /// # Returns
559 ///
560 /// A new `ParamInfo` containing only the common parameters.
561 pub fn intersect(&self, other: &ParamInfo) -> ParamInfo {
562 let combined_indices = self.indices.intersect(&other.indices);
563
564 let self_index_to_constant: std::collections::HashMap<usize, bool> = self
565 .indices
566 .iter()
567 .zip(self.constant.iter().cloned())
568 .collect();
569
570 let other_index_to_constant: std::collections::HashMap<usize, bool> = other
571 .indices
572 .iter()
573 .zip(other.constant.iter().cloned())
574 .collect();
575
576 let mut new_constant = Vec::new();
577 for index in combined_indices.iter() {
578 let c_self = self_index_to_constant.get(&index);
579 let c_other = other_index_to_constant.get(&index);
580
581 // An index in combined_indices must exist in both self and other ParamInfo
582 match (c_self, c_other) {
583 (Some(&s_const), Some(&o_const)) => {
584 // Their constant status must be the same for the common index
585 assert_eq!(
586 s_const, o_const,
587 "Constant status mismatch for common index {}",
588 index
589 );
590 new_constant.push(s_const);
591 }
592 _ => unreachable!(
593 "Intersected index {} not found in both original ParamInfos",
594 index
595 ),
596 }
597 }
598
599 ParamInfo {
600 indices: combined_indices,
601 constant: new_constant,
602 }
603 }
604
605 /// Returns the total number of parameters.
606 ///
607 /// # Returns
608 ///
609 /// The number of parameters tracked by this `ParamInfo`.
610 pub fn num_params(&self) -> usize {
611 self.indices.num_params()
612 }
613
614 /// Returns the total number of variable parameters.
615 ///
616 /// # Returns
617 ///
618 /// The number of non-constant parameters tracked by this `ParamInfo`.
619 pub fn num_var_params(&self) -> usize {
620 self.indices.num_params() - self.constant.iter().filter(|&x| *x).count()
621 }
622
623 /// Returns the parameter indices as a vector of u64 values.
624 ///
625 /// This is useful for interfacing with external libraries that expect
626 /// parameter indices as 64-bit unsigned integers.
627 ///
628 /// # Returns
629 ///
630 /// A vector containing all parameter indices converted to u64.
631 pub fn get_param_map(&self) -> Vec<u64> {
632 self.indices.iter().map(|x| x as u64).collect()
633 }
634
635 /// Returns a copy of the constant flags for all parameters.
636 ///
637 /// The returned vector has the same length as the number of parameters,
638 /// where each boolean indicates whether the corresponding parameter is constant.
639 ///
640 /// # Returns
641 ///
642 /// A vector of boolean values indicating parameter constancy.
643 pub fn get_const_map(&self) -> Vec<bool> {
644 self.constant.clone()
645 }
646
647 /// Creates an empty `ParamInfo` with no parameters.
648 ///
649 /// This is useful as a default or starting point for parameter information.
650 ///
651 /// # Returns
652 ///
653 /// An empty `ParamInfo` instance.
654 pub fn empty() -> ParamInfo {
655 ParamInfo {
656 indices: vec![].into(),
657 constant: vec![],
658 }
659 }
660
661 /// Checks if this `ParamInfo` contains no parameters.
662 ///
663 /// # Returns
664 ///
665 /// `true` if there are no parameters, `false` otherwise.
666 pub fn is_empty(&self) -> bool {
667 self.len() == 0
668 }
669
670 /// Returns the number of parameters (alias for `num_params`).
671 ///
672 /// # Returns
673 ///
674 /// The total number of parameters.
675 pub fn len(&self) -> usize {
676 self.constant.len()
677 }
678
679 /// Returns a sorted vector of indices for parameters that are not constant.
680 ///
681 /// This is useful for optimization routines that only need to vary non-constant
682 /// parameters.
683 ///
684 /// # Returns
685 ///
686 /// A sorted vector containing the indices of all non-constant parameters.
687 pub fn sorted_non_constant(&self) -> Vec<usize> {
688 let mut non_constant_indices = Vec::new();
689 for (index, &is_constant) in self.indices.iter().zip(self.constant.iter()) {
690 if !is_constant {
691 non_constant_indices.push(index);
692 }
693 }
694 non_constant_indices.sort();
695 non_constant_indices
696 }
697}
698
699/// A parameterized object.
700pub trait HasParams {
701 /// The number of parameters this object requires.
702 fn num_params(&self) -> usize;
703}
704
705/// A bounded, parameterized object.
706pub trait HasBounds<R: RealScalar>: HasParams {
707 /// The bounds for each variable of the function
708 fn bounds(&self) -> Vec<Range<R>>;
709}
710
711/// A periodic, parameterized object
712pub trait HasPeriods<R: RealScalar>: HasParams {
713 /// The core period for each variable of the function
714 fn periods(&self) -> Vec<Range<R>>;
715}
716
717// #[cfg(test)]
718// pub mod strategies {
719// use std::ops::Range;
720
721// use proptest::prelude::*;
722
723// use super::BoundedFn;
724
725// pub fn params(num_params: usize) -> impl Strategy<Value = Vec<f64>> {
726// prop::collection::vec(
727// prop::num::f64::POSITIVE
728// | prop::num::f64::NEGATIVE
729// | prop::num::f64::NORMAL
730// | prop::num::f64::SUBNORMAL
731// | prop::num::f64::ZERO,
732// num_params,
733// )
734// }
735
736// pub fn params_with_bounds(
737// bounds: Vec<Range<f64>>,
738// ) -> impl Strategy<Value = Vec<f64>> {
739// bounds
740// }
741
742// pub fn arbitrary_with_params_strategy<F: Clone + BoundedFn + Arbitrary>(
743// ) -> impl Strategy<Value = (F, Vec<f64>)> {
744// any::<F>().prop_flat_map(|f| (Just(f.clone()), f.get_bounds()))
745// }
746// }
747
748#[cfg(feature = "python")]
749mod python {
750 use super::*;
751 use pyo3::prelude::*;
752 use pyo3::types::PyIterator;
753
754 /// Python wrapper for parameter indices.
755 ///
756 /// This provides parameter index functionality to Python, supporting both
757 /// consecutive (Joint) and disjoint parameter representations.
758 #[pyclass(name = "ParamIndices", frozen, eq, hash)]
759 #[derive(Clone, Debug, PartialEq, Eq, Hash)]
760 pub struct PyParamIndices {
761 inner: ParamIndices,
762 }
763
764 #[pymethods]
765 impl PyParamIndices {
766 /// Creates new parameter indices.
767 ///
768 /// # Arguments
769 ///
770 /// * `indices_or_start` - Either an iterable of parameter indices, or the starting index for a consecutive range
771 /// * `length` - Optional length for consecutive range (only used if `indices_or_start` is an integer)
772 ///
773 /// # Examples
774 ///
775 /// ```python
776 /// # Disjoint parameters
777 /// params1 = ParamIndices([0, 2, 5, 7])
778 ///
779 /// # Consecutive parameters starting at index 10 with length 5
780 /// params2 = ParamIndices(10, 5) # represents [10, 11, 12, 13, 14]
781 /// ```
782 #[new]
783 #[pyo3(signature = (indices_or_start, length = None))]
784 fn new<'py>(indices_or_start: &Bound<'py, PyAny>, length: Option<usize>) -> PyResult<Self> {
785 if let Some(len) = length {
786 // Joint case: start + length
787 let start: usize = indices_or_start.extract()?;
788 Ok(PyParamIndices {
789 inner: ParamIndices::Joint(start, len),
790 })
791 } else {
792 // Try to extract as an integer first (single parameter)
793 if let Ok(single_index) = indices_or_start.extract::<usize>() {
794 Ok(PyParamIndices {
795 inner: ParamIndices::Joint(single_index, 1),
796 })
797 } else {
798 // Try to extract as an iterable of indices
799 let iter = PyIterator::from_object(indices_or_start)?;
800 let mut indices = Vec::new();
801 for item in iter {
802 let index: usize = item?.extract()?;
803 indices.push(index);
804 }
805 Ok(PyParamIndices {
806 inner: ParamIndices::Disjoint(indices),
807 })
808 }
809 }
810 }
811
812 /// Returns the number of parameters.
813 #[getter]
814 fn num_params(&self) -> usize {
815 self.inner.num_params()
816 }
817
818 /// Returns the starting index (for Joint) or first index (for Disjoint).
819 #[getter]
820 fn start(&self) -> usize {
821 self.inner.start()
822 }
823
824 /// Returns whether the parameters are consecutive.
825 #[getter]
826 fn is_consecutive(&self) -> bool {
827 self.inner.is_consecutive()
828 }
829
830 /// Returns whether the parameter indices are empty.
831 #[getter]
832 fn is_empty(&self) -> bool {
833 self.inner.is_empty()
834 }
835
836 /// Checks if the parameter indices contain the given index.
837 fn contains(&self, index: usize) -> bool {
838 self.inner.contains(index)
839 }
840
841 /// Returns all parameter indices as a list.
842 fn to_list(&self) -> Vec<usize> {
843 self.inner.as_vec()
844 }
845
846 /// Returns a sorted copy of these parameter indices.
847 fn sorted(&self) -> PyParamIndices {
848 PyParamIndices {
849 inner: self.inner.sorted(),
850 }
851 }
852
853 /// Unions with another ParamIndices.
854 fn union(&self, other: &PyParamIndices) -> PyParamIndices {
855 PyParamIndices {
856 inner: self.inner.union(&other.inner),
857 }
858 }
859
860 /// Intersects with another ParamIndices.
861 fn intersect(&self, other: &PyParamIndices) -> PyParamIndices {
862 PyParamIndices {
863 inner: self.inner.intersect(&other.inner),
864 }
865 }
866
867 /// Creates parameter indices for a constant (no parameters).
868 #[staticmethod]
869 fn constant() -> PyParamIndices {
870 PyParamIndices {
871 inner: ParamIndices::empty(),
872 }
873 }
874
875 fn __repr__(&self) -> String {
876 match &self.inner {
877 ParamIndices::Joint(start, length) => {
878 if *length == 1 {
879 format!("ParamIndices({})", start)
880 } else {
881 format!("ParamIndices({}, {})", start, length)
882 }
883 }
884 ParamIndices::Disjoint(indices) => {
885 format!("ParamIndices({:?})", indices)
886 }
887 }
888 }
889
890 fn __str__(&self) -> String {
891 let indices = self.inner.as_vec();
892 if indices.len() <= 3 {
893 format!("{:?}", indices)
894 } else {
895 format!(
896 "[{}, ..., {}]",
897 indices.first().unwrap(),
898 indices.last().unwrap()
899 )
900 }
901 }
902
903 fn __len__(&self) -> usize {
904 self.inner.len()
905 }
906
907 fn __iter__(slf: PyRef<'_, Self>) -> PyParamIndicesIterator {
908 PyParamIndicesIterator {
909 indices: slf.inner.as_vec(),
910 index: 0,
911 }
912 }
913
914 fn __contains__(&self, index: usize) -> bool {
915 self.inner.contains(index)
916 }
917 }
918
919 #[pyclass]
920 struct PyParamIndicesIterator {
921 indices: Vec<usize>,
922 index: usize,
923 }
924
925 #[pymethods]
926 impl PyParamIndicesIterator {
927 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
928 slf
929 }
930
931 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<usize> {
932 if slf.index < slf.indices.len() {
933 let result = slf.indices[slf.index];
934 slf.index += 1;
935 Some(result)
936 } else {
937 None
938 }
939 }
940 }
941
942 impl From<ParamIndices> for PyParamIndices {
943 fn from(indices: ParamIndices) -> Self {
944 PyParamIndices { inner: indices }
945 }
946 }
947
948 impl From<PyParamIndices> for ParamIndices {
949 fn from(py_indices: PyParamIndices) -> Self {
950 py_indices.inner
951 }
952 }
953
954 impl<'py> IntoPyObject<'py> for ParamIndices {
955 type Target = PyParamIndices;
956 type Output = Bound<'py, Self::Target>;
957 type Error = PyErr;
958
959 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
960 let py_indices = PyParamIndices::from(self);
961 Bound::new(py, py_indices)
962 }
963 }
964
965 impl<'a, 'py> FromPyObject<'a, 'py> for ParamIndices {
966 type Error = PyErr;
967
968 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
969 let py_indices: PyParamIndices = ob.extract()?;
970 Ok(py_indices.into())
971 }
972 }
973}