hyperparameter/
api.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3
4use const_str;
5use xxhash_rust;
6
7use crate::storage::{
8    frozen_global_storage, Entry, GetOrElse, MultipleVersion, Params, THREAD_STORAGE,
9};
10use crate::value::{Value, EMPTY};
11use crate::xxh::XXHashable;
12
13/// ParameterScope
14///
15/// `ParameterScope` is a data structure that stores the current set of named parameters
16/// and their values. `ParameterScope` is used to manage the scope of named parameters,
17/// allowing parameters to be defined and used within a specific scope,
18/// and then restored to the previous scope when the scope is exited.
19///
20/// The parameter scope can be used to implement a variety of features, such
21/// as named parameters, default parameter values, and parameter inheritance.
22#[derive(Debug, Clone)]
23pub enum ParamScope {
24    /// No parameters are defined in the current scope.
25    Nothing,
26    /// The current scope contains a set of named parameters stored in `Params`.
27    Just(Params),
28}
29
30impl Default for ParamScope {
31    fn default() -> Self {
32        ParamScope::Just(Params::new())
33    }
34}
35
36impl<T: Into<String> + Clone> From<&Vec<T>> for ParamScope {
37    fn from(value: &Vec<T>) -> Self {
38        let mut ps = ParamScope::default();
39        value.iter().for_each(|x| ps.add(x.clone()));
40        ps
41    }
42}
43
44impl ParamScope {
45    /// Get a parameter with a given hash key.
46    pub fn get_with_hash(&self, key: u64) -> Value {
47        if let ParamScope::Just(changes) = self {
48            if let Some(e) = changes.get(&key) {
49                match e.value() {
50                    Value::Empty => {}
51                    v => return v.clone(),
52                }
53            }
54        }
55        THREAD_STORAGE.with(|ts| {
56            let ts = ts.borrow();
57            ts.get_entry(key).map(|e| e.clone_value()).unwrap_or(EMPTY)
58        })
59    }
60
61    /// Get a parameter with a given key.
62    pub fn get<K>(&self, key: K) -> Value
63    where
64        K: Into<String> + Clone + XXHashable,
65    {
66        let hkey = key.xxh();
67        self.get_with_hash(hkey)
68    }
69
70    pub fn add<T: Into<String>>(&mut self, expr: T) {
71        let expr: String = expr.into();
72        if let Some((k, v)) = expr.split_once('=') {
73            self.put(k.to_string(), v.to_string())
74        }
75    }
76
77    /// Get a list of all parameter keys.
78    pub fn keys(&self) -> Vec<String> {
79        let mut retval: HashSet<String> = THREAD_STORAGE.with(|ts| {
80            let ts = ts.borrow();
81            ts.keys().iter().cloned().collect()
82        });
83        if let ParamScope::Just(changes) = self {
84            retval.extend(changes.values().map(|e| e.key.clone()));
85        }
86        retval.iter().cloned().collect()
87    }
88
89    /// Enter a new parameter scope.
90    pub fn enter(&mut self) {
91        THREAD_STORAGE.with(|ts| {
92            let mut ts = ts.borrow_mut();
93            ts.enter();
94            if let ParamScope::Just(changes) = self {
95                for v in changes.values() {
96                    ts.put(v.key.clone(), v.value().clone());
97                }
98            }
99        });
100        *self = ParamScope::Nothing;
101    }
102
103    /// Exit the current parameter scope.
104    pub fn exit(&mut self) {
105        THREAD_STORAGE.with(|ts| {
106            let tree = ts.borrow_mut().exit();
107            *self = ParamScope::Just(tree);
108        })
109    }
110}
111
112/// Parameter scope operations.
113pub trait ParamScopeOps<K, V> {
114    fn get_or_else(&self, key: K, default: V) -> V;
115    fn put(&mut self, key: K, val: V);
116}
117
118impl<V> ParamScopeOps<u64, V> for ParamScope
119where
120    V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value>,
121{
122    fn get_or_else(&self, key: u64, default: V) -> V {
123        if let ParamScope::Just(changes) = self {
124            if let Some(val) = changes.get(&key) {
125                let r = val.value().clone().try_into();
126                if r.is_ok() {
127                    return r.ok().unwrap();
128                }
129            }
130        }
131        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(key, default))
132    }
133
134    /// Put a parameter.
135    fn put(&mut self, key: u64, val: V) {
136        println!(
137            "hyperparameter warning: put parameter with hashed key {}",
138            key
139        );
140        if let ParamScope::Just(changes) = self {
141            if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(key) {
142                e.insert(Entry::new("", val));
143            } else {
144                changes.update(key, val);
145            }
146        }
147    }
148}
149
150impl<K, V> ParamScopeOps<K, V> for ParamScope
151where
152    K: Into<String> + Clone + XXHashable + Debug,
153    V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value> + Clone,
154{
155    /// Get a parameter or the default value if it doesn't exist.
156    fn get_or_else(&self, key: K, default: V) -> V {
157        let hkey = key.xxh();
158        self.get_or_else(hkey, default)
159    }
160
161    /// Put a parameter.
162    fn put(&mut self, key: K, val: V) {
163        let hkey = key.xxh();
164        if let ParamScope::Just(changes) = self {
165            // if changes.contains_key(&hkey) {
166            //     changes.update(hkey, val);
167            // } else {
168            //     let key: String = key.into();
169            //     changes.insert(hkey, Entry::new(key, val));
170            // }
171            if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(hkey) {
172                let key: String = key.into();
173                e.insert(Entry::new(key, val));
174            } else {
175                changes.update(hkey, val);
176            }
177        } else {
178            THREAD_STORAGE.with(|ts| ts.borrow_mut().put(key, val))
179        }
180    }
181}
182
183pub fn frozen() {
184    frozen_global_storage();
185}
186
187#[macro_export]
188macro_rules! get_param {
189    ($name:expr, $default:expr) => {{
190        const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
191        const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
192        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
193        // ParamScope::default().get_or_else(CONST_HASH, $default)
194    }};
195
196    ($name:expr, $default:expr, $help: expr) => {{
197        const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
198        const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
199        // ParamScope::default().get_or_else(CONST_HASH, $default)
200        {
201            const CONST_HELP: &str = $help;
202            #[::linkme::distributed_slice(PARAMS)]
203            static help: (&str, &str) = (CONST_KEY, CONST_HELP);
204        }
205        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
206    }};
207}
208
209/// Define or use `hyperparameters` in a code block.
210///
211/// Hyperparameters are named parameters whose values control the learning process of
212/// an ML model or the behaviors of an underlying machine learning system.
213///
214/// Hyperparameter is designed as user-friendly as global variables but overcomes two major
215/// drawbacks of global variables: non-thread safety and global scope.
216///
217/// # A quick example
218/// ```
219/// use hyperparameter::*;
220///
221/// with_params! {   // with_params begins a new parameter scope
222///     set a.b = 1; // set the value of named parameter `a.b`
223///     set a.b.c = 2.0; // `a.b.c` is another parameter.
224///
225///     assert_eq!(1, get_param!(a.b, 0));
226///
227///     with_params! {   // start a new parameter scope that inherits parameters from the previous scope
228///         set a.b = 2; // override parameter `a.b`
229///
230///         let a_b = get_param!(a.b, 0); // read parameter `a.b`, return the default value (0) if not defined
231///         assert_eq!(2, a_b);
232///     }
233/// }
234/// ```
235#[macro_export]
236macro_rules! with_params {
237    (
238        set $($key:ident).+ = $val:expr;
239
240        $($body:tt)*
241    ) =>{
242        let mut ps = ParamScope::default();
243        {
244            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
245            ps.put(CONST_KEY, $val);
246        }
247        with_params!(params ps; $($body)*)
248    };
249
250    (
251        params $ps:expr;
252        set $($key:ident).+ = $val:expr;
253
254        $($body:tt)*
255    ) => {
256        {
257            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
258            $ps.put(CONST_KEY, $val);
259        }
260        with_params!(params $ps; $($body)*)
261    };
262
263    (
264        params $ps:expr;
265        params $nested:expr;
266
267        $($body:tt)*
268    ) => {
269        $ps.enter();
270        let ret = with_params!(params $nested; $($body)*);
271        $ps.exit();
272        ret
273    };
274
275    (
276        get $name:ident = $($key:ident).+ or $default:expr;
277
278        $($body:tt)*
279    ) => {
280        let $name = get_param!($($key).+, $default);
281        with_params_readonly!($($body)*)
282    };
283
284    (
285        $(#[doc = $doc:expr])*
286        get $name:ident = $($key:ident).+ or $default:expr;
287
288        $($body:tt)*
289    ) => {
290        let $name = get_param!($($key).+, $default, $($doc)*);
291        with_params_readonly!($($body)*)
292    };
293
294    (
295        params $ps:expr;
296        get $name:ident = $($key:ident).+ or $default:expr;
297
298        $($body:tt)*
299    ) => {
300        $ps.enter();
301        let ret = {
302            let $name = get_param!($($key).+, $default);
303
304            with_params_readonly!($($body)*)
305        };
306        $ps.exit();
307        ret
308    };
309
310    (
311        params $ps:expr;
312
313        $($body:tt)*
314    ) => {{
315            $ps.enter();
316            let ret = {$($body)*};
317            $ps.exit();
318            ret
319    }};
320
321    ($($body:tt)*) => {{
322        let ret = {$($body)*};
323        ret
324    }};
325}
326
327#[macro_export]
328macro_rules! with_params_readonly {
329    (
330        get $name:ident = $($key:ident).+ or $default:expr;
331
332        $($body:tt)*
333    ) => {
334        let $name = get_param!($($key).+, $default);
335        with_params_readonly!($($body)*)
336    };
337
338    (
339        set $($key:ident).+ = $val:expr;
340
341        $($body:tt)*
342    ) =>{
343        let mut ps = ParamScope::default();
344        {
345            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
346            ps.put(CONST_KEY, $val);
347        }
348        with_params!(params ps; $($body)*)
349    };
350
351    ($($body:tt)*) => {{
352            let ret = {$($body)*};
353            ret
354    }};
355}
356
357#[cfg(test)]
358mod tests {
359    use crate::get_param;
360    use crate::storage::{GetOrElse, THREAD_STORAGE};
361    use crate::with_params;
362
363    use super::{ParamScope, ParamScopeOps};
364
365    #[test]
366    fn test_param_scope_create() {
367        let _ = ParamScope::default();
368    }
369
370    #[test]
371    fn test_param_scope_put_get() {
372        let mut ps = ParamScope::default();
373        ps.put("1", 1);
374        ps.put("2.0", 2.0);
375
376        // check thread storage is not affected
377        THREAD_STORAGE.with(|ts| {
378            let ts = ts.borrow();
379            assert_eq!(0, ts.get_or_else("1", 0));
380            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
381        });
382
383        // check changes in param_scope
384        assert_eq!(1, ps.get_or_else("1", 0));
385        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
386    }
387
388    #[test]
389    fn test_param_scope_enter() {
390        let mut ps = ParamScope::default();
391        ps.put("1", 1);
392        ps.put("2.0", 2.0);
393
394        // check thread storage is not affected
395        THREAD_STORAGE.with(|ts| {
396            let ts = ts.borrow();
397            assert_eq!(0, ts.get_or_else("1", 0));
398            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
399        });
400
401        // check changes in param_scope
402        assert_eq!(1, ps.get_or_else("1", 0));
403        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
404
405        ps.enter();
406
407        // check thread storage is affected after enter
408        THREAD_STORAGE.with(|ts| {
409            let ts = ts.borrow();
410            assert_eq!(1, ts.get_or_else("1", 0));
411            assert_eq!(2.0, ts.get_or_else("2.0", 0.0));
412        });
413
414        // check changes in param_scope
415        assert_eq!(1, ps.get_or_else("1", 0));
416        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
417
418        ps.exit();
419        // check thread storage is not affected after exit
420        THREAD_STORAGE.with(|ts| {
421            let ts = ts.borrow();
422            assert_eq!(0, ts.get_or_else("1", 0));
423            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
424        });
425        assert_eq!(1, ps.get_or_else("1", 0));
426        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
427    }
428
429    #[test]
430    fn test_param_scope_get_param() {
431        let mut ps = ParamScope::default();
432        ps.put("a.b.c", 1);
433
434        // check thread storage is not affected
435        THREAD_STORAGE.with(|ts| {
436            let ts = ts.borrow();
437            assert_eq!(0, ts.get_or_else("a.b.c", 0));
438        });
439
440        // check changes in param_scope
441        assert_eq!(1, ps.get_or_else("a.b.c", 0));
442
443        ps.enter();
444
445        let x = get_param!(a.b.c, 0);
446        println!("x={}", x);
447    }
448
449    #[test]
450    fn test_param_scope_with_param_set() {
451        with_params! {
452            set a.b.c=1;
453            set a.b =2;
454
455            assert_eq!(1, get_param!(a.b.c, 0));
456            assert_eq!(2, get_param!(a.b, 0));
457
458            with_params! {
459                set a.b.c=2.0;
460
461                assert_eq!(2.0, get_param!(a.b.c, 0.0));
462                assert_eq!(2, get_param!(a.b, 0));
463            };
464
465            assert_eq!(1, get_param!(a.b.c, 0));
466            assert_eq!(2, get_param!(a.b, 0));
467        }
468
469        assert_eq!(0, get_param!(a.b.c, 0));
470        assert_eq!(0, get_param!(a.b, 0));
471    }
472
473    #[test]
474    fn test_param_scope_with_param_get() {
475        with_params! {
476            set a.b.c=1;
477
478            with_params! {
479                get a_b_c = a.b.c or 0;
480
481                assert_eq!(1, a_b_c);
482            };
483        }
484    }
485
486    #[test]
487    fn test_param_scope_with_param_set_get() {
488        with_params! {
489            set a.b.c = 1;
490            set a.b = 2;
491
492            with_params! {
493                get a_b_c = a.b.c or 0;
494                get a_b = a.b or 0;
495
496                assert_eq!(1, a_b_c);
497                assert_eq!(2, a_b);
498            };
499        }
500    }
501
502    #[test]
503    fn test_param_scope_with_param_readonly() {
504        with_params! {
505            get a_b_c = a.b.c or 1;
506
507            assert_eq!(1, a_b_c);
508        }
509    }
510
511    #[test]
512    fn test_param_scope_with_param_mixed_get_set() {
513        with_params! {
514            get _a_b_c = a.b.c or 1;
515            set a.b.c = 3;
516            get a_b_c = a.b.c or 2;
517
518            assert_eq!(3, a_b_c);
519        }
520    }
521}
522
523// FILEPATH: /home/reiase/workspace/hyperparameter/core/src/api.rs
524// BEGIN: test_code
525
526#[cfg(test)]
527mod test_param_scope {
528    use super::*;
529    use std::convert::TryInto;
530
531    #[test]
532    fn test_param_scope_default() {
533        let ps = ParamScope::default();
534        match ps {
535            ParamScope::Just(_) => assert!(true),
536            _ => assert!(false, "Default ParamScope should be ParamScope::Just"),
537        }
538    }
539
540    #[test]
541    fn test_param_scope_from_vec() {
542        let vec = vec!["param1=value1", "param2=value2"];
543        let ps: ParamScope = (&vec).into();
544        match ps {
545            ParamScope::Just(params) => {
546                assert_eq!(params.get(&"param1".xxh()).unwrap().value(), &Value::from("value1"));
547                assert_eq!(params.get(&"param2".xxh()).unwrap().value(), &Value::from("value2"));
548            }
549            _ => assert!(false, "ParamScope should be ParamScope::Just"),
550        }
551    }
552
553    #[test]
554    fn test_param_scope_get_with_hash() {
555        let mut ps = ParamScope::default();
556        ps.add("param=value");
557        let value = ps.get_with_hash("param".xxh());
558        assert_eq!(value, Value::from("value"));
559    }
560
561    #[test]
562    fn test_param_scope_get() {
563        let mut ps = ParamScope::default();
564        ps.add("param=value");
565        let value: String = ps.get("param").try_into().unwrap();
566        assert_eq!(value, "value");
567    }
568
569    #[test]
570    fn test_param_scope_add() {
571        let mut ps = ParamScope::default();
572        ps.add("param=value");
573        match ps {
574            ParamScope::Just(params) => {
575                assert_eq!(params.get(&"param".xxh()).unwrap().value(), &Value::from("value"));
576            }
577            _ => assert!(false, "ParamScope should be ParamScope::Just"),
578        }
579    }
580
581    #[test]
582    fn test_param_scope_keys() {
583        let mut ps = ParamScope::default();
584        ps.add("param=value");
585        let keys = ps.keys();
586        assert_eq!(keys, vec!["param"]);
587    }
588
589    #[test]
590    fn test_param_scope_enter_exit() {
591        let mut ps = ParamScope::default();
592        ps.add("param=value");
593        ps.enter();
594        match ps {
595            ParamScope::Nothing => assert!(true),
596            _ => assert!(false, "ParamScope should be ParamScope::Nothing after enter"),
597        }
598        ps.exit();
599        match ps {
600            ParamScope::Just(_) => assert!(true),
601            _ => assert!(false, "ParamScope should be ParamScope::Just after exit"),
602        }
603    }
604}
605
606// END: test_code