1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use std::ops::*;
use std::mem::size_of;

#[derive(Clone, Copy, Debug)]
pub struct Ct<T>(T);

impl<T> Ct<T> where T: Neg<Output=T> + BitAnd<Output=T> + BitXor<Output=T> + Copy {
    pub fn select(self, a: T, b: T) -> Self {
        let mask = -self;
        let ret = mask & (a ^ b);
        ret ^ a
    }
}

macro_rules! implement {
    (Eq for $($t:ty),*) => {
        $(
            impl Ct<$t> {
                fn const_ne(a: $t, b: $t) -> $t {
                    (a.wrapping_sub(b) | b.wrapping_sub(a)) >> (size_of::<$t>() * 8 - 1)
                }
                fn const_eq(a: $t, b: $t) -> $t { 1 ^ Self::const_ne(a, b) }
            }

            impl PartialEq for Ct<$t> {
                fn eq(&self, other: &Self) -> bool { Self::const_eq(self.0, other.0) == 1 }
                fn ne(&self, other: &Self) -> bool { Self::const_ne(self.0, other.0) == 1 }
            }
            impl Eq for Ct<$t> {}

            impl PartialEq<$t> for Ct<$t> {
                fn eq(&self, other: &$t) -> bool { Self::const_eq(self.0, *other) == 1 }
                fn ne(&self, other: &$t) -> bool { Self::const_ne(self.0, *other) == 1 }
            }
         )*
    };
    (binary $op:ident for Ct<$t:ident> with $fun:ident) => {
        impl<$t> $op for Ct<$t> where $t: $op {
            type Output = Ct<$t::Output>;

            fn $fun(self, other: Self) -> Self::Output { Ct($t::$fun(self.0, other.0)) }
        }
        impl<$t> $op<$t> for Ct<$t> where $t: $op {
            type Output = Ct<$t::Output>;

            fn $fun(self, other: $t) -> Self::Output { Ct($t::$fun(self.0, other)) }
        }
    };
    (unary $op:ident for Ct<$t:ident> with $fun:ident) => {
        impl<$t> $op for Ct<$t> where $t: $op {
            type Output = Ct<$t::Output>;

            fn $fun(self) -> Self::Output { Ct($t::$fun(self.0)) }
        }
    };
}

implement!(Eq for u8, u16, u32, u64, usize, i8, i16, i32, i64);

implement!(binary Add for Ct<T> with add);
implement!(binary Sub for Ct<T> with sub);
implement!(binary Mul for Ct<T> with mul);
implement!(binary Div for Ct<T> with div);

implement!(binary BitAnd for Ct<T> with bitand);
implement!(binary BitOr  for Ct<T> with bitor);
implement!(binary BitXor for Ct<T> with bitxor);

implement!(unary Not for Ct<T> with not);
implement!(unary Neg for Ct<T> with neg);

#[cfg(test)]
mod tests {
    use super::Ct;

    #[test]
    fn test_partial_eq() {
        let a = Ct(0u32);
        let b = Ct(1u32);

        assert_eq!(a, a);
        assert_eq!(b, b);
        assert!(a != b);
    }
}