1use std::collections::HashSet;
2use std::fmt::Debug;
3
4use const_str;
5use xxhash_rust;
6
7use crate::storage::{
8 frozen_global_storage, Entry, GetOrElse, MultipleVersion, Params, THREAD_STORAGE,
9};
10use crate::value::{Value, EMPTY};
11use crate::xxh::XXHashable;
12
13#[derive(Debug, Clone)]
23pub enum ParamScope {
24 Nothing,
26 Just(Params),
28}
29
30impl Default for ParamScope {
31 fn default() -> Self {
32 ParamScope::Just(Params::new())
33 }
34}
35
36impl<T: Into<String> + Clone> From<&Vec<T>> for ParamScope {
37 fn from(value: &Vec<T>) -> Self {
38 let mut ps = ParamScope::default();
39 value.iter().for_each(|x| ps.add(x.clone()));
40 ps
41 }
42}
43
44impl ParamScope {
45 pub fn get_with_hash(&self, key: u64) -> Value {
47 if let ParamScope::Just(changes) = self {
48 if let Some(e) = changes.get(&key) {
49 match e.value() {
50 Value::Empty => {}
51 v => return v.clone(),
52 }
53 }
54 }
55 THREAD_STORAGE.with(|ts| {
56 let ts = ts.borrow();
57 ts.get_entry(key).map(|e| e.clone_value()).unwrap_or(EMPTY)
58 })
59 }
60
61 pub fn get<K>(&self, key: K) -> Value
63 where
64 K: Into<String> + Clone + XXHashable,
65 {
66 let hkey = key.xxh();
67 self.get_with_hash(hkey)
68 }
69
70 pub fn add<T: Into<String>>(&mut self, expr: T) {
71 let expr: String = expr.into();
72 if let Some((k, v)) = expr.split_once('=') {
73 self.put(k.to_string(), v.to_string())
74 }
75 }
76
77 pub fn keys(&self) -> Vec<String> {
79 let mut retval: HashSet<String> = THREAD_STORAGE.with(|ts| {
80 let ts = ts.borrow();
81 ts.keys().iter().cloned().collect()
82 });
83 if let ParamScope::Just(changes) = self {
84 retval.extend(changes.values().map(|e| e.key.clone()));
85 }
86 retval.iter().cloned().collect()
87 }
88
89 pub fn enter(&mut self) {
91 THREAD_STORAGE.with(|ts| {
92 let mut ts = ts.borrow_mut();
93 ts.enter();
94 if let ParamScope::Just(changes) = self {
95 for v in changes.values() {
96 ts.put(v.key.clone(), v.value().clone());
97 }
98 }
99 });
100 *self = ParamScope::Nothing;
101 }
102
103 pub fn exit(&mut self) {
105 THREAD_STORAGE.with(|ts| {
106 let tree = ts.borrow_mut().exit();
107 *self = ParamScope::Just(tree);
108 })
109 }
110}
111
112pub trait ParamScopeOps<K, V> {
114 fn get_or_else(&self, key: K, default: V) -> V;
115 fn put(&mut self, key: K, val: V);
116}
117
118impl<V> ParamScopeOps<u64, V> for ParamScope
119where
120 V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value>,
121{
122 fn get_or_else(&self, key: u64, default: V) -> V {
123 if let ParamScope::Just(changes) = self {
124 if let Some(val) = changes.get(&key) {
125 let r = val.value().clone().try_into();
126 if r.is_ok() {
127 return r.ok().unwrap();
128 }
129 }
130 }
131 THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(key, default))
132 }
133
134 fn put(&mut self, key: u64, val: V) {
136 println!(
137 "hyperparameter warning: put parameter with hashed key {}",
138 key
139 );
140 if let ParamScope::Just(changes) = self {
141 if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(key) {
142 e.insert(Entry::new("", val));
143 } else {
144 changes.update(key, val);
145 }
146 }
147 }
148}
149
150impl<K, V> ParamScopeOps<K, V> for ParamScope
151where
152 K: Into<String> + Clone + XXHashable + Debug,
153 V: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value> + Clone,
154{
155 fn get_or_else(&self, key: K, default: V) -> V {
157 let hkey = key.xxh();
158 self.get_or_else(hkey, default)
159 }
160
161 fn put(&mut self, key: K, val: V) {
163 let hkey = key.xxh();
164 if let ParamScope::Just(changes) = self {
165 if let std::collections::btree_map::Entry::Vacant(e) = changes.entry(hkey) {
172 let key: String = key.into();
173 e.insert(Entry::new(key, val));
174 } else {
175 changes.update(hkey, val);
176 }
177 } else {
178 THREAD_STORAGE.with(|ts| ts.borrow_mut().put(key, val))
179 }
180 }
181}
182
183pub fn frozen() {
184 frozen_global_storage();
185}
186
187#[macro_export]
188macro_rules! get_param {
189 ($name:expr, $default:expr) => {{
190 const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
191 const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
192 THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
193 }};
195
196 ($name:expr, $default:expr, $help: expr) => {{
197 const CONST_KEY: &str = const_str::replace!(stringify!($name), ";", "");
198 const CONST_HASH: u64 = xxhash_rust::const_xxh64::xxh64(CONST_KEY.as_bytes(), 42);
199 {
201 const CONST_HELP: &str = $help;
202 #[::linkme::distributed_slice(PARAMS)]
203 static help: (&str, &str) = (CONST_KEY, CONST_HELP);
204 }
205 THREAD_STORAGE.with(|ts| ts.borrow_mut().get_or_else(CONST_HASH, $default))
206 }};
207}
208
209#[macro_export]
236macro_rules! with_params {
237 (
238 set $($key:ident).+ = $val:expr;
239
240 $($body:tt)*
241 ) =>{
242 let mut ps = ParamScope::default();
243 {
244 const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
245 ps.put(CONST_KEY, $val);
246 }
247 with_params!(params ps; $($body)*)
248 };
249
250 (
251 params $ps:expr;
252 set $($key:ident).+ = $val:expr;
253
254 $($body:tt)*
255 ) => {
256 {
257 const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
258 $ps.put(CONST_KEY, $val);
259 }
260 with_params!(params $ps; $($body)*)
261 };
262
263 (
264 params $ps:expr;
265 params $nested:expr;
266
267 $($body:tt)*
268 ) => {
269 $ps.enter();
270 let ret = with_params!(params $nested; $($body)*);
271 $ps.exit();
272 ret
273 };
274
275 (
276 get $name:ident = $($key:ident).+ or $default:expr;
277
278 $($body:tt)*
279 ) => {
280 let $name = get_param!($($key).+, $default);
281 with_params_readonly!($($body)*)
282 };
283
284 (
285 $(#[doc = $doc:expr])*
286 get $name:ident = $($key:ident).+ or $default:expr;
287
288 $($body:tt)*
289 ) => {
290 let $name = get_param!($($key).+, $default, $($doc)*);
291 with_params_readonly!($($body)*)
292 };
293
294 (
295 params $ps:expr;
296 get $name:ident = $($key:ident).+ or $default:expr;
297
298 $($body:tt)*
299 ) => {
300 $ps.enter();
301 let ret = {
302 let $name = get_param!($($key).+, $default);
303
304 with_params_readonly!($($body)*)
305 };
306 $ps.exit();
307 ret
308 };
309
310 (
311 params $ps:expr;
312
313 $($body:tt)*
314 ) => {{
315 $ps.enter();
316 let ret = {$($body)*};
317 $ps.exit();
318 ret
319 }};
320
321 ($($body:tt)*) => {{
322 let ret = {$($body)*};
323 ret
324 }};
325}
326
327#[macro_export]
328macro_rules! with_params_readonly {
329 (
330 get $name:ident = $($key:ident).+ or $default:expr;
331
332 $($body:tt)*
333 ) => {
334 let $name = get_param!($($key).+, $default);
335 with_params_readonly!($($body)*)
336 };
337
338 (
339 set $($key:ident).+ = $val:expr;
340
341 $($body:tt)*
342 ) =>{
343 let mut ps = ParamScope::default();
344 {
345 const CONST_KEY: &str = const_str::replace!(stringify!($($key).+), ";", "");
346 ps.put(CONST_KEY, $val);
347 }
348 with_params!(params ps; $($body)*)
349 };
350
351 ($($body:tt)*) => {{
352 let ret = {$($body)*};
353 ret
354 }};
355}
356
357#[cfg(test)]
358mod tests {
359 use crate::get_param;
360 use crate::storage::{GetOrElse, THREAD_STORAGE};
361 use crate::with_params;
362
363 use super::{ParamScope, ParamScopeOps};
364
365 #[test]
366 fn test_param_scope_create() {
367 let _ = ParamScope::default();
368 }
369
370 #[test]
371 fn test_param_scope_put_get() {
372 let mut ps = ParamScope::default();
373 ps.put("1", 1);
374 ps.put("2.0", 2.0);
375
376 THREAD_STORAGE.with(|ts| {
378 let ts = ts.borrow();
379 assert_eq!(0, ts.get_or_else("1", 0));
380 assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
381 });
382
383 assert_eq!(1, ps.get_or_else("1", 0));
385 assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
386 }
387
388 #[test]
389 fn test_param_scope_enter() {
390 let mut ps = ParamScope::default();
391 ps.put("1", 1);
392 ps.put("2.0", 2.0);
393
394 THREAD_STORAGE.with(|ts| {
396 let ts = ts.borrow();
397 assert_eq!(0, ts.get_or_else("1", 0));
398 assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
399 });
400
401 assert_eq!(1, ps.get_or_else("1", 0));
403 assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
404
405 ps.enter();
406
407 THREAD_STORAGE.with(|ts| {
409 let ts = ts.borrow();
410 assert_eq!(1, ts.get_or_else("1", 0));
411 assert_eq!(2.0, ts.get_or_else("2.0", 0.0));
412 });
413
414 assert_eq!(1, ps.get_or_else("1", 0));
416 assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
417
418 ps.exit();
419 THREAD_STORAGE.with(|ts| {
421 let ts = ts.borrow();
422 assert_eq!(0, ts.get_or_else("1", 0));
423 assert_eq!(0.0, ts.get_or_else("2.0", 0.0));
424 });
425 assert_eq!(1, ps.get_or_else("1", 0));
426 assert_eq!(2.0, ps.get_or_else("2.0", 0.0));
427 }
428
429 #[test]
430 fn test_param_scope_get_param() {
431 let mut ps = ParamScope::default();
432 ps.put("a.b.c", 1);
433
434 THREAD_STORAGE.with(|ts| {
436 let ts = ts.borrow();
437 assert_eq!(0, ts.get_or_else("a.b.c", 0));
438 });
439
440 assert_eq!(1, ps.get_or_else("a.b.c", 0));
442
443 ps.enter();
444
445 let x = get_param!(a.b.c, 0);
446 println!("x={}", x);
447 }
448
449 #[test]
450 fn test_param_scope_with_param_set() {
451 with_params! {
452 set a.b.c=1;
453 set a.b =2;
454
455 assert_eq!(1, get_param!(a.b.c, 0));
456 assert_eq!(2, get_param!(a.b, 0));
457
458 with_params! {
459 set a.b.c=2.0;
460
461 assert_eq!(2.0, get_param!(a.b.c, 0.0));
462 assert_eq!(2, get_param!(a.b, 0));
463 };
464
465 assert_eq!(1, get_param!(a.b.c, 0));
466 assert_eq!(2, get_param!(a.b, 0));
467 }
468
469 assert_eq!(0, get_param!(a.b.c, 0));
470 assert_eq!(0, get_param!(a.b, 0));
471 }
472
473 #[test]
474 fn test_param_scope_with_param_get() {
475 with_params! {
476 set a.b.c=1;
477
478 with_params! {
479 get a_b_c = a.b.c or 0;
480
481 assert_eq!(1, a_b_c);
482 };
483 }
484 }
485
486 #[test]
487 fn test_param_scope_with_param_set_get() {
488 with_params! {
489 set a.b.c = 1;
490 set a.b = 2;
491
492 with_params! {
493 get a_b_c = a.b.c or 0;
494 get a_b = a.b or 0;
495
496 assert_eq!(1, a_b_c);
497 assert_eq!(2, a_b);
498 };
499 }
500 }
501
502 #[test]
503 fn test_param_scope_with_param_readonly() {
504 with_params! {
505 get a_b_c = a.b.c or 1;
506
507 assert_eq!(1, a_b_c);
508 }
509 }
510
511 #[test]
512 fn test_param_scope_with_param_mixed_get_set() {
513 with_params! {
514 get _a_b_c = a.b.c or 1;
515 set a.b.c = 3;
516 get a_b_c = a.b.c or 2;
517
518 assert_eq!(3, a_b_c);
519 }
520 }
521}
522
523#[cfg(test)]
527mod test_param_scope {
528 use super::*;
529 use std::convert::TryInto;
530
531 #[test]
532 fn test_param_scope_default() {
533 let ps = ParamScope::default();
534 match ps {
535 ParamScope::Just(_) => assert!(true),
536 _ => assert!(false, "Default ParamScope should be ParamScope::Just"),
537 }
538 }
539
540 #[test]
541 fn test_param_scope_from_vec() {
542 let vec = vec!["param1=value1", "param2=value2"];
543 let ps: ParamScope = (&vec).into();
544 match ps {
545 ParamScope::Just(params) => {
546 assert_eq!(params.get(&"param1".xxh()).unwrap().value(), &Value::from("value1"));
547 assert_eq!(params.get(&"param2".xxh()).unwrap().value(), &Value::from("value2"));
548 }
549 _ => assert!(false, "ParamScope should be ParamScope::Just"),
550 }
551 }
552
553 #[test]
554 fn test_param_scope_get_with_hash() {
555 let mut ps = ParamScope::default();
556 ps.add("param=value");
557 let value = ps.get_with_hash("param".xxh());
558 assert_eq!(value, Value::from("value"));
559 }
560
561 #[test]
562 fn test_param_scope_get() {
563 let mut ps = ParamScope::default();
564 ps.add("param=value");
565 let value: String = ps.get("param").try_into().unwrap();
566 assert_eq!(value, "value");
567 }
568
569 #[test]
570 fn test_param_scope_add() {
571 let mut ps = ParamScope::default();
572 ps.add("param=value");
573 match ps {
574 ParamScope::Just(params) => {
575 assert_eq!(params.get(&"param".xxh()).unwrap().value(), &Value::from("value"));
576 }
577 _ => assert!(false, "ParamScope should be ParamScope::Just"),
578 }
579 }
580
581 #[test]
582 fn test_param_scope_keys() {
583 let mut ps = ParamScope::default();
584 ps.add("param=value");
585 let keys = ps.keys();
586 assert_eq!(keys, vec!["param"]);
587 }
588
589 #[test]
590 fn test_param_scope_enter_exit() {
591 let mut ps = ParamScope::default();
592 ps.add("param=value");
593 ps.enter();
594 match ps {
595 ParamScope::Nothing => assert!(true),
596 _ => assert!(false, "ParamScope should be ParamScope::Nothing after enter"),
597 }
598 ps.exit();
599 match ps {
600 ParamScope::Just(_) => assert!(true),
601 _ => assert!(false, "ParamScope should be ParamScope::Just after exit"),
602 }
603 }
604}
605
606