1use crate::{BeatInfo, NodeManifest, ParamType, ParamUpdate, ParamValue, WidgetConfig};
24
25#[derive(Debug)]
41pub struct ActionContext<'a> {
42 pub params: &'a [u8],
44
45 pub beat_info: Option<&'a BeatInfo>,
47
48 pub is_high_frequency: bool,
53}
54
55pub fn calculate_reset_values(
75 manifest: &NodeManifest,
76 _ctx: &ActionContext,
77) -> Result<Vec<ParamUpdate>, String> {
78 let mut updates = Vec::new();
79
80 for (index, param) in manifest.parameters.iter().enumerate() {
81 let default_value = get_default_value(¶m.data_type);
82 updates.push(ParamUpdate {
83 param_index: index as u32,
84 value: default_value,
85 });
86 }
87
88 Ok(updates)
89}
90
91pub fn calculate_random_values(
111 manifest: &NodeManifest,
112 _ctx: &ActionContext,
113 seed: u64,
114) -> Result<Vec<ParamUpdate>, String> {
115 use rand::{Rng, SeedableRng};
116 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
117 let mut updates = Vec::new();
118
119 for (index, param) in manifest.parameters.iter().enumerate() {
120 let random_value = match (¶m.widget, ¶m.data_type) {
122 (WidgetConfig::Slider(config), ParamType::ScalarF32) => {
124 ParamValue::ScalarF32(rng.gen_range(config.min..=config.max))
125 }
126
127 (WidgetConfig::Checkbox, ParamType::ScalarBool) => {
129 ParamValue::ScalarBool(rng.gen_bool(0.5))
130 }
131
132 (WidgetConfig::ColorPicker, ParamType::Vec3F32) => {
134 ParamValue::Vec3F32((rng.gen(), rng.gen(), rng.gen()))
135 }
136 (WidgetConfig::ColorPicker, ParamType::Vec4F32) => {
137 ParamValue::Vec4F32((rng.gen(), rng.gen(), rng.gen(), 1.0))
138 }
139
140 _ => get_default_value(¶m.data_type),
142 };
143
144 updates.push(ParamUpdate {
145 param_index: index as u32,
146 value: random_value,
147 });
148 }
149
150 Ok(updates)
151}
152
153fn get_default_value(param_type: &ParamType) -> ParamValue {
157 match param_type {
158 ParamType::ScalarF32 => ParamValue::ScalarF32(0.0),
159 ParamType::ScalarI32 => ParamValue::ScalarI32(0),
160 ParamType::ScalarU32 => ParamValue::ScalarU32(0),
161 ParamType::ScalarBool => ParamValue::ScalarBool(false),
162 ParamType::Vec2F32 => ParamValue::Vec2F32((0.0, 0.0)),
163 ParamType::Vec3F32 => ParamValue::Vec3F32((0.0, 0.0, 0.0)),
164 ParamType::Vec4F32 => ParamValue::Vec4F32((0.0, 0.0, 0.0, 0.0)),
165 ParamType::Mat4F32 => {
167 ParamValue::Mat4F32(vec![
169 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
170 ])
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::{ExecutionModel, NodeCategory, ShaderParam, SliderConfig};
179
180 fn create_test_manifest() -> NodeManifest {
181 NodeManifest {
182 api_version: 1,
183 display_name: "Test Node".to_string(),
184 version: "1.0.0".to_string(),
185 author: "Test".to_string(),
186 description: "Test node".to_string(),
187 category: NodeCategory::Effector,
188 tags: vec![],
189 model: ExecutionModel::FragmentShader,
190 parameters: vec![
191 ShaderParam {
192 name: "strength".to_string(),
193 data_type: ParamType::ScalarF32,
194 widget: WidgetConfig::Slider(SliderConfig {
195 min: 0.0,
196 max: 1.0,
197 step: 0.01,
198 }),
199 },
200 ShaderParam {
201 name: "enabled".to_string(),
202 data_type: ParamType::ScalarBool,
203 widget: WidgetConfig::Checkbox,
204 },
205 ShaderParam {
206 name: "color".to_string(),
207 data_type: ParamType::Vec3F32,
208 widget: WidgetConfig::ColorPicker,
209 },
210 ],
211 ports: vec![],
212 output_resolution_scale: 1.0,
213 output_hint: None,
214 actions: vec![],
215 embedded_textures: vec![],
216 }
217 }
218
219 #[test]
220 fn test_calculate_reset_values() {
221 let manifest = create_test_manifest();
222 let ctx = ActionContext {
223 params: &[],
224 beat_info: None,
225 is_high_frequency: false,
226 };
227
228 let updates = calculate_reset_values(&manifest, &ctx).unwrap();
229
230 assert_eq!(updates.len(), 3);
231 assert_eq!(updates[0].param_index, 0);
232 assert!(matches!(updates[0].value, ParamValue::ScalarF32(v) if v == 0.0));
233 assert_eq!(updates[1].param_index, 1);
234 assert!(matches!(updates[1].value, ParamValue::ScalarBool(false)));
235 assert_eq!(updates[2].param_index, 2);
236 assert!(matches!(
237 updates[2].value,
238 ParamValue::Vec3F32((0.0, 0.0, 0.0))
239 ));
240 }
241
242 #[test]
243 fn test_calculate_random_values() {
244 let manifest = create_test_manifest();
245 let ctx = ActionContext {
246 params: &[],
247 beat_info: None,
248 is_high_frequency: false,
249 };
250
251 let seed = 42;
252 let updates = calculate_random_values(&manifest, &ctx, seed).unwrap();
253
254 assert_eq!(updates.len(), 3);
255
256 assert_eq!(updates[0].param_index, 0);
258 if let ParamValue::ScalarF32(v) = updates[0].value {
259 assert!(
260 (0.0..=1.0).contains(&v),
261 "Random slider value {} out of range",
262 v
263 );
264 } else {
265 panic!("Expected ScalarF32 for slider parameter");
266 }
267
268 assert_eq!(updates[1].param_index, 1);
270 assert!(matches!(updates[1].value, ParamValue::ScalarBool(_)));
271
272 assert_eq!(updates[2].param_index, 2);
274 if let ParamValue::Vec3F32((r, g, b)) = updates[2].value {
275 assert!((0.0..=1.0).contains(&r), "Red channel out of range");
276 assert!((0.0..=1.0).contains(&g), "Green channel out of range");
277 assert!((0.0..=1.0).contains(&b), "Blue channel out of range");
278 } else {
279 panic!("Expected Vec3F32 for color picker parameter");
280 }
281 }
282
283 #[test]
284 fn test_random_values_deterministic() {
285 let manifest = create_test_manifest();
286 let ctx = ActionContext {
287 params: &[],
288 beat_info: None,
289 is_high_frequency: false,
290 };
291
292 let seed = 123;
293 let updates1 = calculate_random_values(&manifest, &ctx, seed).unwrap();
294 let updates2 = calculate_random_values(&manifest, &ctx, seed).unwrap();
295
296 assert_eq!(updates1.len(), updates2.len());
298 for (u1, u2) in updates1.iter().zip(updates2.iter()) {
299 assert_eq!(u1.param_index, u2.param_index);
300 }
303 }
304
305 #[test]
306 fn test_get_default_value_all_types() {
307 assert!(matches!(
308 get_default_value(&ParamType::ScalarF32),
309 ParamValue::ScalarF32(0.0)
310 ));
311 assert!(matches!(
312 get_default_value(&ParamType::ScalarI32),
313 ParamValue::ScalarI32(0)
314 ));
315 assert!(matches!(
316 get_default_value(&ParamType::ScalarU32),
317 ParamValue::ScalarU32(0)
318 ));
319 assert!(matches!(
320 get_default_value(&ParamType::ScalarBool),
321 ParamValue::ScalarBool(false)
322 ));
323 assert!(matches!(
324 get_default_value(&ParamType::Vec2F32),
325 ParamValue::Vec2F32((0.0, 0.0))
326 ));
327 assert!(matches!(
328 get_default_value(&ParamType::Vec3F32),
329 ParamValue::Vec3F32((0.0, 0.0, 0.0))
330 ));
331 assert!(matches!(
332 get_default_value(&ParamType::Vec4F32),
333 ParamValue::Vec4F32((0.0, 0.0, 0.0, 0.0))
334 ));
335 }
336
337 #[test]
338 fn test_action_context_with_beat_info() {
339 let beat = BeatInfo {
340 bpm: 120.0,
341 phase: 0.5,
342 bar_position: 0.25,
343 };
344
345 let ctx = ActionContext {
346 params: &[],
347 beat_info: Some(&beat),
348 is_high_frequency: false,
349 };
350
351 assert!(ctx.beat_info.is_some());
352 assert_eq!(ctx.beat_info.unwrap().bpm, 120.0);
353 }
354}