use crate::{CtAssign, CtAssignSlice, CtEq, CtEqSlice, CtSelectUsingCtAssign};
use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not};
#[cfg(feature = "subtle")]
use crate::CtSelect;
macro_rules! bitle {
($x:expr, $y:expr, $bits:expr) => {
(((!$x) | $y) & (($x ^ $y) | !($y.wrapping_sub($x)))) >> ($bits - 1)
};
}
macro_rules! bitlt {
($x:expr, $y:expr, $bits:expr) => {
(((!$x) & $y) | (((!$x) | $y) & $x.wrapping_sub($y))) >> ($bits - 1)
};
}
macro_rules! bitnz {
($value:expr, $bits:expr) => {
($value | $value.wrapping_neg()) >> ($bits - 1)
};
}
#[derive(Copy, Clone, Debug)]
pub struct Choice(pub(crate) u8);
impl Choice {
pub const FALSE: Self = Self(0);
pub const TRUE: Self = Self(1);
#[inline]
#[must_use]
pub const fn and(self, rhs: Choice) -> Choice {
Self(self.0 & rhs.0)
}
#[inline]
#[must_use]
pub const fn or(self, rhs: Choice) -> Choice {
Self(self.0 | rhs.0)
}
#[inline]
#[must_use]
pub const fn xor(self, rhs: Choice) -> Choice {
Self(self.0 ^ rhs.0)
}
#[inline]
#[must_use]
pub const fn not(self) -> Choice {
Self(self.0 ^ 1)
}
#[inline]
#[must_use]
pub const fn eq(self, other: Self) -> Self {
Self::ne(self, other).not()
}
#[inline]
#[must_use]
pub const fn ne(self, other: Self) -> Self {
Self::xor(self, other)
}
#[inline]
#[must_use]
#[allow(clippy::cast_sign_loss)]
pub const fn from_i64_eq(x: i64, y: i64) -> Self {
Self::from_u64_nz(x as u64 ^ y as u64).not()
}
#[inline]
#[must_use]
pub const fn from_u8_eq(x: u8, y: u8) -> Self {
Self::from_u8_nz(x ^ y).not()
}
#[inline]
#[must_use]
pub const fn from_u8_le(x: u8, y: u8) -> Self {
Self::from_u8_lsb(bitle!(x, y, u8::BITS))
}
#[inline]
#[must_use]
pub const fn from_u8_lsb(value: u8) -> Self {
Self(value & 0x1)
}
#[inline]
#[must_use]
pub const fn from_u8_lt(x: u8, y: u8) -> Self {
Self::from_u8_lsb(bitlt!(x, y, u8::BITS))
}
#[inline]
#[must_use]
pub const fn from_u8_nz(value: u8) -> Self {
Self::from_u8_lsb(bitnz!(value, u8::BITS))
}
#[inline]
#[must_use]
pub const fn from_u16_eq(x: u16, y: u16) -> Self {
Self::from_u16_nz(x ^ y).not()
}
#[inline]
#[must_use]
pub const fn from_u16_le(x: u16, y: u16) -> Self {
Self::from_u16_lsb(bitle!(x, y, u16::BITS))
}
#[inline]
#[must_use]
pub const fn from_u16_lsb(value: u16) -> Self {
Self((value & 0x1) as u8)
}
#[inline]
#[must_use]
pub const fn from_u16_lt(x: u16, y: u16) -> Self {
Self::from_u16_lsb(bitlt!(x, y, u16::BITS))
}
#[inline]
#[must_use]
pub const fn from_u16_nz(value: u16) -> Self {
Self::from_u16_lsb(bitnz!(value, u16::BITS))
}
#[inline]
#[must_use]
pub const fn from_u32_eq(x: u32, y: u32) -> Self {
Self::from_u32_nz(x ^ y).not()
}
#[inline]
#[must_use]
pub const fn from_u32_le(x: u32, y: u32) -> Self {
Self::from_u32_lsb(bitle!(x, y, u32::BITS))
}
#[inline]
#[must_use]
pub const fn from_u32_lsb(value: u32) -> Self {
Self((value & 0x1) as u8)
}
#[inline]
#[must_use]
pub const fn from_u32_lt(x: u32, y: u32) -> Self {
Self::from_u32_lsb(bitlt!(x, y, u32::BITS))
}
#[inline]
#[must_use]
pub const fn from_u32_nz(value: u32) -> Self {
Self::from_u32_lsb(bitnz!(value, u32::BITS))
}
#[inline]
#[must_use]
pub const fn from_u64_eq(x: u64, y: u64) -> Self {
Self::from_u64_nz(x ^ y).not()
}
#[inline]
#[must_use]
pub const fn from_u64_le(x: u64, y: u64) -> Self {
Self::from_u64_lsb(bitle!(x, y, u64::BITS))
}
#[inline]
#[must_use]
pub const fn from_u64_lsb(value: u64) -> Self {
Self((value & 0x1) as u8)
}
#[inline]
#[must_use]
pub const fn from_u64_lt(x: u64, y: u64) -> Self {
Self::from_u64_lsb(bitlt!(x, y, u64::BITS))
}
#[inline]
#[must_use]
pub const fn from_u64_nz(value: u64) -> Self {
Self::from_u64_lsb(bitnz!(value, u64::BITS))
}
#[inline]
#[must_use]
pub const fn from_u128_eq(x: u128, y: u128) -> Self {
Self::from_u128_nz(x ^ y).not()
}
#[inline]
#[must_use]
pub const fn from_u128_le(x: u128, y: u128) -> Self {
Self::from_u128_lsb(bitle!(x, y, u128::BITS))
}
#[inline]
#[must_use]
pub const fn from_u128_lsb(value: u128) -> Self {
Self((value & 1) as u8)
}
#[inline]
#[must_use]
pub const fn from_u128_lt(x: u128, y: u128) -> Self {
Self::from_u128_lsb(bitlt!(x, y, u128::BITS))
}
#[inline]
#[must_use]
pub const fn from_u128_nz(value: u128) -> Self {
Self::from_u128_lsb(bitnz!(value, u128::BITS))
}
#[inline]
#[must_use]
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
pub const fn select_i64(self, a: i64, b: i64) -> i64 {
self.select_u64(a as u64, b as u64) as i64
}
#[inline]
#[must_use]
pub const fn select_u8(self, a: u8, b: u8) -> u8 {
a ^ (self.to_u8_mask() & (a ^ b))
}
#[inline]
#[must_use]
pub const fn select_u16(self, a: u16, b: u16) -> u16 {
a ^ (self.to_u16_mask() & (a ^ b))
}
#[inline]
#[must_use]
pub const fn select_u32(self, a: u32, b: u32) -> u32 {
a ^ (self.to_u32_mask() & (a ^ b))
}
#[inline]
#[must_use]
pub const fn select_u64(self, a: u64, b: u64) -> u64 {
a ^ (self.to_u64_mask() & (a ^ b))
}
#[inline]
#[must_use]
pub const fn select_u128(self, a: u128, b: u128) -> u128 {
a ^ (self.to_u128_mask() & (a ^ b))
}
#[must_use]
pub fn to_bool(self) -> bool {
self.to_u8() != 0
}
#[must_use]
pub fn to_u8(self) -> u8 {
core::hint::black_box(self.0)
}
#[must_use]
pub const fn to_bool_vartime(self) -> bool {
self.0 != 0
}
#[must_use]
pub const fn to_u8_vartime(self) -> u8 {
self.0
}
#[inline]
#[must_use]
pub const fn to_u8_mask(self) -> u8 {
self.0.wrapping_neg()
}
#[inline]
#[must_use]
pub const fn to_u16_mask(self) -> u16 {
(self.0 as u16).wrapping_neg()
}
#[inline]
#[must_use]
pub const fn to_u32_mask(self) -> u32 {
(self.0 as u32).wrapping_neg()
}
#[inline]
#[must_use]
pub const fn to_u64_mask(self) -> u64 {
(self.0 as u64).wrapping_neg()
}
#[inline]
#[must_use]
pub const fn to_u128_mask(self) -> u128 {
(self.0 as u128).wrapping_neg()
}
}
impl BitAnd for Choice {
type Output = Choice;
#[inline]
fn bitand(self, rhs: Choice) -> Choice {
self.and(rhs)
}
}
impl BitAndAssign for Choice {
#[inline]
fn bitand_assign(&mut self, rhs: Choice) {
*self = *self & rhs;
}
}
impl BitOr for Choice {
type Output = Choice;
#[inline]
fn bitor(self, rhs: Choice) -> Choice {
self.or(rhs)
}
}
impl BitOrAssign for Choice {
#[inline]
fn bitor_assign(&mut self, rhs: Choice) {
*self = *self | rhs;
}
}
impl BitXor for Choice {
type Output = Choice;
#[inline]
fn bitxor(self, rhs: Choice) -> Choice {
Choice(self.0 ^ rhs.0)
}
}
impl BitXorAssign for Choice {
#[inline]
fn bitxor_assign(&mut self, rhs: Choice) {
*self = *self ^ rhs;
}
}
impl CtAssign for Choice {
#[inline]
fn ct_assign(&mut self, other: &Self, choice: Choice) {
self.0.ct_assign(&other.0, choice);
}
}
impl CtAssignSlice for Choice {}
impl CtSelectUsingCtAssign for Choice {}
impl CtEq for Choice {
#[inline]
fn ct_eq(&self, other: &Self) -> Self {
self.0.ct_eq(&other.0)
}
}
impl CtEqSlice for Choice {}
impl From<u8> for Choice {
fn from(value: u8) -> Self {
Choice::from_u8_lsb(value)
}
}
impl From<Choice> for u8 {
fn from(choice: Choice) -> u8 {
choice.to_u8()
}
}
impl From<Choice> for bool {
fn from(choice: Choice) -> bool {
choice.to_bool()
}
}
impl Not for Choice {
type Output = Choice;
#[inline]
fn not(self) -> Choice {
self.not()
}
}
#[cfg(feature = "subtle")]
impl From<subtle::Choice> for Choice {
#[inline]
fn from(choice: subtle::Choice) -> Choice {
Choice(choice.unwrap_u8())
}
}
#[cfg(feature = "subtle")]
impl From<Choice> for subtle::Choice {
#[inline]
fn from(choice: Choice) -> subtle::Choice {
subtle::Choice::from(choice.0)
}
}
#[cfg(feature = "subtle")]
impl subtle::ConditionallySelectable for Choice {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
CtSelect::ct_select(a, b, choice.into())
}
}
#[cfg(feature = "subtle")]
impl subtle::ConstantTimeEq for Choice {
#[inline]
fn ct_eq(&self, other: &Self) -> subtle::Choice {
CtEq::ct_eq(self, other).into()
}
}
#[cfg(test)]
mod tests {
use super::Choice;
use crate::{CtEq, CtSelect};
#[test]
fn ct_eq() {
let a = Choice::TRUE;
let b = Choice::TRUE;
let c = Choice::FALSE;
assert!(a.ct_eq(&b).to_bool());
assert!(!a.ct_eq(&c).to_bool());
assert!(!b.ct_eq(&c).to_bool());
assert!(!a.ct_ne(&b).to_bool());
assert!(a.ct_ne(&c).to_bool());
assert!(b.ct_ne(&c).to_bool());
}
#[test]
fn ct_select() {
let a = Choice::FALSE;
let b = Choice::TRUE;
assert_eq!(a.ct_select(&b, Choice::FALSE).to_bool(), a.to_bool());
assert_eq!(a.ct_select(&b, Choice::TRUE).to_bool(), b.to_bool());
}
#[test]
fn and() {
assert_eq!((Choice::FALSE & Choice::FALSE).to_u8(), 0);
assert_eq!((Choice::TRUE & Choice::FALSE).to_u8(), 0);
assert_eq!((Choice::FALSE & Choice::TRUE).to_u8(), 0);
assert_eq!((Choice::TRUE & Choice::TRUE).to_u8(), 1);
}
#[test]
fn or() {
assert_eq!((Choice::FALSE | Choice::FALSE).to_u8(), 0);
assert_eq!((Choice::TRUE | Choice::FALSE).to_u8(), 1);
assert_eq!((Choice::FALSE | Choice::TRUE).to_u8(), 1);
assert_eq!((Choice::TRUE | Choice::TRUE).to_u8(), 1);
}
#[test]
fn xor() {
assert_eq!((Choice::FALSE ^ Choice::FALSE).to_u8(), 0);
assert_eq!((Choice::TRUE ^ Choice::FALSE).to_u8(), 1);
assert_eq!((Choice::FALSE ^ Choice::TRUE).to_u8(), 1);
assert_eq!((Choice::TRUE ^ Choice::TRUE).to_u8(), 0);
}
#[test]
fn not() {
assert_eq!(Choice::FALSE.not().to_u8(), 1);
assert_eq!(Choice::TRUE.not().to_u8(), 0);
}
#[test]
fn from_i64_eq() {
assert!(Choice::from_i64_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_i64_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u8_eq() {
assert!(Choice::from_u8_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u8_le() {
assert!(Choice::from_u8_le(0, 0).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u8_le(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_le(1, 1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u8_le(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u8_lsb() {
assert!(Choice::from_u8_lsb(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_lsb(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u8_lsb(2).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_lsb(3).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u8_lt() {
assert!(Choice::from_u8_lt(0, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_lt(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_lt(1, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_lt(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u8_nz() {
assert!(Choice::from_u8_nz(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u8_nz(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u8_nz(2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u16_eq() {
assert!(Choice::from_u16_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u16_le() {
assert!(Choice::from_u16_le(0, 0).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u16_le(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_le(1, 1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u16_le(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u16_lsb() {
assert!(Choice::from_u16_lsb(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_lsb(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u16_lsb(2).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_lsb(3).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u16_lt() {
assert!(Choice::from_u16_lt(0, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_lt(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_lt(1, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_lt(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u16_nz() {
assert!(Choice::from_u16_nz(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u16_nz(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u16_nz(2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u32_eq() {
assert!(Choice::from_u32_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u32_le() {
assert!(Choice::from_u32_le(0, 0).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u32_le(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_le(1, 1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u32_le(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u32_lsb() {
assert!(Choice::from_u32_lsb(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_lsb(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u32_lsb(2).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_lsb(3).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u32_lt() {
assert!(Choice::from_u32_lt(0, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_lt(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_lt(1, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_lt(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u32_nz() {
assert!(Choice::from_u32_nz(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u32_nz(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u32_nz(2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u64_eq() {
assert!(Choice::from_u64_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u64_le() {
assert!(Choice::from_u64_le(0, 0).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u64_le(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_le(1, 1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u64_le(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u64_lsb() {
assert!(Choice::from_u64_lsb(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_lsb(1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u64_lt() {
assert!(Choice::from_u64_lt(0, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_lt(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_lt(1, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_lt(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u64_nz() {
assert!(Choice::from_u64_nz(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u64_nz(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u64_nz(2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u128_eq() {
assert!(Choice::from_u128_eq(0, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_eq(1, 1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u128_le() {
assert!(Choice::from_u128_le(0, 0).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u128_le(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_le(1, 1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u128_le(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u128_lsb() {
assert!(Choice::from_u128_lsb(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_lsb(1).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u128_lt() {
assert!(Choice::from_u128_lt(0, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_lt(1, 0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_lt(1, 1).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_lt(1, 2).eq(Choice::TRUE).to_bool());
}
#[test]
fn from_u128_nz() {
assert!(Choice::from_u128_nz(0).eq(Choice::FALSE).to_bool());
assert!(Choice::from_u128_nz(1).eq(Choice::TRUE).to_bool());
assert!(Choice::from_u128_nz(2).eq(Choice::TRUE).to_bool());
}
#[test]
fn select_i64() {
let a: i64 = 1;
let b: i64 = 2;
assert_eq!(Choice::TRUE.select_i64(a, b), b);
assert_eq!(Choice::FALSE.select_i64(a, b), a);
}
#[test]
fn select_u8() {
let a: u8 = 1;
let b: u8 = 2;
assert_eq!(Choice::TRUE.select_u8(a, b), b);
assert_eq!(Choice::FALSE.select_u8(a, b), a);
}
#[test]
fn select_u16() {
let a: u16 = 1;
let b: u16 = 2;
assert_eq!(Choice::TRUE.select_u16(a, b), b);
assert_eq!(Choice::FALSE.select_u16(a, b), a);
}
#[test]
fn select_u32() {
let a: u32 = 1;
let b: u32 = 2;
assert_eq!(Choice::TRUE.select_u32(a, b), b);
assert_eq!(Choice::FALSE.select_u32(a, b), a);
}
#[test]
fn select_u64() {
let a: u64 = 1;
let b: u64 = 2;
assert_eq!(Choice::TRUE.select_u64(a, b), b);
assert_eq!(Choice::FALSE.select_u64(a, b), a);
}
#[test]
fn select_u128() {
let a: u128 = 1;
let b: u128 = 2;
assert_eq!(Choice::TRUE.select_u128(a, b), b);
assert_eq!(Choice::FALSE.select_u128(a, b), a);
}
#[test]
fn to_bool() {
assert!(!Choice::FALSE.to_bool());
assert!(Choice::TRUE.to_bool());
}
#[test]
fn to_u8() {
assert_eq!(Choice::FALSE.to_u8(), 0);
assert_eq!(Choice::TRUE.to_u8(), 1);
}
#[test]
fn to_u8_mask() {
assert_eq!(Choice::FALSE.to_u8_mask(), 0);
assert_eq!(Choice::TRUE.to_u8_mask(), u8::MAX);
}
#[test]
fn to_u16_mask() {
assert_eq!(Choice::FALSE.to_u16_mask(), 0);
assert_eq!(Choice::TRUE.to_u16_mask(), u16::MAX);
}
#[test]
fn to_u32_mask() {
assert_eq!(Choice::FALSE.to_u32_mask(), 0);
assert_eq!(Choice::TRUE.to_u32_mask(), u32::MAX);
}
#[test]
fn to_u64_mask() {
assert_eq!(Choice::FALSE.to_u64_mask(), 0);
assert_eq!(Choice::TRUE.to_u64_mask(), u64::MAX);
}
#[test]
fn to_u128_mask() {
assert_eq!(Choice::FALSE.to_u128_mask(), 0);
assert_eq!(Choice::TRUE.to_u128_mask(), u128::MAX);
}
}