Skip to main content

burn_central_runtime/params/
args.rs

1use burn::prelude::Backend;
2use derive_more::{Deref, From};
3use json_patch::merge;
4use serde::{Deserialize, Serialize};
5
6use crate::{executor::ExecutionContext, params::RoutineParam};
7
8/// Trait for experiments arguments. It specify that the type must be serializable, deserializable
9/// and implement default. The reason it must implement default is that when you override a value
10/// it will only override the value you provide, the rest will be filled with the default value.
11pub trait ExperimentArgs: Serialize + for<'de> Deserialize<'de> + Default {}
12impl<T> ExperimentArgs for T where T: Serialize + for<'de> Deserialize<'de> + Default {}
13
14pub fn deserialize_and_merge_with_default<T: ExperimentArgs>(
15    args: &serde_json::Value,
16) -> Result<T, serde_json::Error> {
17    let mut merged = serde_json::to_value(T::default())?;
18
19    merge(&mut merged, args);
20
21    serde_json::from_value(merged)
22}
23
24/// Args are wrapper around the config you want to inject.
25///
26/// The type T must implement [ExperimentArgs] trait. This trait allow us to override the
27/// configuration from the CLI arguments you can specify while given us a fallback for arguments
28/// you don't provide.
29#[derive(From, Deref)]
30pub struct Args<T: ExperimentArgs>(pub T);
31
32impl<B: Backend, C: ExperimentArgs> RoutineParam<ExecutionContext<B>> for Args<C> {
33    type Item<'new> = Args<C>;
34
35    fn try_retrieve(ctx: &ExecutionContext<B>) -> anyhow::Result<Self::Item<'_>> {
36        let cfg = ctx.use_merged_args();
37        Ok(Args(cfg))
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use serde::{Deserialize, Serialize};
45    use serde_json::json;
46
47    #[derive(Serialize, Deserialize, Debug, PartialEq)]
48    struct Nested {
49        x: bool,
50        y: u64,
51    }
52
53    impl Default for Nested {
54        fn default() -> Self {
55            Nested { x: true, y: 10 }
56        }
57    }
58
59    #[derive(Serialize, Deserialize, Debug, PartialEq)]
60    struct MyArgs {
61        a: i32,
62        b: Option<String>,
63        nested: Nested,
64        list: Vec<i32>,
65    }
66
67    impl Default for MyArgs {
68        fn default() -> Self {
69            MyArgs {
70                a: 5,
71                b: Some("hello".to_owned()),
72                nested: Nested::default(),
73                list: vec![1, 2, 3],
74            }
75        }
76    }
77
78    #[test]
79    fn empty_override_returns_default() {
80        let cfg: MyArgs = deserialize_and_merge_with_default(&json!({})).unwrap();
81        assert_eq!(cfg, MyArgs::default());
82    }
83
84    #[test]
85    fn override_top_level_field() {
86        let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "a": 42 })).unwrap();
87        let expected = MyArgs {
88            a: 42,
89            ..Default::default()
90        };
91        assert_eq!(cfg, expected);
92    }
93
94    #[test]
95    fn deep_override_nested_field() {
96        let cfg: MyArgs =
97            deserialize_and_merge_with_default(&json!({ "nested": { "y": 99 } })).unwrap();
98        let mut expected = MyArgs::default();
99        expected.nested.y = 99;
100        assert_eq!(cfg, expected);
101    }
102
103    #[test]
104    fn null_becomes_json_null_for_optional() {
105        let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "b": null })).unwrap();
106        assert_eq!(cfg.b, None);
107    }
108
109    #[test]
110    fn null_becomes_json_null_for_required() {
111        let err = deserialize_and_merge_with_default::<MyArgs>(&json!({ "a": null })).unwrap_err();
112        assert!(err.is_data());
113    }
114
115    #[test]
116    fn override_list_replaces_array() {
117        let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "list": [9,8,7] })).unwrap();
118        assert_eq!(cfg.list, vec![9, 8, 7]);
119    }
120
121    #[test]
122    fn type_mismatch_in_nested_errors_data() {
123        let err = deserialize_and_merge_with_default::<MyArgs>(
124            &json!({ "nested": { "x": "not_a_bool" } }),
125        )
126        .unwrap_err();
127        assert!(err.is_data());
128    }
129
130    #[test]
131    fn patch_application_error_propagates() {
132        let err =
133            deserialize_and_merge_with_default::<MyArgs>(&json!({ "nested": { "y": [1, 2, 3] } }))
134                .unwrap_err();
135        assert!(err.is_data());
136    }
137}