leafwing_input_manager/input_processing/single_axis/
custom.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::sync::{LazyLock, RwLock};
4
5use bevy::app::App;
6use bevy::prelude::{FromReflect, Reflect, ReflectDeserialize, ReflectSerialize, TypePath};
7use bevy::reflect::utility::{GenericTypePathCell, NonGenericTypeInfoCell};
8use bevy::reflect::{
9    erased_serde, FromType, GetTypeRegistration, OpaqueInfo, PartialReflect, ReflectFromPtr,
10    ReflectKind, ReflectMut, ReflectOwned, ReflectRef, TypeInfo, TypeRegistration, Typed,
11};
12use dyn_clone::DynClone;
13use dyn_eq::DynEq;
14use dyn_hash::DynHash;
15use serde::{Deserialize, Deserializer, Serialize, Serializer};
16use serde_flexitos::ser::require_erased_serialize_impl;
17use serde_flexitos::{serialize_trait_object, Registry};
18
19use crate::input_processing::AxisProcessor;
20use crate::typetag::{InfallibleMapRegistry, RegisterTypeTag};
21
22/// A trait for creating custom processor that handles single-axis input values,
23/// accepting a `f32` input and producing a `f32` output.
24///
25/// # Examples
26///
27/// ```rust
28/// use std::hash::{Hash, Hasher};
29/// use bevy::prelude::*;
30/// use bevy::math::FloatOrd;
31/// use serde::{Deserialize, Serialize};
32/// use leafwing_input_manager::prelude::*;
33///
34/// /// Doubles the input, takes the absolute value,
35/// /// and discards results that meet the specified condition.
36/// // If your processor includes fields not implemented Eq and Hash,
37/// // implementation is necessary as shown below.
38/// // Otherwise, you can derive Eq and Hash directly.
39/// #[derive(Debug, Clone, Copy, PartialEq, Reflect, Serialize, Deserialize)]
40/// pub struct DoubleAbsoluteValueThenIgnored(pub f32);
41///
42/// // Add this attribute for ensuring proper serialization and deserialization.
43/// #[serde_typetag]
44/// impl CustomAxisProcessor for DoubleAbsoluteValueThenIgnored {
45///     fn process(&self, input_value: f32) -> f32 {
46///         // Implement the logic just like you would in a normal function.
47///
48///         // You can use other processors within this function.
49///         let value = AxisProcessor::Sensitivity(2.0).process(input_value);
50///
51///         let value = value.abs();
52///         if value == self.0 {
53///             0.0
54///         } else {
55///             value
56///         }
57///     }
58/// }
59///
60/// // Unfortunately, manual implementation is required due to the float field.
61/// impl Eq for DoubleAbsoluteValueThenIgnored {}
62/// impl Hash for DoubleAbsoluteValueThenIgnored {
63///     fn hash<H: Hasher>(&self, state: &mut H) {
64///         // Encapsulate the float field for hashing.
65///         FloatOrd(self.0).hash(state);
66///     }
67/// }
68///
69/// // Remember to register your processor - it will ensure everything works smoothly!
70/// let mut app = App::new();
71/// app.register_axis_processor::<DoubleAbsoluteValueThenIgnored>();
72///
73/// // Now you can use it!
74/// let processor = DoubleAbsoluteValueThenIgnored(4.0);
75///
76/// // Rejected!
77/// assert_eq!(processor.process(2.0), 0.0);
78/// assert_eq!(processor.process(-2.0), 0.0);
79///
80/// // Others are just doubled absolute value.
81/// assert_eq!(processor.process(6.0), 12.0);
82/// assert_eq!(processor.process(4.0), 8.0);
83/// assert_eq!(processor.process(0.0), 0.0);
84/// assert_eq!(processor.process(-4.0), 8.0);
85/// assert_eq!(processor.process(-6.0), 12.0);
86///
87/// // The ways to create an AxisProcessor.
88/// let axis_processor = AxisProcessor::Custom(Box::new(processor));
89/// assert_eq!(axis_processor, AxisProcessor::from(processor));
90/// ```
91pub trait CustomAxisProcessor:
92    Send + Sync + Debug + DynClone + DynEq + DynHash + Reflect + erased_serde::Serialize
93{
94    /// Computes the result by processing the `input_value`.
95    fn process(&self, input_value: f32) -> f32;
96}
97
98impl<P: CustomAxisProcessor> From<P> for AxisProcessor {
99    fn from(value: P) -> Self {
100        Self::Custom(Box::new(value))
101    }
102}
103
104dyn_clone::clone_trait_object!(CustomAxisProcessor);
105dyn_eq::eq_trait_object!(CustomAxisProcessor);
106dyn_hash::hash_trait_object!(CustomAxisProcessor);
107
108impl PartialReflect for Box<dyn CustomAxisProcessor> {
109    fn get_represented_type_info(&self) -> Option<&'static TypeInfo> {
110        Some(Self::type_info())
111    }
112
113    fn reflect_kind(&self) -> ReflectKind {
114        ReflectKind::Opaque
115    }
116
117    fn reflect_ref(&self) -> ReflectRef {
118        ReflectRef::Opaque(self)
119    }
120
121    fn reflect_mut(&mut self) -> ReflectMut {
122        ReflectMut::Opaque(self)
123    }
124
125    fn reflect_owned(self: Box<Self>) -> ReflectOwned {
126        ReflectOwned::Opaque(self)
127    }
128
129    fn clone_value(&self) -> Box<dyn PartialReflect> {
130        Box::new(self.clone())
131    }
132
133    fn try_apply(&mut self, value: &dyn PartialReflect) -> Result<(), bevy::reflect::ApplyError> {
134        if let Some(value) = value.try_downcast_ref::<Self>() {
135            *self = value.clone();
136            Ok(())
137        } else {
138            Err(bevy::reflect::ApplyError::MismatchedTypes {
139                from_type: self
140                    .reflect_type_ident()
141                    .unwrap_or_default()
142                    .to_string()
143                    .into_boxed_str(),
144                to_type: self
145                    .reflect_type_ident()
146                    .unwrap_or_default()
147                    .to_string()
148                    .into_boxed_str(),
149            })
150        }
151    }
152
153    fn into_partial_reflect(self: Box<Self>) -> Box<dyn PartialReflect> {
154        self
155    }
156
157    fn as_partial_reflect(&self) -> &dyn PartialReflect {
158        self
159    }
160
161    fn as_partial_reflect_mut(&mut self) -> &mut dyn PartialReflect {
162        self
163    }
164
165    fn try_into_reflect(self: Box<Self>) -> Result<Box<dyn Reflect>, Box<dyn PartialReflect>> {
166        Ok(self)
167    }
168
169    fn try_as_reflect(&self) -> Option<&dyn Reflect> {
170        Some(self)
171    }
172
173    fn try_as_reflect_mut(&mut self) -> Option<&mut dyn Reflect> {
174        Some(self)
175    }
176}
177
178impl Reflect for Box<dyn CustomAxisProcessor> {
179    fn into_any(self: Box<Self>) -> Box<dyn Any> {
180        self
181    }
182
183    fn as_any(&self) -> &dyn Any {
184        self
185    }
186
187    fn as_any_mut(&mut self) -> &mut dyn Any {
188        self
189    }
190
191    fn into_reflect(self: Box<Self>) -> Box<dyn Reflect> {
192        self
193    }
194
195    fn as_reflect(&self) -> &dyn Reflect {
196        self
197    }
198
199    fn as_reflect_mut(&mut self) -> &mut dyn Reflect {
200        self
201    }
202
203    fn set(&mut self, value: Box<dyn Reflect>) -> Result<(), Box<dyn Reflect>> {
204        *self = value.take()?;
205        Ok(())
206    }
207}
208
209impl Typed for Box<dyn CustomAxisProcessor> {
210    fn type_info() -> &'static TypeInfo {
211        static CELL: NonGenericTypeInfoCell = NonGenericTypeInfoCell::new();
212        CELL.get_or_set(|| TypeInfo::Opaque(OpaqueInfo::new::<Self>()))
213    }
214}
215
216impl TypePath for Box<dyn CustomAxisProcessor> {
217    fn type_path() -> &'static str {
218        static CELL: GenericTypePathCell = GenericTypePathCell::new();
219        CELL.get_or_insert::<Self, _>(|| {
220            {
221                format!(
222                    "std::boxed::Box<dyn {}::CustomAxisProcessor>",
223                    module_path!()
224                )
225            }
226        })
227    }
228
229    fn short_type_path() -> &'static str {
230        static CELL: GenericTypePathCell = GenericTypePathCell::new();
231        CELL.get_or_insert::<Self, _>(|| "Box<dyn CustomAxisProcessor>".to_string())
232    }
233
234    fn type_ident() -> Option<&'static str> {
235        Some("Box<dyn CustomAxisProcessor>")
236    }
237
238    fn crate_name() -> Option<&'static str> {
239        Some(module_path!().split(':').next().unwrap())
240    }
241
242    fn module_path() -> Option<&'static str> {
243        Some(module_path!())
244    }
245}
246
247impl GetTypeRegistration for Box<dyn CustomAxisProcessor> {
248    fn get_type_registration() -> TypeRegistration {
249        let mut registration = TypeRegistration::of::<Self>();
250        registration.insert::<ReflectDeserialize>(FromType::<Self>::from_type());
251        registration.insert::<ReflectFromPtr>(FromType::<Self>::from_type());
252        registration.insert::<ReflectSerialize>(FromType::<Self>::from_type());
253        registration
254    }
255}
256
257impl FromReflect for Box<dyn CustomAxisProcessor> {
258    fn from_reflect(reflect: &dyn PartialReflect) -> Option<Self> {
259        Some(reflect.try_downcast_ref::<Self>()?.clone())
260    }
261}
262
263impl Serialize for dyn CustomAxisProcessor + '_ {
264    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
265    where
266        S: Serializer,
267    {
268        // Check that `CustomAxisProcessor` has `erased_serde::Serialize` as a super trait,
269        // preventing infinite recursion at runtime.
270        const fn __check_erased_serialize_super_trait<T: ?Sized + CustomAxisProcessor>() {
271            require_erased_serialize_impl::<T>();
272        }
273        serialize_trait_object(serializer, self.reflect_short_type_path(), self)
274    }
275}
276
277impl<'de> Deserialize<'de> for Box<dyn CustomAxisProcessor> {
278    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
279    where
280        D: Deserializer<'de>,
281    {
282        let registry = PROCESSOR_REGISTRY.read().unwrap();
283        registry.deserialize_trait_object(deserializer)
284    }
285}
286
287/// Registry of deserializers for [`CustomAxisProcessor`]s.
288static PROCESSOR_REGISTRY: LazyLock<RwLock<InfallibleMapRegistry<dyn CustomAxisProcessor>>> =
289    LazyLock::new(|| RwLock::new(InfallibleMapRegistry::new("CustomAxisProcessor")));
290
291/// A trait for registering a specific [`CustomAxisProcessor`].
292pub trait RegisterCustomAxisProcessorExt {
293    /// Registers the specified [`CustomAxisProcessor`].
294    fn register_axis_processor<'de, T>(&mut self) -> &mut Self
295    where
296        T: RegisterTypeTag<'de, dyn CustomAxisProcessor> + GetTypeRegistration;
297}
298
299impl RegisterCustomAxisProcessorExt for App {
300    fn register_axis_processor<'de, T>(&mut self) -> &mut Self
301    where
302        T: RegisterTypeTag<'de, dyn CustomAxisProcessor> + GetTypeRegistration,
303    {
304        let mut registry = PROCESSOR_REGISTRY.write().unwrap();
305        T::register_typetag(&mut registry);
306        self.register_type::<T>();
307        self
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate as leafwing_input_manager;
315    use leafwing_input_manager_macros::serde_typetag;
316    use serde_test::{assert_tokens, Token};
317
318    #[test]
319    fn test_custom_axis_processor() {
320        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Reflect, Serialize, Deserialize)]
321        struct CustomAxisInverted;
322
323        #[serde_typetag]
324        impl CustomAxisProcessor for CustomAxisInverted {
325            fn process(&self, input_value: f32) -> f32 {
326                -input_value
327            }
328        }
329
330        let mut app = App::new();
331        app.register_axis_processor::<CustomAxisInverted>();
332
333        let custom: Box<dyn CustomAxisProcessor> = Box::new(CustomAxisInverted);
334        assert_tokens(
335            &custom,
336            &[
337                Token::Map { len: Some(1) },
338                Token::BorrowedStr("CustomAxisInverted"),
339                Token::UnitStruct {
340                    name: "CustomAxisInverted",
341                },
342                Token::MapEnd,
343            ],
344        );
345
346        let processor = AxisProcessor::Custom(custom);
347        assert_eq!(AxisProcessor::from(CustomAxisInverted), processor);
348
349        for value in -300..300 {
350            let value = value as f32 * 0.01;
351
352            assert_eq!(processor.process(value), -value);
353            assert_eq!(CustomAxisInverted.process(value), -value);
354        }
355    }
356}