leafwing_input_manager/input_processing/dual_axis/
custom.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::sync::{LazyLock, RwLock};
4
5use bevy::app::App;
6use bevy::prelude::{FromReflect, Reflect, ReflectDeserialize, ReflectSerialize, TypePath, Vec2};
7use bevy::reflect::utility::{GenericTypePathCell, NonGenericTypeInfoCell};
8use bevy::reflect::{
9    erased_serde, FromType, GetTypeRegistration, OpaqueInfo, PartialReflect, ReflectFromPtr,
10    ReflectKind, ReflectMut, ReflectOwned, ReflectRef, TypeInfo, TypeRegistration, Typed,
11};
12use dyn_clone::DynClone;
13use dyn_eq::DynEq;
14use dyn_hash::DynHash;
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16use serde_flexitos::ser::require_erased_serialize_impl;
17use serde_flexitos::{serialize_trait_object, Registry};
18
19use crate::input_processing::DualAxisProcessor;
20use crate::typetag::{InfallibleMapRegistry, RegisterTypeTag};
21
22/// A trait for creating custom processor that handles dual-axis input values,
23/// accepting a [`Vec2`] input and producing a [`Vec2`] output.
24///
25/// # Examples
26///
27/// ```rust
28/// use std::hash::{Hash, Hasher};
29/// use bevy::prelude::*;
30/// use bevy::math::FloatOrd;
31/// use serde::{Deserialize, Serialize};
32/// use leafwing_input_manager::prelude::*;
33///
34/// /// Doubles the input, takes its absolute value,
35/// /// and discards results that meet the specified condition on the X-axis.
36/// // If your processor includes fields not implemented Eq and Hash,
37/// // implementation is necessary as shown below.
38/// // Otherwise, you can derive Eq and Hash directly.
39/// #[derive(Debug, Clone, Copy, PartialEq, Reflect, Serialize, Deserialize)]
40/// pub struct DoubleAbsoluteValueThenRejectX(pub f32);
41///
42/// // Add this attribute for ensuring proper serialization and deserialization.
43/// #[serde_typetag]
44/// impl CustomDualAxisProcessor for DoubleAbsoluteValueThenRejectX {
45///     fn process(&self, input_value: Vec2) -> Vec2 {
46///         // Implement the logic just like you would in a normal function.
47///
48///         // You can use other processors within this function.
49///         let value = DualAxisSensitivity::all(2.0).scale(input_value);
50///
51///         let value = value.abs();
52///         let new_x = if value.x == self.0 {
53///             0.0
54///         } else {
55///             value.x
56///         };
57///         Vec2::new(new_x, value.y)
58///     }
59/// }
60///
61/// // Unfortunately, manual implementation is required due to the float field.
62/// impl Eq for DoubleAbsoluteValueThenRejectX {}
63/// impl Hash for DoubleAbsoluteValueThenRejectX {
64///     fn hash<H: Hasher>(&self, state: &mut H) {
65///         // Encapsulate the float field for hashing.
66///         FloatOrd(self.0).hash(state);
67///     }
68/// }
69///
70/// // Remember to register your processor - it will ensure everything works smoothly!
71/// let mut app = App::new();
72/// app.register_dual_axis_processor::<DoubleAbsoluteValueThenRejectX>();
73///
74/// // Now you can use it!
75/// let processor = DoubleAbsoluteValueThenRejectX(4.0);
76///
77/// // Rejected X!
78/// assert_eq!(processor.process(Vec2::splat(2.0)), Vec2::new(0.0, 4.0));
79/// assert_eq!(processor.process(Vec2::splat(-2.0)), Vec2::new(0.0, 4.0));
80///
81/// // Others are just doubled absolute value.
82/// assert_eq!(processor.process(Vec2::splat(6.0)), Vec2::splat(12.0));
83/// assert_eq!(processor.process(Vec2::splat(4.0)), Vec2::splat(8.0));
84/// assert_eq!(processor.process(Vec2::splat(0.0)), Vec2::splat(0.0));
85/// assert_eq!(processor.process(Vec2::splat(-4.0)), Vec2::splat(8.0));
86/// assert_eq!(processor.process(Vec2::splat(-6.0)), Vec2::splat(12.0));
87///
88/// // The ways to create a DualAxisProcessor.
89/// let dual_axis_processor = DualAxisProcessor::Custom(Box::new(processor));
90/// assert_eq!(dual_axis_processor, DualAxisProcessor::from(processor));
91/// ```
92pub trait CustomDualAxisProcessor:
93    Send + Sync + Debug + DynClone + DynEq + DynHash + Reflect + erased_serde::Serialize
94{
95    /// Computes the result by processing the `input_value`.
96    fn process(&self, input_value: Vec2) -> Vec2;
97}
98
99impl<P: CustomDualAxisProcessor> From<P> for DualAxisProcessor {
100    fn from(value: P) -> Self {
101        Self::Custom(Box::new(value))
102    }
103}
104
105dyn_clone::clone_trait_object!(CustomDualAxisProcessor);
106dyn_eq::eq_trait_object!(CustomDualAxisProcessor);
107dyn_hash::hash_trait_object!(CustomDualAxisProcessor);
108
109impl PartialReflect for Box<dyn CustomDualAxisProcessor> {
110    fn get_represented_type_info(&self) -> Option<&'static TypeInfo> {
111        Some(Self::type_info())
112    }
113
114    fn reflect_kind(&self) -> ReflectKind {
115        ReflectKind::Opaque
116    }
117
118    fn reflect_ref(&self) -> ReflectRef {
119        ReflectRef::Opaque(self)
120    }
121
122    fn reflect_mut(&mut self) -> ReflectMut {
123        ReflectMut::Opaque(self)
124    }
125
126    fn reflect_owned(self: Box<Self>) -> ReflectOwned {
127        ReflectOwned::Opaque(self)
128    }
129
130    fn clone_value(&self) -> Box<dyn PartialReflect> {
131        Box::new(self.clone())
132    }
133
134    fn try_apply(&mut self, value: &dyn PartialReflect) -> Result<(), bevy::reflect::ApplyError> {
135        if let Some(value) = value.try_downcast_ref::<Self>() {
136            *self = value.clone();
137            Ok(())
138        } else {
139            Err(bevy::reflect::ApplyError::MismatchedTypes {
140                from_type: self
141                    .reflect_type_ident()
142                    .unwrap_or_default()
143                    .to_string()
144                    .into_boxed_str(),
145                to_type: self
146                    .reflect_type_ident()
147                    .unwrap_or_default()
148                    .to_string()
149                    .into_boxed_str(),
150            })
151        }
152    }
153
154    fn into_partial_reflect(self: Box<Self>) -> Box<dyn PartialReflect> {
155        self
156    }
157
158    fn as_partial_reflect(&self) -> &dyn PartialReflect {
159        self
160    }
161
162    fn as_partial_reflect_mut(&mut self) -> &mut dyn PartialReflect {
163        self
164    }
165
166    fn try_into_reflect(self: Box<Self>) -> Result<Box<dyn Reflect>, Box<dyn PartialReflect>> {
167        Ok(self)
168    }
169
170    fn try_as_reflect(&self) -> Option<&dyn Reflect> {
171        Some(self)
172    }
173
174    fn try_as_reflect_mut(&mut self) -> Option<&mut dyn Reflect> {
175        Some(self)
176    }
177}
178
179impl Reflect for Box<dyn CustomDualAxisProcessor> {
180    fn into_any(self: Box<Self>) -> Box<dyn Any> {
181        self
182    }
183
184    fn as_any(&self) -> &dyn Any {
185        self
186    }
187
188    fn as_any_mut(&mut self) -> &mut dyn Any {
189        self
190    }
191
192    fn into_reflect(self: Box<Self>) -> Box<dyn Reflect> {
193        self
194    }
195
196    fn as_reflect(&self) -> &dyn Reflect {
197        self
198    }
199
200    fn as_reflect_mut(&mut self) -> &mut dyn Reflect {
201        self
202    }
203
204    fn set(&mut self, value: Box<dyn Reflect>) -> Result<(), Box<dyn Reflect>> {
205        *self = value.take()?;
206        Ok(())
207    }
208}
209
210impl Typed for Box<dyn CustomDualAxisProcessor> {
211    fn type_info() -> &'static TypeInfo {
212        static CELL: NonGenericTypeInfoCell = NonGenericTypeInfoCell::new();
213        CELL.get_or_set(|| TypeInfo::Opaque(OpaqueInfo::new::<Self>()))
214    }
215}
216
217impl TypePath for Box<dyn CustomDualAxisProcessor> {
218    fn type_path() -> &'static str {
219        static CELL: GenericTypePathCell = GenericTypePathCell::new();
220        CELL.get_or_insert::<Self, _>(|| {
221            {
222                format!(
223                    "std::boxed::Box<dyn {}::CustomDualAxisProcessor>",
224                    module_path!()
225                )
226            }
227        })
228    }
229
230    fn short_type_path() -> &'static str {
231        static CELL: GenericTypePathCell = GenericTypePathCell::new();
232        CELL.get_or_insert::<Self, _>(|| "Box<dyn CustomDualAxisProcessor>".to_string())
233    }
234
235    fn type_ident() -> Option<&'static str> {
236        Some("Box<dyn CustomDualAxisProcessor>")
237    }
238
239    fn crate_name() -> Option<&'static str> {
240        Some(module_path!().split(':').next().unwrap())
241    }
242
243    fn module_path() -> Option<&'static str> {
244        Some(module_path!())
245    }
246}
247
248impl GetTypeRegistration for Box<dyn CustomDualAxisProcessor> {
249    fn get_type_registration() -> TypeRegistration {
250        let mut registration = TypeRegistration::of::<Self>();
251        registration.insert::<ReflectDeserialize>(FromType::<Self>::from_type());
252        registration.insert::<ReflectFromPtr>(FromType::<Self>::from_type());
253        registration.insert::<ReflectSerialize>(FromType::<Self>::from_type());
254        registration
255    }
256}
257
258impl FromReflect for Box<dyn CustomDualAxisProcessor> {
259    fn from_reflect(reflect: &dyn PartialReflect) -> Option<Self> {
260        Some(reflect.try_downcast_ref::<Self>()?.clone())
261    }
262}
263
264impl Serialize for dyn CustomDualAxisProcessor + '_ {
265    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266    where
267        S: Serializer,
268    {
269        // Check that `CustomDualAxisProcessor` has `erased_serde::Serialize` as a super trait,
270        // preventing infinite recursion at runtime.
271        const fn __check_erased_serialize_super_trait<T: ?Sized + CustomDualAxisProcessor>() {
272            require_erased_serialize_impl::<T>();
273        }
274        serialize_trait_object(serializer, self.reflect_short_type_path(), self)
275    }
276}
277
278impl<'de> Deserialize<'de> for Box<dyn CustomDualAxisProcessor> {
279    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
280    where
281        D: Deserializer<'de>,
282    {
283        let registry = PROCESSOR_REGISTRY.read().unwrap();
284        registry.deserialize_trait_object(deserializer)
285    }
286}
287
288/// Registry of deserializers for [`CustomDualAxisProcessor`]s.
289static PROCESSOR_REGISTRY: LazyLock<RwLock<InfallibleMapRegistry<dyn CustomDualAxisProcessor>>> =
290    LazyLock::new(|| RwLock::new(InfallibleMapRegistry::new("CustomDualAxisProcessor")));
291
292/// A trait for registering a specific [`CustomDualAxisProcessor`].
293pub trait RegisterDualAxisProcessorExt {
294    /// Registers the specified [`CustomDualAxisProcessor`].
295    fn register_dual_axis_processor<'de, T>(&mut self) -> &mut Self
296    where
297        T: RegisterTypeTag<'de, dyn CustomDualAxisProcessor> + GetTypeRegistration;
298}
299
300impl RegisterDualAxisProcessorExt for App {
301    fn register_dual_axis_processor<'de, T>(&mut self) -> &mut Self
302    where
303        T: RegisterTypeTag<'de, dyn CustomDualAxisProcessor> + GetTypeRegistration,
304    {
305        let mut registry = PROCESSOR_REGISTRY.write().unwrap();
306        T::register_typetag(&mut registry);
307        self.register_type::<T>();
308        self
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate as leafwing_input_manager;
316    use leafwing_input_manager_macros::serde_typetag;
317    use serde_test::{assert_tokens, Token};
318
319    #[test]
320    fn test_custom_dual_axis_processor() {
321        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect, Serialize, Deserialize)]
322        struct CustomDualAxisInverted;
323
324        #[serde_typetag]
325        impl CustomDualAxisProcessor for CustomDualAxisInverted {
326            fn process(&self, input_value: Vec2) -> Vec2 {
327                -input_value
328            }
329        }
330
331        let mut app = App::new();
332        app.register_dual_axis_processor::<CustomDualAxisInverted>();
333
334        let custom: Box<dyn CustomDualAxisProcessor> = Box::new(CustomDualAxisInverted);
335        assert_tokens(
336            &custom,
337            &[
338                Token::Map { len: Some(1) },
339                Token::BorrowedStr("CustomDualAxisInverted"),
340                Token::UnitStruct {
341                    name: "CustomDualAxisInverted",
342                },
343                Token::MapEnd,
344            ],
345        );
346
347        let processor = DualAxisProcessor::Custom(custom);
348        assert_eq!(DualAxisProcessor::from(CustomDualAxisInverted), processor);
349
350        for x in -300..300 {
351            let x = x as f32 * 0.01;
352            for y in -300..300 {
353                let y = y as f32 * 0.01;
354                let value = Vec2::new(x, y);
355
356                assert_eq!(processor.process(value), -value);
357                assert_eq!(CustomDualAxisInverted.process(value), -value);
358            }
359        }
360    }
361}