1use std::{ collections::{BTreeSet, HashMap}, fmt::{Debug, Display}, hash::Hash, ops::{ Add, Index }, sync::OnceLock };
18
19use hashed_type_def::HashedTypeDef;
20use crate::types::{ u128slx, u32slx, f64slx, SlxInto, IntoSlx, };
23#[cfg(feature = "silx-types")] use silx_types::Float;
24
25#[cfg(feature = "serde")] use serde::{ Serialize as SerdeSerialize, Deserialize as SerdeDeserialize, };
26#[cfg(feature = "rkyv")] use rkyv::{ Archive, Serialize as RkyvSerialize, Deserialize as RkyvDeserialize, };
27use self::hidden::OrdMap;
28
29use super::SafeElement;
30
31pub const ASSIGNMENT_EPSILON: f64 = 1e-8;
32static ZERO_F64_SLX: OnceLock<f64slx> = OnceLock::new();
33static ONE_F64_SLX: OnceLock<f64slx> = OnceLock::new();
34
35pub fn zero_f64slx() -> &'static f64slx { ZERO_F64_SLX.get_or_init(|| 0.0.slx()) }
36pub fn one_f64slx() -> &'static f64slx { ONE_F64_SLX.get_or_init(|| 1.0.slx()) }
37
38#[derive(Clone,HashedTypeDef)]
39pub struct AssignmentBuilder<X> where X: Eq + Ord + Hash {
44 pub (crate) elements: hidden::OrdMap<X>,
45 pub (crate) lattice_hash: u128slx,
46 pub (crate) length_mid: u32slx,
47 pub (crate) length_max: u32slx,
48}
49
50#[derive(Clone,HashedTypeDef,)]
51#[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
52pub struct Assignment<X> where X: Eq + Hash {
58 pub elements: HashMap<X,f64slx>,
59 pub lattice_hash: u128slx,
60}
61
62impl <X> From<AssignmentBuilder<X>> for Assignment<X> where X: Eq + Ord + Hash {
63 fn from(value: AssignmentBuilder<X>) -> Self {
64 let AssignmentBuilder {
65 elements: OrdMap { elements, .. }, lattice_hash, ..
66 } = value;
67 Self { elements, lattice_hash, }
68 }
69}
70
71#[cfg(feature = "serde")]
73mod serding {
74 use std::collections::BTreeMap;
75 use super::{ Assignment as SerdingAssignment, SerdeSerialize, SerdeDeserialize, Hash, };
76 use crate::types::{ SlxInto, IntoSlx, };
79 #[derive(SerdeSerialize,SerdeDeserialize)]
80 pub struct Assignment<X> where X: Eq + Ord, {
81 elements: BTreeMap<X,f64>,
82 lattice_hash: u128,
83 }
84
85 impl<'de, X> SerdeDeserialize<'de> for SerdingAssignment<X>
86 where X: Clone + Eq + Ord + Hash + SerdeDeserialize<'de>, {
87 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88 where D: serde::Deserializer<'de> {
89 let Assignment {
90 elements, lattice_hash,
91 } = Assignment::<X>::deserialize(deserializer)?;
92 let lattice_hash = lattice_hash.slx();
93 let elements = elements.into_iter().map(|(x,w)| (x,w.slx())).collect();
94 Ok(Self { elements, lattice_hash, })
95 }
96 }
97
98 impl<X> SerdeSerialize for SerdingAssignment<X> where X: Clone + Eq + Ord + Hash + SerdeSerialize, {
99 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100 where S: serde::Serializer {
101 let SerdingAssignment {
102 elements, lattice_hash,
103 } = self;
104 let elements = elements.iter().map(|(x,w)| {
105 (x.clone(), (*w).unslx())
106 }).collect();
107 let lattice_hash = (*lattice_hash).unslx();
108 let assignment = Assignment { elements, lattice_hash, };
109 assignment.serialize(serializer)
110 }
111 }
112}
113
114impl<X,T> Add<(SafeElement<X>,T)> for AssignmentBuilder<X> where X: Eq + Ord + Hash + Clone, T: Into<f64slx> {
115 type Output = Self;
116
117 fn add(mut self, (x,w): (SafeElement<X>,T)) -> Self::Output {
118 self.push(x,w.into()).unwrap(); self
119 }
120}
121
122impl<X> Add<()> for AssignmentBuilder<X> where X: Eq + Ord + Hash + Clone {
123 type Output = Assignment<X>;
124
125 fn add(mut self, _: ()) -> Self::Output {
126 self.normalize().unwrap(); self.into()
127 }
128}
129
130impl<X> Debug for Assignment<X> where X: Eq + Hash + Debug, {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 let value = self.elements.iter()
133 .fold(" ".to_string(),|acc,(u,w)| format!("{acc}{u:?} -> {w}, "));
134 f.debug_struct("Assignment").field("elements", &value).field("lattice_hash", &self.lattice_hash).finish()
135 }
136}
137
138impl<X> Display for Assignment<X> where X: Eq + Hash + Display, {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 let value = self.elements.iter()
141 .fold("[ ".to_string(),|acc,(u,w)| format!("{acc}{u} -> {w:.4}, "));
142 f.write_str(&value)?;
143 f.write_str("]")
144 }
145}
146
147impl<X> AssignmentBuilder<X> where X: Clone + Eq + Ord + Hash, {
148 pub unsafe fn unsafe_push(&mut self, element: X, weight:f64slx) -> Result<bool, String> {
157 let Self { elements, .. } = self;
158 let w_native = weight.unslx();
159 if w_native.is_finite() && !w_native.is_sign_negative() {
160 if w_native > ASSIGNMENT_EPSILON { Ok({ elements.push(element,weight); true }) } else {Ok(false)}
161 } else { Err("non finite or negative weights are forbidden".to_string()) }
162 }
163
164 pub fn push(&mut self, safe_element: SafeElement<X>, weight: f64slx) -> Result<bool, String> {
172 let Self { elements, lattice_hash, .. } = self;
173 let w_native = weight.unslx();
174 let assign_lattice_hash = *lattice_hash;
175 let SafeElement { code: element, lattice_hash } = safe_element;
176 if lattice_hash == assign_lattice_hash {
177 if w_native.is_finite() && !w_native.is_sign_negative() {
178 if w_native > ASSIGNMENT_EPSILON { Ok({ elements.push(element,weight); true }) }
179 else {Ok(false)}
180 } else { Err("non finite or negative weights are forbidden".to_string()) }
181 } else { Err("lattice hash mismatch".to_string()) }
182 }
183
184 pub fn remove(&mut self, safe_element: &SafeElement<X>) -> Result<Option<(SafeElement<X>,f64slx)>, String> {
191 let Self { elements, lattice_hash: assign_lattice_hash, .. } = self;
192 let SafeElement { code: element, lattice_hash } = safe_element;
193 if lattice_hash == assign_lattice_hash { match elements.remove(element) {
194 Some((element,w)) => Ok(Some((SafeElement { code: element, lattice_hash: *lattice_hash },w))),
195 None => Ok(None),
196 } } else { Err("lattice hash mismatch".to_string()) }
197 }
198
199 pub fn prune<F>(&mut self, pruner: F) where F: Fn(X,X) -> X {
205 let Self { elements, length_mid, length_max, .. } = self;
206 let length_mid = (*length_mid).unslx() as usize;
207 let length_max = (*length_max).unslx() as usize;
208 if elements.len() > length_max {
209 let length_mid = length_mid.max(1); while elements.len() > length_mid {
211 let (x,v) = elements.pop_first().unwrap();
212 let (y,w) = elements.pop_first().unwrap();
213 elements.push(pruner(x,y), v + w);
214 }
215 }
216 }
217
218 pub fn scale(&mut self, scaler: f64slx) -> Result<(), String> {
222 self.self_map(|w| w * scaler)
223 }
224
225 pub fn neg_shift(&mut self, neg_shift: f64slx) -> Result<(), String> {
230 self.self_map(|w| w - neg_shift)
231 }
232
233 pub fn cumul_weight(&self) -> Result<f64slx, String> {
236 let mut cumul = 0.0;
237 let elements = &self.elements.elements;
238 for (_,rw) in elements.iter() {
239 let w = (*rw).unslx();
240 if w.is_finite() && w >= 0.0 {
241 cumul += w;
242 } else { return Err("weight is not finite or not positive".to_string()); }
243 }
244 Ok(cumul.slx())
245 }
246
247 pub fn normalize(&mut self) -> Result<(), String> {
250 let norm = self.cumul_weight()?;
251 if &norm == zero_f64slx() {
252 Err("Cumulative weight is zero, cannot be normalized".to_string())
253 } else { self.scale(norm.recip()) }
254 }
255
256 pub fn map<F>(self, mut f: F) -> Result<Self,String> where F: FnMut(f64slx) -> f64slx {
261 let Self { lattice_hash, length_mid, length_max, elements, } = self;
262 let mapped = elements.ord_elements.into_iter()
263 .map(|hidden::OrdData((x,w))| (x,f(w))).collect::<Vec<_>>();
264 for (_,w) in &mapped {
265 let w = (*w).unslx();
266 if !w.is_finite() || w.is_sign_negative() {
267 return Err(format!("mapped weight {w} is not finite or is sign negative"));
268 }
269 }
270 let mapped_filtered = mapped.into_iter().filter(|(_,w)| (*w).unslx() > ASSIGNMENT_EPSILON).collect::<Vec<_>>();
271 let elements = mapped_filtered.iter().cloned().collect::<HashMap<_,_>>();
272 let ord_elements = mapped_filtered.into_iter()
273 .map(|xw| hidden::OrdData(xw)).collect::<BTreeSet<_>>();
274 let elements = OrdMap { elements, ord_elements, };
275 Ok(Self { lattice_hash, length_mid, length_max, elements })
276 }
277
278 pub fn self_map<F>(&mut self, f: F) -> Result<(),String> where F: FnMut(f64slx) -> f64slx {
283 let mut assign_tmp = AssignmentBuilder {
284 elements: OrdMap::new(), lattice_hash: 0u128.slx(), length_mid: 0u32.slx(), length_max: 0u32.slx()
285 };
286 std::mem::swap(self, &mut assign_tmp);
287 *self = assign_tmp.map(f)?;
288 Ok(())
289 }
290}
291
292
293impl<X> Index<&SafeElement<X>> for Assignment<X> where X: Eq + Ord + Hash, {
294 type Output = f64slx;
295
296 fn index(&self, SafeElement { code: element, lattice_hash }: &SafeElement<X>) -> &Self::Output {
297 if lattice_hash != &self.lattice_hash { panic!("mismatching lattice hash"); }
298 self.elements.index(element)
299 }
300}
301
302impl<X> IntoIterator for Assignment<X> where X: Eq + Ord + Hash, {
303 type Item = (SafeElement<X>,f64slx);
304
305 type IntoIter = std::vec::IntoIter<Self::Item>;
306
307 fn into_iter(self) -> Self::IntoIter {
308 let Self { elements, lattice_hash, .. } = self;
309 elements.into_iter().map(move |(rce,w)| (SafeElement {
310 code: rce, lattice_hash
311 },w)).collect::<Vec<_>>().into_iter()
312 }
313}
314
315impl<'a, X> IntoIterator for &'a Assignment<X> where X: Eq + Ord + Hash, {
316 type Item = (SafeElement<&'a X>,f64slx);
317
318 type IntoIter = std::vec::IntoIter<Self::Item>;
319
320 fn into_iter(self) -> Self::IntoIter {
321 let Assignment { elements, lattice_hash, .. } = self;
322 elements.iter().map(move |(rce,w)| (SafeElement {
323 code: rce, lattice_hash: *lattice_hash
324 },*w)).collect::<Vec<_>>().into_iter()
325 }
326}
327
328pub (crate) mod hidden{
329 use std::ops::Index;
330
331 use super::*;
332
333 #[derive(PartialEq,HashedTypeDef,)]
334 #[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
335 #[repr(transparent)]
336 pub struct OrdData<X>(pub (X, f64slx,));
338
339 impl<X> Clone for OrdData<X> where X: Clone {
340 fn clone(&self) -> Self {
341 let Self((x,w)) = self;
342 Self((x.clone(),*w))
343 }
344 }
345
346 impl<X> PartialOrd for OrdData<X> where X: PartialOrd {
347 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
348 match self.0.1.partial_cmp(&other.0.1) {
349 Some(core::cmp::Ordering::Equal) => {}, ord => return ord,
350 } self.0.0.partial_cmp(&other.0.0)
351 }
352 }
353
354 impl<X> Eq for OrdData<X> where X: PartialEq { }
355
356 impl<X> Ord for OrdData<X> where X: Ord + Eq {
357 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
358 match self.0.1.partial_cmp(&other.0.1) {
359 Some(core::cmp::Ordering::Equal) => {}, Some(ord) => return ord,
360 None => panic!("Comparison of {} and {} failed", self.0.1, other.0.1),
361 } self.0.0.cmp(&other.0.0)
362 }
363 }
364
365 #[derive(HashedTypeDef,)]
366 #[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
367 pub struct OrdMap<X> where X: Eq + Ord + Hash, {
369 pub (crate) elements: HashMap<X,f64slx>,
370 pub (crate) ord_elements: BTreeSet<OrdData<X>>,
371 }
372
373 impl<X> Clone for OrdMap<X> where X: Eq + Ord + Hash + Clone, {
374 fn clone(&self) -> Self {
375 let Self { ord_elements, .. } = self;
376 let weighted = ord_elements.iter().cloned().collect::<Vec<_>>();
378 let elements = weighted.iter().map(|OrdData((x,w))| (x.clone(),*w)).collect();
380 let ord_elements = weighted.into_iter().collect();
381 Self { elements, ord_elements, }
382 }
383 }
384
385 impl<X> OrdMap<X> where X: Clone + Eq + Ord + Hash, {
386 pub fn new() -> Self {
388 Self { elements: HashMap::new(), ord_elements: BTreeSet::new(), }
389 }
390 pub fn with_capacity(capacity: usize) -> Self {
392 Self { elements: HashMap::with_capacity(capacity), ord_elements: BTreeSet::new(), }
393 }
394 pub fn pop_first(&mut self) -> Option<(X,f64slx)> {
396 match self.ord_elements.pop_first() {
397 Some(OrdData((xx,_))) => match self.elements.remove_entry(&xx) {
398 None => panic!("unexpected missing entry"),
399 some => some.map(|(x,w)| {
400 drop(xx); (x,w)
401 }),
402 }, None => None,
403 }
404 }
405 pub fn remove(&mut self, x: &X) -> Option<(X,f64slx)> {
407 match self.elements.remove_entry(x) {
408 Some(xw) => {
409 match self.ord_elements.take(unsafe{ std::mem::transmute(&xw)}) {
410 None => panic!("unexpected missing element"),
411 Some(OrdData((xx,w))) => {
412 drop(xw);
413 Some((xx, w))
414 },
415 }
416 },
417 None => None,
418 }
419 }
420 pub fn push(&mut self, x: X, w: f64slx) {
422 let (x,w) = match self.remove(&x) { Some((_,v)) => (x,v+w), None => (x,w), };
423 self.ord_elements.insert(OrdData((x.clone(),w))); self.elements.insert(x, w);
425 }
426 pub fn len(&self) -> usize {
428 let len1 = self.elements.len();
429 let len2 = self.ord_elements.len();
430 if len1 != len2 { panic!("unexpected error: mismatching lens") }
431 len1
432 }
433 }
434
435 impl<X> Index<&X> for OrdMap<X> where X: Eq + Ord + Hash, {
436 type Output = f64slx;
437
438 fn index(&self, index: &X) -> &Self::Output {
439 match self.elements.get(index) { Some(x) => x, None => zero_f64slx(), }
440 }
441 }
442}
443
444pub fn exp_hidden() {
446 let mut om = hidden::OrdMap::<&'static str>::new();
447 om.push("A",0.125.slx());
448 om.push("B",0.25.slx());
449 om.push("C",0.25.slx());
450 om.push("B",0.25.slx());
451 om.push("C",0.125.slx());
452 println!("om.len() -> {}", om.len());
453 println!("om.remove(&\"B\") -> {:?}", om.remove(&"B"));
454 println!("om.len() -> {}", om.len());
455 println!("om.pop_first() -> {:?}", om.pop_first());
456 println!("om.len() -> {}", om.len());
457 println!("om.remove(&\"B\") -> {:?}", om.remove(&"B"));
458 println!("om.len() -> {}", om.len());
459 println!("om.pop_first() -> {:?}", om.pop_first());
460 println!("om.len() -> {}", om.len());
461 println!("om.pop_first() -> {:?}", om.pop_first());
462 println!("om.len() -> {}", om.len());
463}