Skip to main content

hyperparameter/
api.rs

1use std::collections::HashSet;
2use std::fmt::Debug;
3
4use crate::storage::{
5    frozen_global_storage, Entry, GetOrElse, MultipleVersion, Params, THREAD_STORAGE,
6};
7use crate::value::{Value, EMPTY};
8use crate::xxh::XXHashable;
9
10/// ParameterScope
11///
12/// `ParameterScope` is a data structure that stores the current set of named parameters
13/// and their values. `ParameterScope` is used to manage the scope of named parameters,
14/// allowing parameters to be defined and used within a specific scope,
15/// and then restored to the previous scope when the scope is exited.
16///
17/// The parameter scope can be used to implement a variety of features, such
18/// as named parameters, default parameter values, and parameter inheritance.
19#[derive(Debug, Clone)]
20pub enum ParamScope {
21    /// No parameters are defined in the current scope.
22    Nothing,
23    /// The current scope contains a set of named parameters stored in `Params`.
24    Just(Params),
25}
26
27impl Default for ParamScope {
28    fn default() -> Self {
29        ParamScope::Just(Params::new())
30    }
31}
32
33impl<T: Into<String> + Clone> From<&Vec<T>> for ParamScope {
34    fn from(value: &Vec<T>) -> Self {
35        let mut ps = ParamScope::default();
36        value.iter().for_each(|x| ps.add(x.clone()));
37        ps
38    }
39}
40
41impl ParamScope {
42    /// Get a parameter with a given hash key.
43    pub fn get_with_hash(&self, key: u64) -> Value {
44        if let ParamScope::Just(changes) = self {
45            if let Some(e) = changes.get(&key) {
46                match e.value() {
47                    Value::Empty => {}
48                    v => return v.clone(),
49                }
50            }
51        }
52        THREAD_STORAGE.with(|ts| {
53            let ts = ts.borrow();
54            ts.get_entry(key).map(|e| e.clone_value()).unwrap_or(EMPTY)
55        })
56    }
57
58    /// Get a parameter with a given key.
59    pub fn get<K>(&self, key: K) -> Value
60    where
61        K: Into<String> + Clone + XXHashable,
62    {
63        let hkey = key.xxh();
64        self.get_with_hash(hkey)
65    }
66
67    pub fn add<T: Into<String>>(&mut self, expr: T) {
68        let expr: String = expr.into();
69        if let Some((k, v)) = expr.split_once('=') {
70            self.put(k.to_string(), v.to_string())
71        }
72    }
73
74    /// Get a list of all parameter keys.
75    pub fn keys(&self) -> Vec<String> {
76        let mut retval: HashSet<String> = THREAD_STORAGE.with(|ts| {
77            let ts = ts.borrow();
78            ts.keys().iter().cloned().collect()
79        });
80        if let ParamScope::Just(changes) = self {
81            retval.extend(changes.values().map(|e| e.key.clone()));
82        }
83        retval.iter().cloned().collect()
84    }
85
86    /// Enter a new parameter scope.
87    pub fn enter(&mut self) {
88        THREAD_STORAGE.with(|ts| {
89            let mut ts = ts.borrow_mut();
90            ts.enter();
91            if let ParamScope::Just(changes) = self {
92                for v in changes.values() {
93                    ts.put(v.key.clone(), v.value().clone());
94                }
95            }
96        });
97        *self = ParamScope::Nothing;
98    }
99
100    /// Exit the current parameter scope.
101    pub fn exit(&mut self) {
102        THREAD_STORAGE.with(|ts| {
103            let tree = ts.borrow_mut().exit();
104            *self = ParamScope::Just(tree);
105        })
106    }
107}
108
109/// Parameter scope operations.
110pub trait ParamScopeOps<K, V> {
111    fn get_or_else(&self, key: K, default: V) -> V;
112    fn put(&mut self, key: K, val: V);
113}
114
115impl<V> ParamScopeOps<u64, V> for ParamScope
116where
117    V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value>,
118{
119    fn get_or_else(&self, key: u64, default: V) -> V {
120        if let ParamScope::Just(changes) = self {
121            if let Some(val) = changes.get(&key) {
122                let r = val.value().clone().try_into();
123                if let Ok(v) = r {
124                    return v;
125                }
126            }
127        }
128        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(key, default))
129    }
130
131    /// Put a parameter.
132    fn put(&mut self, key: u64, val: V) {
133        if let ParamScope::Just(changes) = self {
134            if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(key) {
135                e.insert(Entry::new("", val));
136            } else {
137                changes.update(key, val);
138            }
139        }
140    }
141}
142
143impl<K, V> ParamScopeOps<K, V> for ParamScope
144where
145    K: Into<String> + Clone + XXHashable + Debug,
146    V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value> + Clone,
147{
148    /// Get a parameter or the default value if it doesn't exist.
149    fn get_or_else(&self, key: K, default: V) -> V {
150        let hkey = key.xxh();
151        self.get_or_else(hkey, default)
152    }
153
154    /// Put a parameter.
155    fn put(&mut self, key: K, val: V) {
156        let hkey = key.xxh();
157        if let ParamScope::Just(changes) = self {
158            // if changes.contains_key(&hkey) {
159            //     changes.update(hkey, val);
160            // } else {
161            //     let key: String = key.into();
162            //     changes.insert(hkey, Entry::new(key, val));
163            // }
164            if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(hkey) {
165                let key: String = key.into();
166                e.insert(Entry::new(key, val));
167            } else {
168                changes.update(hkey, val);
169            }
170        } else {
171            THREAD_STORAGE.with(|ts| ts.borrow_mut().put(key, val))
172        }
173    }
174}
175
176pub fn frozen() {
177    frozen_global_storage();
178}
179
180#[macro_export]
181macro_rules! get_param {
182    ($name:expr, $default:expr) => {{
183        const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
184        const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
185        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
186        // ParamScope::default().get_or_else(CONST_HASH, $default)
187    }};
188
189    ($name:expr, $default:expr, $help: expr) => {{
190        const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
191        const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
192        // ParamScope::default().get_or_else(CONST_HASH, $default)
193        {
194            const CONST_HELP: &str = $help;
195            #[::linkme::distributed_slice(PARAMS)]
196            static help: (&str, &str) = (CONST_KEY, CONST_HELP);
197        }
198        THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
199    }};
200}
201
202/// Define or use `hyperparameters` in a code block.
203///
204/// Hyperparameters are named parameters whose values control the learning process of
205/// an ML model or the behaviors of an underlying machine learning system.
206///
207/// Hyperparameter is designed as user-friendly as global variables but overcomes two major
208/// drawbacks of global variables: non-thread safety and global scope.
209///
210/// # A quick example
211/// ```
212/// use hyperparameter::*;
213///
214/// with_params! {   // with_params begins a new parameter scope
215///     set a.b = 1; // set the value of named parameter `a.b`
216///     set a.b.c = 2.0; // `a.b.c` is another parameter.
217///
218///     assert_eq!(1, get_param!(a.b, 0));
219///
220///     with_params! {   // start a new parameter scope that inherits parameters from the previous scope
221///         set a.b = 2; // override parameter `a.b`
222///
223///         let a_b = get_param!(a.b, 0); // read parameter `a.b`, return the default value (0) if not defined
224///         assert_eq!(2, a_b);
225///     }
226/// }
227/// ```
228#[macro_export]
229macro_rules! with_params {
230    (
231        set $($key:ident).+ = $val:expr;
232
233        $($body:tt)*
234    ) =>{
235        let mut ps = ParamScope::default();
236        {
237            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
238            ps.put(CONST_KEY, $val);
239        }
240        with_params!(params ps; $($body)*)
241    };
242
243    (
244        params $ps:expr;
245        set $($key:ident).+ = $val:expr;
246
247        $($body:tt)*
248    ) => {
249        {
250            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
251            $ps.put(CONST_KEY, $val);
252        }
253        with_params!(params $ps; $($body)*)
254    };
255
256    (
257        params $ps:expr;
258        params $nested:expr;
259
260        $($body:tt)*
261    ) => {
262        $ps.enter();
263        let ret = with_params!(params $nested; $($body)*);
264        $ps.exit();
265        ret
266    };
267
268    (
269        get $name:ident = $($key:ident).+ or $default:expr;
270
271        $($body:tt)*
272    ) => {
273        let $name = get_param!($($key).+, $default);
274        with_params_readonly!($($body)*)
275    };
276
277    (
278        $(#[doc = $doc:expr])*
279        get $name:ident = $($key:ident).+ or $default:expr;
280
281        $($body:tt)*
282    ) => {
283        let $name = get_param!($($key).+, $default, $($doc)*);
284        with_params_readonly!($($body)*)
285    };
286
287    (
288        params $ps:expr;
289        get $name:ident = $($key:ident).+ or $default:expr;
290
291        $($body:tt)*
292    ) => {
293        $ps.enter();
294        let ret = {
295            let $name = get_param!($($key).+, $default);
296
297            with_params_readonly!($($body)*)
298        };
299        $ps.exit();
300        ret
301    };
302
303    (
304        params $ps:expr;
305
306        $($body:tt)*
307    ) => {{
308            $ps.enter();
309            let ret = {$($body)*};
310            $ps.exit();
311            ret
312    }};
313
314    ($($body:tt)*) => {{
315        let ret = {$($body)*};
316        ret
317    }};
318}
319
320#[macro_export]
321macro_rules! with_params_readonly {
322    (
323        get $name:ident = $($key:ident).+ or $default:expr;
324
325        $($body:tt)*
326    ) => {
327        let $name = get_param!($($key).+, $default);
328        with_params_readonly!($($body)*)
329    };
330
331    (
332        set $($key:ident).+ = $val:expr;
333
334        $($body:tt)*
335    ) =>{
336        let mut ps = ParamScope::default();
337        {
338            const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
339            ps.put(CONST_KEY, $val);
340        }
341        with_params!(params ps; $($body)*)
342    };
343
344    ($($body:tt)*) => {{
345            let ret = {$($body)*};
346            ret
347    }};
348}
349
350#[cfg(test)]
351mod tests {
352    use crate::storage::{GetOrElse, THREAD_STORAGE};
353
354    use super::{ParamScope, ParamScopeOps};
355
356    #[test]
357    fn test_param_scope_create() {
358        let _ = ParamScope::default();
359    }
360
361    #[test]
362    fn test_param_scope_put_get() {
363        let mut ps = ParamScope::default();
364        ps.put("1", 1);
365        ps.put("2.0", 2.0);
366
367        // check thread storage is not affected
368        THREAD_STORAGE.with(|ts| {
369            let ts = ts.borrow();
370            assert_eq!(0, ts.get_or_else("1", 0));
371            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
372        });
373
374        // check changes in param_scope
375        assert_eq!(1, ps.get_or_else("1", 0));
376        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
377    }
378
379    #[test]
380    fn test_param_scope_enter() {
381        let mut ps = ParamScope::default();
382        ps.put("1", 1);
383        ps.put("2.0", 2.0);
384
385        // check thread storage is not affected
386        THREAD_STORAGE.with(|ts| {
387            let ts = ts.borrow();
388            assert_eq!(0, ts.get_or_else("1", 0));
389            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
390        });
391
392        // check changes in param_scope
393        assert_eq!(1, ps.get_or_else("1", 0));
394        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
395
396        ps.enter();
397
398        // check thread storage is affected after enter
399        THREAD_STORAGE.with(|ts| {
400            let ts = ts.borrow();
401            assert_eq!(1, ts.get_or_else("1", 0));
402            assert_eq!(2.0, ts.get_or_else("2.0", 0.0));
403        });
404
405        // check changes in param_scope
406        assert_eq!(1, ps.get_or_else("1", 0));
407        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
408
409        ps.exit();
410        // check thread storage is not affected after exit
411        THREAD_STORAGE.with(|ts| {
412            let ts = ts.borrow();
413            assert_eq!(0, ts.get_or_else("1", 0));
414            assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
415        });
416        assert_eq!(1, ps.get_or_else("1", 0));
417        assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
418    }
419
420    #[test]
421    fn test_param_scope_get_param() {
422        let mut ps = ParamScope::default();
423        ps.put("a.b.c", 1);
424
425        // check thread storage is not affected
426        THREAD_STORAGE.with(|ts| {
427            let ts = ts.borrow();
428            assert_eq!(0, ts.get_or_else("a.b.c", 0));
429        });
430
431        // check changes in param_scope
432        assert_eq!(1, ps.get_or_else("a.b.c", 0));
433
434        ps.enter();
435
436        let x = get_param!(a.b.c, 0);
437        println!("x={}", x);
438    }
439
440    #[test]
441    fn test_param_scope_with_param_set() {
442        with_params! {
443            set a.b.c=1;
444            set a.b =2;
445
446            assert_eq!(1, get_param!(a.b.c, 0));
447            assert_eq!(2, get_param!(a.b, 0));
448
449            with_params! {
450                set a.b.c=2.0;
451
452                assert_eq!(2.0, get_param!(a.b.c, 0.0));
453                assert_eq!(2, get_param!(a.b, 0));
454            };
455
456            assert_eq!(1, get_param!(a.b.c, 0));
457            assert_eq!(2, get_param!(a.b, 0));
458        }
459
460        assert_eq!(0, get_param!(a.b.c, 0));
461        assert_eq!(0, get_param!(a.b, 0));
462    }
463
464    #[test]
465    fn test_param_scope_with_param_get() {
466        with_params! {
467            set a.b.c=1;
468
469            with_params! {
470                get a_b_c = a.b.c or 0;
471
472                assert_eq!(1, a_b_c);
473            };
474        }
475    }
476
477    #[test]
478    fn test_param_scope_with_param_set_get() {
479        with_params! {
480            set a.b.c = 1;
481            set a.b = 2;
482
483            with_params! {
484                get a_b_c = a.b.c or 0;
485                get a_b = a.b or 0;
486
487                assert_eq!(1, a_b_c);
488                assert_eq!(2, a_b);
489            };
490        }
491    }
492
493    #[test]
494    fn test_param_scope_with_param_readonly() {
495        with_params! {
496            get a_b_c = a.b.c or 1;
497
498            assert_eq!(1, a_b_c);
499        }
500    }
501
502    #[test]
503    fn test_param_scope_with_param_mixed_get_set() {
504        with_params! {
505            get _a_b_c = a.b.c or 1;
506            set a.b.c = 3;
507            get a_b_c = a.b.c or 2;
508
509            assert_eq!(3, a_b_c);
510        }
511    }
512}
513
514// FILEPATH: /home/reiase/workspace/hyperparameter/core/src/api.rs
515// BEGIN: test_code
516
517#[cfg(test)]
518mod test_param_scope {
519    use super::*;
520    use std::convert::TryInto;
521
522    #[test]
523    fn test_param_scope_default() {
524        let ps = ParamScope::default();
525        match ps {
526            ParamScope::Just(_) => assert!(true),
527            _ => assert!(false, "Default ParamScope should be ParamScope::Just"),
528        }
529    }
530
531    #[test]
532    fn test_param_scope_from_vec() {
533        let vec = vec!["param1=value1", "param2=value2"];
534        let ps: ParamScope = (&vec).into();
535        match ps {
536            ParamScope::Just(params) => {
537                assert_eq!(
538                    params
539                        .get(&"param1".xxh())
540                        .expect("param1 should exist")
541                        .value(),
542                    &Value::from("value1")
543                );
544                assert_eq!(
545                    params
546                        .get(&"param2".xxh())
547                        .expect("param2 should exist")
548                        .value(),
549                    &Value::from("value2")
550                );
551            }
552            _ => panic!("ParamScope should be ParamScope::Just"),
553        }
554    }
555
556    #[test]
557    fn test_param_scope_get_with_hash() {
558        let mut ps = ParamScope::default();
559        ps.add("param=value");
560        let value = ps.get_with_hash("param".xxh());
561        assert_eq!(value, Value::from("value"));
562    }
563
564    #[test]
565    fn test_param_scope_get() {
566        let mut ps = ParamScope::default();
567        ps.add("param=value");
568        let value: String = ps
569            .get("param")
570            .try_into()
571            .expect("Failed to convert param to String");
572        assert_eq!(value, "value");
573    }
574
575    #[test]
576    fn test_param_scope_add() {
577        let mut ps = ParamScope::default();
578        ps.add("param=value");
579        match ps {
580            ParamScope::Just(params) => {
581                assert_eq!(
582                    params
583                        .get(&"param".xxh())
584                        .expect("param should exist")
585                        .value(),
586                    &Value::from("value")
587                );
588            }
589            _ => panic!("ParamScope should be ParamScope::Just"),
590        }
591    }
592
593    #[test]
594    fn test_param_scope_keys() {
595        let mut ps = ParamScope::default();
596        ps.add("param=value");
597        let keys = ps.keys();
598        assert_eq!(keys, vec!["param"]);
599    }
600
601    #[test]
602    fn test_param_scope_enter_exit() {
603        let mut ps = ParamScope::default();
604        ps.add("param=value");
605        ps.enter();
606        match ps {
607            ParamScope::Nothing => assert!(true),
608            _ => assert!(
609                false,
610                "ParamScope should be ParamScope::Nothing after enter"
611            ),
612        }
613        ps.exit();
614        match ps {
615            ParamScope::Just(_) => assert!(true),
616            _ => assert!(false, "ParamScope should be ParamScope::Just after exit"),
617        }
618    }
619}
620
621// END: test_code