1use std::marker::PhantomData;
2use std::ops::Index;
3
4use oximo_expr::Expr;
5use rustc_hash::FxHashMap;
6
7use crate::set::{Axis, FromIndexKey, IndexKey};
8
9#[derive(Clone)]
15pub(crate) enum Storage<'a> {
16 Dense { data: Vec<Expr<'a>>, keys: Vec<IndexKey>, axes: Box<[Axis]> },
17 Sparse(FxHashMap<IndexKey, Expr<'a>>),
18}
19
20pub struct IndexedVar<'a, K = IndexKey> {
31 pub(crate) storage: Storage<'a>,
32 pub(crate) _k: PhantomData<fn() -> K>,
33}
34
35impl<'a, K> Clone for IndexedVar<'a, K> {
36 fn clone(&self) -> Self {
37 Self { storage: self.storage.clone(), _k: PhantomData }
38 }
39}
40
41impl<'a, K> std::fmt::Debug for IndexedVar<'a, K> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("IndexedVar")
44 .field("len", &self.len())
45 .field("dense", &self.is_dense())
46 .finish()
47 }
48}
49
50impl<'a, K> IndexedVar<'a, K> {
51 pub fn len(&self) -> usize {
52 match &self.storage {
53 Storage::Dense { data, .. } => data.len(),
54 Storage::Sparse(m) => m.len(),
55 }
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.len() == 0
60 }
61
62 pub fn is_dense(&self) -> bool {
65 matches!(self.storage, Storage::Dense { .. })
66 }
67
68 pub fn shape(&self) -> Option<Box<[usize]>> {
70 match &self.storage {
71 Storage::Dense { axes, .. } => Some(axes.iter().map(|a| a.len).collect()),
72 Storage::Sparse(_) => None,
73 }
74 }
75
76 pub fn iter(&self) -> impl Iterator<Item = (&IndexKey, &Expr<'a>)> + '_ {
77 let it: Box<dyn Iterator<Item = (&IndexKey, &Expr<'a>)>> = match &self.storage {
78 Storage::Dense { data, keys, .. } => Box::new(keys.iter().zip(data.iter())),
79 Storage::Sparse(m) => Box::new(m.iter()),
80 };
81 it
82 }
83
84 pub fn get<Q: Into<IndexKey>>(&self, key: Q) -> Option<Expr<'a>> {
85 match &self.storage {
86 Storage::Sparse(m) => m.get(&key.into()).copied(),
87 Storage::Dense { data, axes, .. } => {
88 grid_offset(axes, &key.into()).map(|off| data[off])
89 }
90 }
91 }
92
93 pub fn at<const N: usize>(&self, coords: [usize; N]) -> Expr<'a> {
100 *self.get_ref(&coords).expect("IndexedVar: coordinates not present")
101 }
102
103 pub fn get_at<const N: usize>(&self, coords: [usize; N]) -> Option<Expr<'a>> {
105 self.get_ref(&coords).copied()
106 }
107
108 fn get_ref(&self, coords: &[usize]) -> Option<&Expr<'a>> {
109 match &self.storage {
110 Storage::Dense { data, axes, .. } => {
111 grid_offset_coords(axes, coords).map(|off| &data[off])
112 }
113 Storage::Sparse(m) => m.get(&coords_to_key(coords)),
114 }
115 }
116}
117
118impl<'a, K: FromIndexKey> IndexedVar<'a, K> {
119 pub fn keys(&self) -> impl Iterator<Item = (K, Expr<'a>)> + '_ {
121 self.iter().map(|(k, e)| (K::from_index_key(k), *e))
122 }
123}
124
125impl<'a, K, Q: Into<IndexKey>> Index<Q> for IndexedVar<'a, K> {
126 type Output = Expr<'a>;
127 fn index(&self, key: Q) -> &Self::Output {
128 match &self.storage {
129 Storage::Sparse(m) => m.get(&key.into()).expect("IndexedVar: key not present"),
130 Storage::Dense { data, axes, .. } => {
131 let off = grid_offset(axes, &key.into()).expect("IndexedVar: key not present");
132 &data[off]
133 }
134 }
135 }
136}
137
138impl<'a, K> Index<&IndexKey> for IndexedVar<'a, K> {
139 type Output = Expr<'a>;
140 fn index(&self, key: &IndexKey) -> &Self::Output {
141 match &self.storage {
142 Storage::Sparse(m) => m.get(key).expect("IndexedVar: key not present"),
143 Storage::Dense { data, axes, .. } => {
144 let off = grid_offset(axes, key).expect("IndexedVar: key not present");
145 &data[off]
146 }
147 }
148 }
149}
150
151impl<'a, K, const N: usize> Index<[usize; N]> for IndexedVar<'a, K> {
152 type Output = Expr<'a>;
153 fn index(&self, coords: [usize; N]) -> &Self::Output {
154 self.get_ref(&coords).expect("IndexedVar: coordinates not present")
155 }
156}
157
158fn axis_index(a: &Axis, v: i64) -> Option<usize> {
160 let d = v.checked_sub(a.start)?;
161 let u = usize::try_from(d).ok()?;
162 (u < a.len).then_some(u)
163}
164
165pub(crate) fn grid_offset(axes: &[Axis], key: &IndexKey) -> Option<usize> {
168 match (axes, key) {
169 ([a], IndexKey::Int(v)) => axis_index(a, *v),
170 (axes, IndexKey::Tuple(parts)) if parts.len() == axes.len() => {
171 let mut off = 0usize;
172 for (a, p) in axes.iter().zip(parts.iter()) {
173 off = off.checked_mul(a.len)?.checked_add(axis_index(a, p.as_i64()?)?)?;
174 }
175 Some(off)
176 }
177 _ => None,
178 }
179}
180
181fn grid_offset_coords(axes: &[Axis], coords: &[usize]) -> Option<usize> {
183 if coords.len() != axes.len() {
184 return None;
185 }
186 let mut off = 0usize;
187 for (a, &c) in axes.iter().zip(coords) {
188 off = off * a.len + axis_index(a, i64::try_from(c).ok()?)?;
189 }
190 Some(off)
191}
192
193fn coords_to_key(coords: &[usize]) -> IndexKey {
196 if let [single] = coords {
197 IndexKey::from(*single)
198 } else {
199 IndexKey::Tuple(coords.iter().map(|&c| IndexKey::from(c)).collect())
200 }
201}