Skip to main content

deke_types/
validator.rs

1use std::{fmt::Debug, sync::Arc};
2
3use bitvec::vec::BitVec;
4use wide::{f32x8, f64x4};
5
6use crate::{DekeError, DekeResult, KinScalar, SRobotQ, SRobotQLike};
7
8pub trait ValidatorContext: Sized {}
9
10#[doc(hidden)]
11pub trait Leaf: ValidatorContext {}
12
13impl ValidatorContext for () {}
14
15type ExtraCheck<const N: usize, F> = Box<dyn Fn(&SRobotQ<N, F>) -> bool + Send + Sync>;
16
17#[macro_export]
18macro_rules! validator_context_type_impl {
19    ($($ident:ident),*) => {
20        $(
21            impl $crate::ValidatorContext for $ident {}
22            impl $crate::Leaf for $ident {}
23        )*
24    };
25}
26
27impl<A: ValidatorContext, B: ValidatorContext> ValidatorContext for (A, B) {}
28
29pub trait FromFlattened<Flattened>: ValidatorContext {
30    fn nest(flattened: Flattened) -> Self;
31}
32
33macro_rules! validator_context_tuple_impl {
34    ($tup:tt) => {
35        validator_context_tuple_impl!(@rewrite $tup [emit_impl]);
36    };
37
38    (@rewrite ($l:tt, $r:tt) [$($cb:tt)*]) => {
39        validator_context_tuple_impl!(@rewrite $l [pair_right $r [$($cb)*]]);
40    };
41    (@rewrite $ident:ident [$($cb:tt)*]) => {
42        validator_context_tuple_impl!(@invoke [$($cb)*] [$ident] $ident);
43        validator_context_tuple_impl!(@invoke [$($cb)*] [] ());
44    };
45
46    (@invoke [pair_right $r:tt [$($cb:tt)*]] [$($kept_l:ident)*] $rew_l:tt) => {
47        validator_context_tuple_impl!(@rewrite $r [pair_combine [$($kept_l)*] $rew_l [$($cb)*]]);
48    };
49    (@invoke [pair_combine [$($kept_l:ident)*] $rew_l:tt [$($cb:tt)*]] [$($kept_r:ident)*] $rew_r:tt) => {
50        validator_context_tuple_impl!(@invoke [$($cb)*] [$($kept_l)* $($kept_r)*] ($rew_l, $rew_r));
51    };
52    (@invoke [emit_impl] [$($kept:ident)*] $shape:tt) => {
53        #[allow(non_snake_case)]
54        impl<$($kept: Leaf),*> FromFlattened<($($kept,)*)> for $shape {
55            #[inline]
56            fn nest(flattened: ($($kept,)*)) -> Self {
57                let ($($kept,)*) = flattened;
58                $shape
59            }
60        }
61    };
62}
63
64validator_context_tuple_impl! { (A, (B, C)) }
65validator_context_tuple_impl! { ((A, B), C) }
66
67validator_context_tuple_impl! { (A, (B, (C, D))) }
68validator_context_tuple_impl! { (A, ((B, C), D)) }
69validator_context_tuple_impl! { ((A, B), (C, D)) }
70validator_context_tuple_impl! { ((A, (B, C)), D) }
71validator_context_tuple_impl! { (((A, B), C), D) }
72
73validator_context_tuple_impl! { (A, (B, (C, (D, E)))) }
74validator_context_tuple_impl! { (A, (B, ((C, D), E))) }
75validator_context_tuple_impl! { (A, ((B, C), (D, E))) }
76validator_context_tuple_impl! { (A, ((B, (C, D)), E)) }
77validator_context_tuple_impl! { (A, (((B, C), D), E)) }
78validator_context_tuple_impl! { ((A, B), (C, (D, E))) }
79validator_context_tuple_impl! { ((A, B), ((C, D), E)) }
80validator_context_tuple_impl! { ((A, (B, C)), (D, E)) }
81validator_context_tuple_impl! { (((A, B), C), (D, E)) }
82validator_context_tuple_impl! { ((A, (B, (C, D))), E) }
83validator_context_tuple_impl! { ((A, ((B, C), D)), E) }
84validator_context_tuple_impl! { (((A, B), (C, D)), E) }
85validator_context_tuple_impl! { (((A, (B, C)), D), E) }
86validator_context_tuple_impl! { ((((A, B), C), D), E) }
87
88validator_context_tuple_impl! { ((A, B), ((C, D), (E, F))) }
89validator_context_tuple_impl! { (((A, B), (C, D)), (E, F)) }
90validator_context_tuple_impl! { ((A, (B, C)), ((D, E), F)) }
91validator_context_tuple_impl! { (((A, B), C), ((D, E), F)) }
92
93#[doc(hidden)]
94mod sealed {
95    pub trait Sealed {}
96}
97
98pub trait ValidatorRet: Sized + sealed::Sealed + Copy {
99    fn as_f64(&self) -> f64;
100    /// Maximal-margin value used when a validator is disabled (no
101    /// constraint applied). Mirrors the `as_f64() == INFINITY` convention
102    /// for the unit return type.
103    fn passing() -> Self;
104}
105
106impl sealed::Sealed for () {}
107impl ValidatorRet for () {
108    #[inline]
109    fn as_f64(&self) -> f64 {
110        f64::INFINITY
111    }
112    #[inline]
113    fn passing() -> Self {}
114}
115
116impl sealed::Sealed for f32 {}
117impl ValidatorRet for f32 {
118    #[inline]
119    fn as_f64(&self) -> f64 {
120        *self as f64
121    }
122    #[inline]
123    fn passing() -> Self {
124        f32::INFINITY
125    }
126}
127
128impl sealed::Sealed for f64 {}
129impl ValidatorRet for f64 {
130    #[inline]
131    fn as_f64(&self) -> f64 {
132        *self
133    }
134    #[inline]
135    fn passing() -> Self {
136        f64::INFINITY
137    }
138}
139
140/// SIMD joint-limit batch check, sealed to the scalar types [`KinScalar`]
141/// permits (`f32` → `f32x8`, `f64` → `f64x4`). Bit `i` of `out` is set when
142/// `qs[i]` lies outside `[lower, upper]` on any axis. It is a supertrait of
143/// [`KinScalar`] so the generic [`JointValidator`] batch path vectorises
144/// without narrowing the impl with an extra bound.
145#[doc(hidden)]
146pub trait BatchLimits: num_traits::Float {
147    fn fill_oob<const N: usize>(
148        qs: &[SRobotQ<N, Self>],
149        lower: &[Self; N],
150        upper: &[Self; N],
151        out: &mut BitVec,
152    );
153}
154
155macro_rules! impl_batch_limits {
156    ($scalar:ty, $simd:ty, $lanes:literal) => {
157        impl BatchLimits for $scalar {
158            fn fill_oob<const N: usize>(
159                qs: &[SRobotQ<N, $scalar>],
160                lower: &[$scalar; N],
161                upper: &[$scalar; N],
162                out: &mut BitVec,
163            ) {
164                let n = qs.len();
165                out.clear();
166                out.resize(n, false);
167                let mut i = 0usize;
168                while i + $lanes <= n {
169                    let mut fail = <$simd>::new([0.0; $lanes]);
170                    let mut j = 0usize;
171                    while j < N {
172                        let mut col = [0.0; $lanes];
173                        let mut l = 0usize;
174                        while l < $lanes {
175                            col[l] = qs[i + l].0[j];
176                            l += 1;
177                        }
178                        let cv = <$simd>::new(col);
179                        fail = fail | cv.simd_lt(lower[j]) | cv.simd_gt(upper[j]);
180                        j += 1;
181                    }
182                    let arr = fail.to_array();
183                    let mut l = 0usize;
184                    while l < $lanes {
185                        if arr[l].to_bits() != 0 {
186                            out.set(i + l, true);
187                        }
188                        l += 1;
189                    }
190                    i += $lanes;
191                }
192                while i < n {
193                    let q = &qs[i].0;
194                    let mut j = 0usize;
195                    while j < N {
196                        if q[j] < lower[j] || q[j] > upper[j] {
197                            out.set(i, true);
198                            break;
199                        }
200                        j += 1;
201                    }
202                    i += 1;
203                }
204            }
205        }
206    };
207}
208impl_batch_limits!(f32, f32x8, 8);
209impl_batch_limits!(f64, f64x4, 4);
210
211pub trait Validator<const N: usize, R: ValidatorRet = (), F: KinScalar = f32>:
212    Sized + Clone + Debug + Send + Sync + 'static
213{
214    type Context<'ctx>: ValidatorContext;
215    const VALIDATE_MOTION_IS_CONTINUOUS: bool = false;
216
217    fn validate<'ctx, E: Into<DekeError>, A: SRobotQLike<N, E, F>>(
218        &self,
219        q: A,
220        ctx: &Self::Context<'ctx>,
221    ) -> DekeResult<R>;
222    fn validate_motion<'ctx>(
223        &self,
224        qs: &[SRobotQ<N, F>],
225        ctx: &Self::Context<'ctx>,
226    ) -> DekeResult<R>;
227
228    /// Validate a batch of configurations at once, returning a bitvec whose
229    /// `i`-th bit is set iff `qs[i]` is **invalid** (rejected). The default
230    /// runs [`Validator::validate`] per config; implementors with a batched
231    /// fast path (SIMD, GPU) override it.
232    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], ctx: &Self::Context<'ctx>) -> BitVec {
233        qs.iter().map(|q| self.validate(*q, ctx).is_err()).collect()
234    }
235}
236
237#[derive(Debug, Clone)]
238pub struct ValidatorAnd<A, B>(pub A, pub B);
239
240#[derive(Debug, Clone)]
241pub struct ValidatorOr<A, B>(pub A, pub B);
242
243#[derive(Debug, Clone)]
244pub struct ValidatorNot<A>(pub A);
245
246impl<A, B> ValidatorAnd<A, B> {
247    /// Construct an AND combinator after a compile-time assertion that
248    /// `A` and `B` share the `Validator<N, R, F>` signature passed via
249    /// turbofish or inferred at the call site.
250    ///
251    /// Direct tuple-struct construction (`ValidatorAnd(a, b)`) skips this
252    /// check and is what the [`combine_validators!`] macro emits — both
253    /// forms produce the same value, and the trait impl below only fires
254    /// for `(N, R, F)` triples both members support, so callers that
255    /// dispatch through the trait are safe either way.
256    ///
257    /// [`combine_validators!`]: deke-cricket
258    pub fn new<const N: usize, R: ValidatorRet, F: KinScalar>(a: A, b: B) -> Self
259    where
260        A: Validator<N, R, F>,
261        B: Validator<N, R, F>,
262    {
263        Self(a, b)
264    }
265}
266
267impl<A, B> ValidatorOr<A, B> {
268    /// See [`ValidatorAnd::new`].
269    pub fn new<const N: usize, R: ValidatorRet, F: KinScalar>(a: A, b: B) -> Self
270    where
271        A: Validator<N, R, F>,
272        B: Validator<N, R, F>,
273    {
274        Self(a, b)
275    }
276}
277
278impl<A> ValidatorNot<A> {
279    /// Construct a NOT combinator after a compile-time assertion that `A`
280    /// implements `Validator<N, (), F>`. `Not` is restricted to the unit
281    /// return type because inverting a scalar score isn't well-defined.
282    pub fn new<const N: usize, F: KinScalar>(a: A) -> Self
283    where
284        A: Validator<N, (), F>,
285    {
286        Self(a)
287    }
288}
289
290/// Blanket impls below cover every `(N, R, F)` triple that **both** member
291/// validators implement: a single generic `impl` over `R` and `F` (and `N`
292/// since validators are const-generic over DOF) means monomorphization
293/// fires the impl for every shared signature without manual enumeration.
294impl<const N: usize, F: KinScalar, R: ValidatorRet, A, B> Validator<N, R, F> for ValidatorAnd<A, B>
295where
296    A: Validator<N, R, F>,
297    B: Validator<N, R, F>,
298{
299    type Context<'ctx> = (A::Context<'ctx>, B::Context<'ctx>);
300
301    #[inline]
302    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, F>>(
303        &self,
304        q: Q,
305        ctx: &Self::Context<'ctx>,
306    ) -> DekeResult<R> {
307        let q = q.to_srobotq().map_err(Into::into)?;
308        self.0.validate(q, &ctx.0)?;
309        self.1.validate(q, &ctx.1)
310    }
311
312    #[inline]
313    fn validate_motion<'ctx>(
314        &self,
315        qs: &[SRobotQ<N, F>],
316        ctx: &Self::Context<'ctx>,
317    ) -> DekeResult<R> {
318        self.0.validate_motion(qs, &ctx.0)?;
319        self.1.validate_motion(qs, &ctx.1)
320    }
321
322    #[inline]
323    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], ctx: &Self::Context<'ctx>) -> BitVec {
324        let a = self.0.validate_batched(qs, &ctx.0);
325        let b = self.1.validate_batched(qs, &ctx.1);
326        a.iter().zip(b.iter()).map(|(x, y)| *x | *y).collect()
327    }
328}
329
330impl<const N: usize, F: KinScalar, R: ValidatorRet, A, B> Validator<N, R, F> for ValidatorOr<A, B>
331where
332    A: Validator<N, R, F>,
333    B: Validator<N, R, F>,
334{
335    type Context<'ctx> = (A::Context<'ctx>, B::Context<'ctx>);
336
337    #[inline]
338    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, F>>(
339        &self,
340        q: Q,
341        ctx: &Self::Context<'ctx>,
342    ) -> DekeResult<R> {
343        let q = q.to_srobotq().map_err(Into::into)?;
344        match self.0.validate(q, &ctx.0) {
345            Ok(r) => Ok(r),
346            Err(_) => self.1.validate(q, &ctx.1),
347        }
348    }
349
350    #[inline]
351    fn validate_motion<'ctx>(
352        &self,
353        qs: &[SRobotQ<N, F>],
354        ctx: &Self::Context<'ctx>,
355    ) -> DekeResult<R> {
356        match self.0.validate_motion(qs, &ctx.0) {
357            Ok(r) => Ok(r),
358            Err(_) => self.1.validate_motion(qs, &ctx.1),
359        }
360    }
361
362    #[inline]
363    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], ctx: &Self::Context<'ctx>) -> BitVec {
364        let a = self.0.validate_batched(qs, &ctx.0);
365        let b = self.1.validate_batched(qs, &ctx.1);
366        a.iter().zip(b.iter()).map(|(x, y)| *x & *y).collect()
367    }
368}
369
370/// `Not` is only meaningful for `R = ()` (a pass/fail validator). For
371/// scalar-returning validators the inversion of the return value isn't
372/// well-defined, so the impl is restricted to the unit case.
373impl<const N: usize, F: KinScalar, A> Validator<N, (), F> for ValidatorNot<A>
374where
375    A: Validator<N, (), F>,
376{
377    type Context<'ctx> = A::Context<'ctx>;
378
379    #[inline]
380    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, F>>(
381        &self,
382        q: Q,
383        ctx: &Self::Context<'ctx>,
384    ) -> DekeResult<()> {
385        let q = q.to_srobotq().map_err(Into::into)?;
386        match self.0.validate(q, ctx) {
387            Ok(()) => Err(DekeError::SuperError),
388            Err(_) => Ok(()),
389        }
390    }
391
392    #[inline]
393    fn validate_motion<'ctx>(
394        &self,
395        qs: &[SRobotQ<N, F>],
396        ctx: &Self::Context<'ctx>,
397    ) -> DekeResult<()> {
398        match self.0.validate_motion(qs, ctx) {
399            Ok(()) => Err(DekeError::SuperError),
400            Err(_) => Ok(()),
401        }
402    }
403
404    #[inline]
405    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], ctx: &Self::Context<'ctx>) -> BitVec {
406        self.0
407            .validate_batched(qs, ctx)
408            .iter()
409            .map(|x| !*x)
410            .collect()
411    }
412}
413
414#[derive(Debug, Clone)]
415pub enum MaybeValidator<V> {
416    Active(V),
417    Disabled,
418}
419
420impl<const N: usize, F: KinScalar, R: ValidatorRet, V> Validator<N, R, F> for MaybeValidator<V>
421where
422    V: Validator<N, R, F>,
423{
424    type Context<'ctx> = V::Context<'ctx>;
425
426    #[inline]
427    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, F>>(
428        &self,
429        q: Q,
430        ctx: &Self::Context<'ctx>,
431    ) -> DekeResult<R> {
432        match self {
433            MaybeValidator::Active(v) => v.validate(q, ctx),
434            MaybeValidator::Disabled => Ok(R::passing()),
435        }
436    }
437
438    #[inline]
439    fn validate_motion<'ctx>(
440        &self,
441        qs: &[SRobotQ<N, F>],
442        ctx: &Self::Context<'ctx>,
443    ) -> DekeResult<R> {
444        match self {
445            MaybeValidator::Active(v) => v.validate_motion(qs, ctx),
446            MaybeValidator::Disabled => Ok(R::passing()),
447        }
448    }
449
450    #[inline]
451    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], ctx: &Self::Context<'ctx>) -> BitVec {
452        match self {
453            MaybeValidator::Active(v) => v.validate_batched(qs, ctx),
454            MaybeValidator::Disabled => std::iter::repeat_n(false, qs.len()).collect(),
455        }
456    }
457}
458
459#[derive(Clone)]
460pub struct JointValidator<const N: usize, F: KinScalar = f32> {
461    lower: SRobotQ<N, F>,
462    upper: SRobotQ<N, F>,
463    extras: Option<Arc<[ExtraCheck<N, F>]>>,
464}
465
466impl<const N: usize, F: KinScalar> Debug for JointValidator<N, F> {
467    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468        f.debug_struct("JointValidator")
469            .field("lower", &self.lower)
470            .field("upper", &self.upper)
471            .field(
472                "extras",
473                &format!(
474                    "[{} extra checks]",
475                    self.extras.as_ref().map(|e| e.len()).unwrap_or(0)
476                ),
477            )
478            .finish()
479    }
480}
481
482impl<const N: usize, F: KinScalar> JointValidator<N, F> {
483    pub fn new(lower: SRobotQ<N, F>, upper: SRobotQ<N, F>) -> Self {
484        Self {
485            lower,
486            upper,
487            extras: None,
488        }
489    }
490
491    pub fn new_with_extras(
492        lower: SRobotQ<N, F>,
493        upper: SRobotQ<N, F>,
494        extras: Vec<ExtraCheck<N, F>>,
495    ) -> Self {
496        Self {
497            lower,
498            upper,
499            extras: Some(extras.into()),
500        }
501    }
502}
503
504impl<const N: usize, F: KinScalar> Validator<N, (), F> for JointValidator<N, F> {
505    type Context<'ctx> = ();
506
507    #[inline]
508    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, F>>(
509        &self,
510        q: Q,
511        _ctx: &Self::Context<'ctx>,
512    ) -> DekeResult<()> {
513        let q = q.to_srobotq().map_err(Into::into)?;
514        if q.any_lt(&self.lower) || q.any_gt(&self.upper) {
515            return Err(DekeError::ExceedJointLimits);
516        }
517        if let Some(extras) = &self.extras {
518            for check in extras.iter() {
519                if !check(&q) {
520                    return Err(DekeError::ExceedJointLimits);
521                }
522            }
523        }
524        Ok(())
525    }
526
527    #[inline]
528    fn validate_motion<'ctx>(
529        &self,
530        qs: &[SRobotQ<N, F>],
531        _ctx: &Self::Context<'ctx>,
532    ) -> DekeResult<()> {
533        for q in qs {
534            self.validate(*q, _ctx)?;
535        }
536        Ok(())
537    }
538
539    #[inline]
540    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, F>], _ctx: &Self::Context<'ctx>) -> BitVec {
541        let mut out = BitVec::with_capacity(qs.len());
542        F::fill_oob::<N>(qs, &self.lower.0, &self.upper.0, &mut out);
543        if let Some(extras) = &self.extras {
544            for (i, q) in qs.iter().enumerate() {
545                if out[i] {
546                    continue;
547                }
548                if extras.iter().any(|check| !check(q)) {
549                    out.set(i, true);
550                }
551            }
552        }
553        out
554    }
555}
556
557/// Cross-precision entry point: f32-storage `JointValidator` accepting f64
558/// inputs. The input is narrowed to f32 at the boundary so the same limits
559/// govern both precisions; comparison is done in storage precision.
560impl<const N: usize> Validator<N, (), f64> for JointValidator<N, f32> {
561    type Context<'ctx> = ();
562
563    #[inline]
564    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, f64>>(
565        &self,
566        q: Q,
567        ctx: &Self::Context<'ctx>,
568    ) -> DekeResult<()> {
569        let q64 = q.to_srobotq().map_err(Into::into)?;
570        let q32: SRobotQ<N, f32> = q64.into();
571        <Self as Validator<N, (), f32>>::validate(self, q32, ctx)
572    }
573
574    #[inline]
575    fn validate_motion<'ctx>(
576        &self,
577        qs: &[SRobotQ<N, f64>],
578        ctx: &Self::Context<'ctx>,
579    ) -> DekeResult<()> {
580        for q in qs {
581            let q32: SRobotQ<N, f32> = (*q).into();
582            <Self as Validator<N, (), f32>>::validate(self, q32, ctx)?;
583        }
584        Ok(())
585    }
586
587    #[inline]
588    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, f64>], ctx: &Self::Context<'ctx>) -> BitVec {
589        let q32: Vec<SRobotQ<N, f32>> = qs.iter().map(|q| (*q).into()).collect();
590        <Self as Validator<N, (), f32>>::validate_batched(self, &q32, ctx)
591    }
592}
593
594/// Cross-precision entry point: f64-storage `JointValidator` accepting f32
595/// inputs. The f32 input is widened to f64 (lossless) before comparison.
596impl<const N: usize> Validator<N, (), f32> for JointValidator<N, f64> {
597    type Context<'ctx> = ();
598
599    #[inline]
600    fn validate<'ctx, E: Into<DekeError>, Q: SRobotQLike<N, E, f32>>(
601        &self,
602        q: Q,
603        ctx: &Self::Context<'ctx>,
604    ) -> DekeResult<()> {
605        let q32 = q.to_srobotq().map_err(Into::into)?;
606        let q64: SRobotQ<N, f64> = q32.into();
607        <Self as Validator<N, (), f64>>::validate(self, q64, ctx)
608    }
609
610    #[inline]
611    fn validate_motion<'ctx>(
612        &self,
613        qs: &[SRobotQ<N, f32>],
614        ctx: &Self::Context<'ctx>,
615    ) -> DekeResult<()> {
616        for q in qs {
617            let q64: SRobotQ<N, f64> = (*q).into();
618            <Self as Validator<N, (), f64>>::validate(self, q64, ctx)?;
619        }
620        Ok(())
621    }
622
623    #[inline]
624    fn validate_batched<'ctx>(&self, qs: &[SRobotQ<N, f32>], ctx: &Self::Context<'ctx>) -> BitVec {
625        let q64: Vec<SRobotQ<N, f64>> = qs.iter().map(|q| (*q).into()).collect();
626        <Self as Validator<N, (), f64>>::validate_batched(self, &q64, ctx)
627    }
628}