1use subtle::{Choice, ConstantTimeEq};
7
8#[inline]
13pub fn ct_eq(a: &[u8], b: &[u8]) -> bool {
14 if a.len() != b.len() {
15 return false;
16 }
17 a.ct_eq(b).into()
18}
19
20#[inline]
22pub fn ct_eq_32(a: &[u8; 32], b: &[u8; 32]) -> bool {
23 a.ct_eq(b).into()
24}
25
26#[inline]
28pub fn ct_eq_64(a: &[u8; 64], b: &[u8; 64]) -> bool {
29 a.ct_eq(b).into()
30}
31
32#[inline]
34pub fn ct_eq_slice_32(slice: &[u8], array: &[u8; 32]) -> bool {
35 if slice.len() != 32 {
36 return false;
37 }
38 slice.ct_eq(array.as_slice()).into()
39}
40
41#[inline]
45pub fn ct_select<T: Copy + Default>(condition: bool, a: T, b: T) -> T {
46 let choice = Choice::from(condition as u8);
47 if choice.unwrap_u8() == 1 { a } else { b }
49}
50
51#[derive(Clone)]
53pub struct SecretBytes(Vec<u8>);
54
55impl SecretBytes {
56 pub fn new(bytes: Vec<u8>) -> Self {
58 Self(bytes)
59 }
60
61 pub fn from_slice(slice: &[u8]) -> Self {
63 Self(slice.to_vec())
64 }
65
66 pub fn as_bytes(&self) -> &[u8] {
68 &self.0
69 }
70
71 pub fn len(&self) -> usize {
73 self.0.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.0.is_empty()
79 }
80}
81
82impl ConstantTimeEq for SecretBytes {
83 fn ct_eq(&self, other: &Self) -> Choice {
84 self.0.ct_eq(&other.0)
85 }
86}
87
88impl PartialEq for SecretBytes {
89 fn eq(&self, other: &Self) -> bool {
90 self.ct_eq(other).into()
91 }
92}
93
94impl Eq for SecretBytes {}
95
96impl Drop for SecretBytes {
97 fn drop(&mut self) {
98 for byte in &mut self.0 {
100 unsafe {
101 std::ptr::write_volatile(byte, 0);
102 }
103 }
104 std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
105 }
106}
107
108impl std::fmt::Debug for SecretBytes {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 write!(f, "SecretBytes([REDACTED; {} bytes])", self.0.len())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 #[test]
119 fn test_ct_eq_equal() {
120 let a = [1u8, 2, 3, 4, 5];
121 let b = [1u8, 2, 3, 4, 5];
122 assert!(ct_eq(&a, &b));
123 }
124
125 #[test]
126 fn test_ct_eq_not_equal() {
127 let a = [1u8, 2, 3, 4, 5];
128 let b = [1u8, 2, 3, 4, 6];
129 assert!(!ct_eq(&a, &b));
130 }
131
132 #[test]
133 fn test_ct_eq_different_length() {
134 let a = [1u8, 2, 3, 4, 5];
135 let b = [1u8, 2, 3, 4];
136 assert!(!ct_eq(&a, &b));
137 }
138
139 #[test]
140 fn test_ct_eq_32() {
141 let a = [42u8; 32];
142 let b = [42u8; 32];
143 let c = [43u8; 32];
144 assert!(ct_eq_32(&a, &b));
145 assert!(!ct_eq_32(&a, &c));
146 }
147
148 #[test]
149 fn test_secret_bytes_equality() {
150 let a = SecretBytes::new(vec![1, 2, 3, 4]);
151 let b = SecretBytes::new(vec![1, 2, 3, 4]);
152 let c = SecretBytes::new(vec![1, 2, 3, 5]);
153
154 assert_eq!(a, b);
155 assert_ne!(a, c);
156 }
157
158 #[test]
159 fn test_secret_bytes_debug() {
160 let s = SecretBytes::new(vec![1, 2, 3, 4]);
161 let debug = format!("{:?}", s);
162 assert!(debug.contains("REDACTED"));
163 assert!(!debug.contains("1, 2, 3, 4"));
164 }
165}