furtif_core/structs/assignment_tools/
assignment.rs

1// This program is free software: you can redistribute it and/or modify
2// it under the terms of the Lesser GNU General Public License as published
3// by the Free Software Foundation, either version 3 of the License, or
4// (at your option) any later version.
5
6// This program is distributed in the hope that it will be useful,
7// but WITHOUT ANY WARRANTY; without even the implied warranty of
8// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
9// Lesser GNU General Public License for more details.
10
11// You should have received a copy of the Lesser GNU General Public License
12// along with this program.  If not, see <https://www.gnu.org/licenses/>.
13
14// Copyright 2024 Frederic Dambreville, Jean Dezert Developers.
15
16
17use std::{ collections::{BTreeSet, HashMap}, fmt::{Debug, Display}, hash::Hash, ops::{ Add, Index }, sync::OnceLock };
18
19use hashed_type_def::HashedTypeDef;
20// #[cfg(not(feature = "silx-types"))] use crate::fake_slx::{ u128slx, u32slx, f64slx, FakeSlx, };
21// #[cfg(feature = "silx-types")] use silx_types::{ u128slx, u32slx, f64slx, SlxInto, IntoSlx, Float, };
22use crate::types::{ u128slx, u32slx, f64slx, SlxInto, IntoSlx, };
23#[cfg(feature = "silx-types")] use silx_types::Float;
24
25#[cfg(feature = "serde")] use serde::{ Serialize as SerdeSerialize, Deserialize as SerdeDeserialize, };
26#[cfg(feature = "rkyv")] use rkyv::{ Archive, Serialize as RkyvSerialize, Deserialize as RkyvDeserialize, };
27use self::hidden::OrdMap;
28
29use super::SafeElement;
30
31pub const  ASSIGNMENT_EPSILON: f64 = 1e-8;
32static ZERO_F64_SLX: OnceLock<f64slx> = OnceLock::new();
33static ONE_F64_SLX: OnceLock<f64slx> = OnceLock::new();
34
35pub fn zero_f64slx() -> &'static f64slx { ZERO_F64_SLX.get_or_init(|| 0.0.slx()) }
36pub fn one_f64slx() -> &'static f64slx { ONE_F64_SLX.get_or_init(|| 1.0.slx()) }
37
38#[derive(Clone,HashedTypeDef)]
39/// Intermediate structure for building assignments
40/// * In particular, this structrure will handle mass discounting
41/// * There is no constructor for `AssignmentBuilder`: methods `init_assignment` or `init_assignment_with_capacity` of trait `Lattice` are used
42/// * `X` : type of lattice element encoding
43pub struct AssignmentBuilder<X> where X: Eq + Ord + Hash {
44    pub (crate) elements: hidden::OrdMap<X>,
45    pub (crate) lattice_hash: u128slx,
46    pub (crate) length_mid: u32slx,
47    pub (crate) length_max: u32slx,
48}
49
50#[derive(Clone,HashedTypeDef,)]
51#[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
52/// Mass assignment: should be build from AssignmentBuilder
53/// * Assignment contains a lattice hash and a sequence of weighted encoded element from this lattice
54/// * Assignments are used in order to encode basic belief function and other belief functions 
55/// * There is no constructor for `Assignment`: use method `From::from(...)` to convert from a builder `AssignmentBuilder`
56/// * `X` : type of lattice element encoding
57pub struct Assignment<X> where X: Eq + Hash {
58    pub elements: HashMap<X,f64slx>,
59    pub lattice_hash: u128slx,
60}
61
62impl <X> From<AssignmentBuilder<X>> for Assignment<X> where X: Eq + Ord + Hash {
63    fn from(value: AssignmentBuilder<X>) -> Self {
64        let AssignmentBuilder { 
65            elements: OrdMap { elements, .. }, lattice_hash, .. 
66        } = value;
67        Self { elements, lattice_hash, }
68    }
69}
70
71// implementation of Serde serialization
72#[cfg(feature = "serde")] 
73mod serding {
74    use std::collections::BTreeMap;
75    use super::{ Assignment as SerdingAssignment, SerdeSerialize, SerdeDeserialize, Hash, };
76    // #[cfg(feature = "silx-types")] use super::{ IntoSlx, SlxInto, };
77    // #[cfg(not(feature = "silx-types"))] use crate::fake_slx::FakeSlx;
78    use crate::types::{ SlxInto, IntoSlx, };
79    #[derive(SerdeSerialize,SerdeDeserialize)]
80    pub struct Assignment<X> where X: Eq + Ord, {
81        elements: BTreeMap<X,f64>,
82        lattice_hash: u128,
83    }
84
85    impl<'de, X> SerdeDeserialize<'de> for SerdingAssignment<X>  
86                                        where X: Clone + Eq + Ord + Hash + SerdeDeserialize<'de>, {
87        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88        where D: serde::Deserializer<'de> {
89            let Assignment { 
90                elements, lattice_hash,
91            } = Assignment::<X>::deserialize(deserializer)?;
92            let lattice_hash = lattice_hash.slx();
93            let elements = elements.into_iter().map(|(x,w)| (x,w.slx())).collect(); 
94            Ok(Self { elements, lattice_hash, })
95        }
96    }
97    
98    impl<X> SerdeSerialize for SerdingAssignment<X>  where X: Clone + Eq + Ord + Hash + SerdeSerialize, {
99        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
100        where S: serde::Serializer {
101            let SerdingAssignment { 
102                elements, lattice_hash,
103            } = self;
104            let elements = elements.iter().map(|(x,w)| {
105                (x.clone(), (*w).unslx())  
106            }).collect();
107            let lattice_hash = (*lattice_hash).unslx();
108            let assignment = Assignment { elements, lattice_hash, };
109            assignment.serialize(serializer)
110        }
111    }
112}
113
114impl<X,T> Add<(SafeElement<X>,T)> for AssignmentBuilder<X> where X: Eq + Ord + Hash + Clone, T: Into<f64slx> {
115    type Output = Self;
116
117    fn add(mut self, (x,w): (SafeElement<X>,T)) -> Self::Output {
118        self.push(x,w.into()).unwrap(); self
119    }
120}
121
122impl<X> Add<()> for AssignmentBuilder<X> where X: Eq + Ord + Hash + Clone {
123    type Output = Assignment<X>;
124
125    fn add(mut self, _: ()) -> Self::Output {
126        self.normalize().unwrap(); self.into()
127    }
128}
129
130impl<X> Debug for Assignment<X> where X: Eq + Hash + Debug, {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        let value = self.elements.iter()
133            .fold(" ".to_string(),|acc,(u,w)| format!("{acc}{u:?} -> {w}, "));
134        f.debug_struct("Assignment").field("elements", &value).field("lattice_hash", &self.lattice_hash).finish()
135    }
136}
137
138impl<X> Display for Assignment<X> where X: Eq + Hash + Display, {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        let value = self.elements.iter()
141            .fold("[ ".to_string(),|acc,(u,w)| format!("{acc}{u} -> {w:.4}, "));
142        f.write_str(&value)?;
143        f.write_str("]")
144    }
145}
146
147impl<X> AssignmentBuilder<X> where X: Clone + Eq + Ord + Hash, {
148    /// Unsafe push weighted element within assignment; weights smaller than `ASSIGNMENT_EPSILON` are discarded (result in `Ok(false)`)
149    /// * method is unsafe since there is no consistency check that the encoded element comes from the lattice of the builder
150    /// * `element: X` : encoded element
151    /// * `weight: f64slx` : weight of the element
152    /// * Output: 
153    ///   * `true` if weighted element is inserted
154    ///   * `false` if the weighted element is discarded
155    ///   * error if the weight is non finite or negative
156    pub unsafe fn unsafe_push(&mut self, element: X, weight:f64slx) -> Result<bool, String> {
157        let Self { elements, .. } = self;
158        let w_native = weight.unslx();
159        if w_native.is_finite() && !w_native.is_sign_negative() {
160            if w_native > ASSIGNMENT_EPSILON { Ok({ elements.push(element,weight); true }) } else {Ok(false)}
161        } else { Err("non finite or negative weights are forbidden".to_string()) }
162    }
163
164    /// Push weighted element within assignment; weights smaller than `ASSIGNMENT_EPSILON` are discarded (result in `Ok(false)`)
165    /// * `safe_element: SafeElement<X>` : element with safe encoding
166    /// * `weight: f64slx` : weight of the element
167    /// * Output: 
168    ///   * `true` if weighted element is inserted
169    ///   * `false` if the weighted element is discarded
170    ///   * error if the weight is non finite or negative
171    pub fn push(&mut self, safe_element: SafeElement<X>, weight: f64slx) -> Result<bool, String> {
172        let Self { elements, lattice_hash, .. } = self;
173        let w_native = weight.unslx();
174        let assign_lattice_hash = *lattice_hash;
175        let SafeElement { code: element, lattice_hash } = safe_element;
176        if lattice_hash == assign_lattice_hash {
177            if w_native.is_finite() && !w_native.is_sign_negative() {
178                if w_native > ASSIGNMENT_EPSILON { Ok({ elements.push(element,weight); true }) } 
179                else {Ok(false)}
180            } else { Err("non finite or negative weights are forbidden".to_string()) }
181        } else { Err("lattice hash mismatch".to_string()) } 
182    }
183
184    /// Remove element from assignment
185    /// * `safe_element: &SafeElement<X>` : reference to lement to be removed
186    /// * Output: 
187    ///   * Some safe element `Some(se)` if element was found
188    ///   * `None` if element was not found
189    ///   * error in case of lattice hash mismatch
190    pub fn remove(&mut self, safe_element: &SafeElement<X>) -> Result<Option<(SafeElement<X>,f64slx)>, String> {
191        let Self { elements, lattice_hash: assign_lattice_hash, .. } = self;
192        let SafeElement { code: element, lattice_hash } = safe_element;
193        if lattice_hash == assign_lattice_hash { match elements.remove(element) {
194            Some((element,w)) => Ok(Some((SafeElement { code: element, lattice_hash: *lattice_hash },w))),
195            None => Ok(None),
196        } } else { Err("lattice hash mismatch".to_string()) } 
197    }
198
199    /// Prune the assignment in order to reduce its size within acceptable range
200    /// * `pruner: F` : prunning function
201    ///   * two weighted encoded elements `(x,wx)` and `(y,wy)` to be prunned will be replaced by (pruner(x,y),wx + wy)
202    /// * `F` : type of pruner 
203    /// * Output: nothing
204    pub fn prune<F>(&mut self, pruner: F) where F: Fn(X,X) -> X {
205        let Self { elements, length_mid, length_max, .. } = self;
206        let length_mid = (*length_mid).unslx() as usize;
207        let length_max = (*length_max).unslx() as usize;
208        if elements.len() > length_max {
209            let length_mid = length_mid.max(1); // impossible to prune over 0 or 1 element
210            while elements.len() > length_mid {
211                let (x,v) = elements.pop_first().unwrap();
212                let (y,w) = elements.pop_first().unwrap();
213                elements.push(pruner(x,y), v + w);
214            }    
215        }
216    }
217
218    /// Scale the assignment weights
219    /// * `scaler: f64slx` : scale multiplier
220    /// * Output: nothing or error
221    pub fn scale(&mut self, scaler: f64slx) -> Result<(), String> {
222        self.self_map(|w| w * scaler)
223    }
224
225    /// Shift the assignment negatively
226    /// * Weighted element `(x,w)` is replaced by `(x,w - neg_shift)`
227    /// * `neg_shift: f64slx` : negative shift
228    /// * Output: nothing or error
229    pub fn neg_shift(&mut self, neg_shift: f64slx) -> Result<(), String> {
230        self.self_map(|w| w - neg_shift)
231    }
232
233    /// Compute the cumulative weight of the assignment
234    /// * Output: the cumulative weight or an error if some weights are non finite or negative
235    pub fn cumul_weight(&self) -> Result<f64slx, String> {
236        let mut cumul = 0.0;
237        let elements = &self.elements.elements;
238        for (_,rw) in elements.iter() {
239            let w = (*rw).unslx();
240            if w.is_finite() && w >= 0.0 {
241                cumul += w;
242            } else { return Err("weight is not finite or not positive".to_string()); }
243        }
244        Ok(cumul.slx())
245    }
246
247    /// Normalize the assignment
248    /// * Output: nothing or error
249    pub fn normalize(&mut self) -> Result<(), String> {
250        let norm = self.cumul_weight()?;
251        if &norm == zero_f64slx() {
252            Err("Cumulative weight is zero, cannot be normalized".to_string())
253        } else { self.scale(norm.recip()) }
254    }
255
256    /// Map the assignment with a closure
257    /// * `f: F` : a closure
258    /// * `F` : type of the closure
259    /// * Output: mapped assignment or an error if some mapped weights are non finite or negative
260    pub fn map<F>(self, mut f: F) -> Result<Self,String> where F: FnMut(f64slx) -> f64slx {
261        let Self { lattice_hash, length_mid, length_max, elements, } = self;
262        let mapped = elements.ord_elements.into_iter()
263            .map(|hidden::OrdData((x,w))| (x,f(w))).collect::<Vec<_>>();
264        for (_,w) in &mapped { 
265            let w = (*w).unslx();
266            if !w.is_finite() || w.is_sign_negative() {
267                return Err(format!("mapped weight {w} is not finite or is sign negative"));
268            } 
269        }
270        let mapped_filtered = mapped.into_iter().filter(|(_,w)| (*w).unslx() > ASSIGNMENT_EPSILON).collect::<Vec<_>>();
271        let elements = mapped_filtered.iter().cloned().collect::<HashMap<_,_>>();
272        let ord_elements = mapped_filtered.into_iter()
273                                    .map(|xw| hidden::OrdData(xw)).collect::<BTreeSet<_>>();
274        let elements = OrdMap { elements, ord_elements, };
275        Ok(Self { lattice_hash, length_mid, length_max, elements })
276    }
277
278    /// Self-map the assignment with a closure
279    /// * `f: F` : a closure
280    /// * `F` : type of the closure
281    /// * Output: nothing or an error if some mapped weights are non finite or negative
282    pub fn self_map<F>(&mut self, f: F) -> Result<(),String> where F: FnMut(f64slx) -> f64slx {
283        let mut assign_tmp = AssignmentBuilder { 
284            elements: OrdMap::new(), lattice_hash: 0u128.slx(), length_mid: 0u32.slx(), length_max: 0u32.slx() 
285        };
286        std::mem::swap(self, &mut assign_tmp);
287        *self = assign_tmp.map(f)?;
288        Ok(())
289    }
290}
291
292
293impl<X> Index<&SafeElement<X>> for Assignment<X> where X: Eq + Ord + Hash, {
294    type Output = f64slx;
295
296    fn index(&self, SafeElement { code: element, lattice_hash }: &SafeElement<X>) -> &Self::Output {
297        if lattice_hash != &self.lattice_hash { panic!("mismatching lattice hash"); }
298        self.elements.index(element)
299    }
300}
301
302impl<X> IntoIterator for Assignment<X> where X: Eq + Ord + Hash, {
303    type Item = (SafeElement<X>,f64slx);
304
305    type IntoIter = std::vec::IntoIter<Self::Item>;
306
307    fn into_iter(self) -> Self::IntoIter {
308        let Self { elements, lattice_hash, .. } = self;
309        elements.into_iter().map(move |(rce,w)| (SafeElement { 
310            code: rce, lattice_hash 
311        },w)).collect::<Vec<_>>().into_iter()   
312    }
313}
314
315impl<'a, X> IntoIterator for &'a Assignment<X> where X: Eq + Ord + Hash, {
316    type Item = (SafeElement<&'a X>,f64slx);
317
318    type IntoIter = std::vec::IntoIter<Self::Item>;
319
320    fn into_iter(self) -> Self::IntoIter {
321        let Assignment { elements,  lattice_hash, .. } = self;
322        elements.iter().map(move |(rce,w)| (SafeElement { 
323            code: rce, lattice_hash: *lattice_hash
324        },*w)).collect::<Vec<_>>().into_iter()   
325    }
326}
327
328pub (crate) mod hidden{
329    use std::ops::Index;
330
331    use super::*;
332
333    #[derive(PartialEq,HashedTypeDef,)]
334    #[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
335    #[repr(transparent)]
336    /// a structure for internal use
337    pub struct OrdData<X>(pub (X, f64slx,));
338
339    impl<X> Clone for OrdData<X> where X: Clone {
340        fn clone(&self) -> Self {
341            let Self((x,w)) = self;
342            Self((x.clone(),*w))
343        }
344    }
345
346    impl<X> PartialOrd for OrdData<X> where X: PartialOrd {
347        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
348            match self.0.1.partial_cmp(&other.0.1) {
349                Some(core::cmp::Ordering::Equal) => {}, ord => return ord,
350            } self.0.0.partial_cmp(&other.0.0)
351        }
352    }
353
354    impl<X> Eq for OrdData<X> where X: PartialEq { }
355
356    impl<X> Ord for OrdData<X> where X: Ord + Eq {
357        fn cmp(&self, other: &Self) -> std::cmp::Ordering { 
358            match self.0.1.partial_cmp(&other.0.1) {
359                Some(core::cmp::Ordering::Equal) => {}, Some(ord) => return ord,
360                None => panic!("Comparison of {} and {} failed", self.0.1, other.0.1),
361            } self.0.0.cmp(&other.0.0)
362        }
363    }
364
365    #[derive(HashedTypeDef,)]
366    #[cfg_attr(feature = "rkyv", derive(Archive,RkyvSerialize,RkyvDeserialize))]
367    /// a map structure for internal use; element are sorted with weight
368    pub struct OrdMap<X> where X: Eq + Ord + Hash, {
369        pub (crate) elements: HashMap<X,f64slx>,
370        pub (crate) ord_elements: BTreeSet<OrdData<X>>,    
371    }
372
373    impl<X> Clone for OrdMap<X> where X: Eq + Ord + Hash + Clone, {
374        fn clone(&self) -> Self {
375            let Self { ord_elements, .. } = self;
376            // build a deep clone of the data (cloning OrdData is deep)
377            let weighted = ord_elements.iter().cloned().collect::<Vec<_>>();
378            // collect on weak clone (OrdData is removed first) 
379            let elements = weighted.iter().map(|OrdData((x,w))| (x.clone(),*w)).collect();
380            let ord_elements = weighted.into_iter().collect();
381            Self { elements, ord_elements, }
382        }
383    }
384
385    impl<X> OrdMap<X> where X: Clone + Eq + Ord + Hash, {
386        /// Constructor
387        pub fn new() -> Self {
388            Self { elements: HashMap::new(), ord_elements: BTreeSet::new(), }
389        }
390        /// Constructor with given capacity
391        pub fn with_capacity(capacity: usize) -> Self {
392            Self { elements: HashMap::with_capacity(capacity), ord_elements: BTreeSet::new(), }
393        }
394        /// Take first element (i.e. with the smallest weight)
395        pub fn pop_first(&mut self) -> Option<(X,f64slx)> {
396            match self.ord_elements.pop_first() {
397                Some(OrdData((xx,_))) => match self.elements.remove_entry(&xx) {
398                    None => panic!("unexpected missing entry"), 
399                    some => some.map(|(x,w)| {
400                        drop(xx); (x,w)
401                    }),
402                }, None => None,
403            }
404        }
405        /// remove element and get its weight
406        pub fn remove(&mut self, x: &X) -> Option<(X,f64slx)> {
407            match self.elements.remove_entry(x) {
408                Some(xw) => {
409                    match self.ord_elements.take(unsafe{ std::mem::transmute(&xw)}) {
410                        None => panic!("unexpected missing element"), 
411                        Some(OrdData((xx,w))) => {
412                            drop(xw);
413                            Some((xx, w))
414                        },
415                    }
416                },
417                None => None,
418            }
419        }
420        /// push new weighted element
421        pub fn push(&mut self, x: X, w: f64slx) {
422            let (x,w) = match self.remove(&x) { Some((_,v)) => (x,v+w), None => (x,w), };
423            self.ord_elements.insert(OrdData((x.clone(),w))); // Nota: weak cloning here
424            self.elements.insert(x, w);
425        }
426        /// Collection length
427        pub fn len(&self) -> usize {
428            let len1 = self.elements.len();
429            let len2 = self.ord_elements.len();
430            if len1 != len2 { panic!("unexpected error: mismatching lens") }
431            len1
432        }
433    }
434
435    impl<X> Index<&X> for OrdMap<X> where X: Eq + Ord + Hash, {
436        type Output = f64slx;
437
438        fn index(&self, index: &X) -> &Self::Output {
439            match self.elements.get(index) { Some(x) => x, None => zero_f64slx(), }
440        }
441    }
442}
443
444/// Experimentation for internal test
445pub fn exp_hidden() {
446    let mut om = hidden::OrdMap::<&'static str>::new();
447    om.push("A",0.125.slx());
448    om.push("B",0.25.slx());
449    om.push("C",0.25.slx());
450    om.push("B",0.25.slx());
451    om.push("C",0.125.slx());
452    println!("om.len() -> {}", om.len());
453    println!("om.remove(&\"B\") -> {:?}", om.remove(&"B"));
454    println!("om.len() -> {}", om.len());
455    println!("om.pop_first() -> {:?}", om.pop_first());
456    println!("om.len() -> {}", om.len());
457    println!("om.remove(&\"B\") -> {:?}", om.remove(&"B"));
458    println!("om.len() -> {}", om.len());
459    println!("om.pop_first() -> {:?}", om.pop_first());
460    println!("om.len() -> {}", om.len());
461    println!("om.pop_first() -> {:?}", om.pop_first());
462    println!("om.len() -> {}", om.len());
463}