burn_central_runtime/params/
args.rs1use burn::prelude::Backend;
2use derive_more::{Deref, From};
3use json_patch::merge;
4use serde::{Deserialize, Serialize};
5
6use crate::{executor::ExecutionContext, params::RoutineParam};
7
8pub trait ExperimentArgs: Serialize + for<'de> Deserialize<'de> + Default {}
12impl<T> ExperimentArgs for T where T: Serialize + for<'de> Deserialize<'de> + Default {}
13
14pub fn deserialize_and_merge_with_default<T: ExperimentArgs>(
15 args: &serde_json::Value,
16) -> Result<T, serde_json::Error> {
17 let mut merged = serde_json::to_value(T::default())?;
18
19 merge(&mut merged, args);
20
21 serde_json::from_value(merged)
22}
23
24#[derive(From, Deref)]
30pub struct Args<T: ExperimentArgs>(pub T);
31
32impl<B: Backend, C: ExperimentArgs> RoutineParam<ExecutionContext<B>> for Args<C> {
33 type Item<'new> = Args<C>;
34
35 fn try_retrieve(ctx: &ExecutionContext<B>) -> anyhow::Result<Self::Item<'_>> {
36 let cfg = ctx.use_merged_args();
37 Ok(Args(cfg))
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use serde::{Deserialize, Serialize};
45 use serde_json::json;
46
47 #[derive(Serialize, Deserialize, Debug, PartialEq)]
48 struct Nested {
49 x: bool,
50 y: u64,
51 }
52
53 impl Default for Nested {
54 fn default() -> Self {
55 Nested { x: true, y: 10 }
56 }
57 }
58
59 #[derive(Serialize, Deserialize, Debug, PartialEq)]
60 struct MyArgs {
61 a: i32,
62 b: Option<String>,
63 nested: Nested,
64 list: Vec<i32>,
65 }
66
67 impl Default for MyArgs {
68 fn default() -> Self {
69 MyArgs {
70 a: 5,
71 b: Some("hello".to_owned()),
72 nested: Nested::default(),
73 list: vec![1, 2, 3],
74 }
75 }
76 }
77
78 #[test]
79 fn empty_override_returns_default() {
80 let cfg: MyArgs = deserialize_and_merge_with_default(&json!({})).unwrap();
81 assert_eq!(cfg, MyArgs::default());
82 }
83
84 #[test]
85 fn override_top_level_field() {
86 let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "a": 42 })).unwrap();
87 let expected = MyArgs {
88 a: 42,
89 ..Default::default()
90 };
91 assert_eq!(cfg, expected);
92 }
93
94 #[test]
95 fn deep_override_nested_field() {
96 let cfg: MyArgs =
97 deserialize_and_merge_with_default(&json!({ "nested": { "y": 99 } })).unwrap();
98 let mut expected = MyArgs::default();
99 expected.nested.y = 99;
100 assert_eq!(cfg, expected);
101 }
102
103 #[test]
104 fn null_becomes_json_null_for_optional() {
105 let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "b": null })).unwrap();
106 assert_eq!(cfg.b, None);
107 }
108
109 #[test]
110 fn null_becomes_json_null_for_required() {
111 let err = deserialize_and_merge_with_default::<MyArgs>(&json!({ "a": null })).unwrap_err();
112 assert!(err.is_data());
113 }
114
115 #[test]
116 fn override_list_replaces_array() {
117 let cfg: MyArgs = deserialize_and_merge_with_default(&json!({ "list": [9,8,7] })).unwrap();
118 assert_eq!(cfg.list, vec![9, 8, 7]);
119 }
120
121 #[test]
122 fn type_mismatch_in_nested_errors_data() {
123 let err = deserialize_and_merge_with_default::<MyArgs>(
124 &json!({ "nested": { "x": "not_a_bool" } }),
125 )
126 .unwrap_err();
127 assert!(err.is_data());
128 }
129
130 #[test]
131 fn patch_application_error_propagates() {
132 let err =
133 deserialize_and_merge_with_default::<MyArgs>(&json!({ "nested": { "y": [1, 2, 3] } }))
134 .unwrap_err();
135 assert!(err.is_data());
136 }
137}