1use crate::error::{CaptchaError, Result};
2use crate::image::{NoiseOptions, encode_image, watermark_with_noise};
3use crate::sprite::{SpriteFormat, SpriteTarget, create_sprite};
4use crate::utils::get_timestamp;
5
6use base64::{Engine as _, prelude::BASE64_STANDARD};
7use hmac::{Hmac, Mac};
8use image::{DynamicImage, Limits};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13type HmacSha256 = Hmac<Sha256>;
14
15pub struct CaptchaChallenge<T> {
16 pub sprite: T,
17 #[cfg(any(test, feature = "test-utils"))]
18 pub sprite_dbg: DynamicImage,
19 pub challenge_id: String,
20 pub timestamp: u64,
21 #[cfg(any(test, feature = "test-utils"))]
22 pub correct_number: u8,
23}
24
25#[derive(Clone)]
26pub struct GenerationOptions {
27 pub cell_size: u32,
28 pub sprite_format: SpriteFormat,
29 pub limits: Option<Limits>,
30}
31
32impl Default for GenerationOptions {
33 fn default() -> Self {
34 Self {
35 cell_size: 150,
36 sprite_format: SpriteFormat::default(),
37 limits: None,
38 }
39 }
40}
41
42pub fn generate<T: SpriteTarget>(
43 base_buf: &[u8],
44 secret: &[u8],
45 opts: &GenerationOptions,
46 noise: NoiseOptions,
47) -> Result<CaptchaChallenge<T>> {
48 let (mut sprite, correct_number) = create_sprite(base_buf, opts)?;
49 watermark_with_noise(&mut sprite, noise);
50
51 let rgb = sprite.to_rgb8();
52 let dyn_rgb = DynamicImage::ImageRgb8(rgb);
53
54 let (sprite_buf, mime) =
55 encode_image(&dyn_rgb, &opts.sprite_format).map_err(CaptchaError::Encode)?;
56
57 let sprite = T::from_bytes(sprite_buf, mime);
58
59 let (challenge_id, timestamp) = build_challenge_id(correct_number, secret)?;
60
61 #[cfg(any(test, feature = "test-utils"))]
62 let challenge = CaptchaChallenge {
63 sprite,
64 sprite_dbg: dyn_rgb,
65 challenge_id,
66 timestamp,
67 correct_number,
68 };
69 #[cfg(not(any(test, feature = "test-utils")))]
70 let challenge = CaptchaChallenge {
71 sprite,
72 challenge_id,
73 timestamp,
74 };
75
76 Ok(challenge)
77}
78
79fn build_challenge_id(correct_number: u8, secret: &[u8]) -> Result<(String, u64)> {
80 let timestamp = get_timestamp();
81 let nonce = Uuid::new_v4().to_string();
82
83 let mut mac = HmacSha256::new_from_slice(secret)
84 .map_err(|e| CaptchaError::Internal(format!("create HMAC: {e}")))?;
85 mac.update(nonce.as_bytes());
86 mac.update(&[correct_number]);
87 mac.update(×tamp.to_be_bytes());
88
89 let code = BASE64_STANDARD.encode(mac.finalize().into_bytes());
90
91 Ok((format!("{nonce}:{timestamp}:{code}"), timestamp))
92}
93
94pub fn verify(secret: &[u8], challenge_id: &str, selected_index: u8, ttl: u64) -> bool {
95 let parts: Vec<&str> = challenge_id.split(':').collect();
96 if parts.len() != 3 {
97 return false;
98 }
99
100 let nonce = parts[0];
101 let timestamp: u64 = match parts[1].parse() {
102 Ok(t) => t,
103 Err(_) => return false,
104 };
105 let expected_code_b64 = parts[2];
106
107 let now = get_timestamp();
108 if now > timestamp.saturating_add(ttl) {
109 return false;
110 }
111
112 let mut mac = match HmacSha256::new_from_slice(secret) {
113 Ok(m) => m,
114 Err(_) => return false,
115 };
116 mac.update(nonce.as_bytes());
117 mac.update(&[selected_index]);
118 mac.update(×tamp.to_be_bytes());
119
120 let computed = mac.finalize().into_bytes();
121
122 let expected = match BASE64_STANDARD.decode(expected_code_b64) {
123 Ok(bytes) => bytes,
124 Err(_) => return false,
125 };
126
127 if expected.len() != computed.len() {
128 return false;
129 }
130
131 computed[..].ct_eq(expected.as_slice()).into()
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::SpriteUri;
138 use base64::engine::general_purpose;
139 use std::collections::HashSet;
140 use std::thread::sleep;
141 use std::time::Duration;
142
143 const CHALLENGE_TTL: u64 = 60;
144 const SECRET: &[u8] = b"secret-key";
145
146 fn load_sample_image() -> Vec<u8> {
147 include_bytes!("../assets/sample1.jpg").to_vec()
148 }
149
150 fn generate_challenge() -> CaptchaChallenge<SpriteUri> {
151 let base = load_sample_image();
152 let opts = GenerationOptions {
153 cell_size: 150,
154 sprite_format: SpriteFormat::Jpeg { quality: 70 },
155 limits: None,
156 };
157 generate::<SpriteUri>(&base, SECRET, &opts, NoiseOptions::default())
158 .expect("Failed to generate challenge")
159 }
160
161 #[test]
162 fn test_generate_and_verify() {
163 let challenge = generate_challenge();
164 let result = verify(
165 SECRET,
166 &challenge.challenge_id,
167 challenge.correct_number,
168 CHALLENGE_TTL,
169 );
170
171 assert!(result, "Challenge verification failed for correct index");
177 }
178
179 #[test]
180 fn test_verification_should_fail_for_wrong_guess() {
181 let challenge = generate_challenge();
182
183 let wrong = (challenge.correct_number + 1) % 9;
184 let valid = verify(SECRET, &challenge.challenge_id, wrong, 60);
185
186 assert!(!valid, "Verification should fail for wrong index");
187 }
188
189 #[test]
190 fn test_challenge_correct_index_should_be_random() {
191 let mut seen_indices = HashSet::new();
192
193 for _ in 0..100 {
194 let challenge = generate_challenge();
195 seen_indices.insert(challenge.correct_number);
196 }
197
198 assert!(
199 seen_indices.len() > 1,
200 "Correct index never changes. Challenge randomization failed"
201 );
202 }
203
204 #[test]
205 fn test_challenge_should_expire_after_ttl() {
206 let challenge = generate_challenge();
207
208 sleep(Duration::from_secs(2));
209
210 let expired = verify(SECRET, &challenge.challenge_id, challenge.correct_number, 1);
211
212 assert!(!expired, "Expired challenge passed verification");
213 }
214
215 #[test]
216 fn test_verification_should_not_leak_answer() {
217 use std::time::Instant;
218
219 let challenge = generate_challenge();
220
221 let mut durations = vec![];
222 for i in 0..9 {
223 let start = Instant::now();
224 let _ = verify(SECRET, &challenge.challenge_id, i, 60);
225 durations.push(start.elapsed().as_nanos());
226 }
227
228 let min = *durations.iter().min().unwrap();
229 let max = *durations.iter().max().unwrap();
230 let delta = max - min;
231
232 println!("Timing min={min}ns, max={max}ns, delta={delta}ns");
233
234 assert!(
236 delta < 50_000,
237 "Timing delta too large ({delta}ns), possible side channel",
238 );
239 }
240
241 #[test]
242 fn test_no_false_positives_over_many_challenges() {
243 use std::time::Instant;
244
245 let mut false_positives = 0;
246 let mut durations = vec![];
247
248 for _ in 0..60 {
249 let start = Instant::now();
250 let challenge = generate_challenge();
251 durations.push(start.elapsed().as_nanos());
252
253 for guess in 0..9 {
254 if guess != challenge.correct_number
255 && verify(SECRET, &challenge.challenge_id, guess, 60)
256 {
257 false_positives += 1;
258 }
259 }
260 }
261
262 let min = *durations.iter().min().unwrap();
263 let max = *durations.iter().max().unwrap();
264 let delta = max - min;
265
266 println!("Timing min={min}ns, max={max}ns, delta={delta}ns");
267
268 assert_eq!(
269 false_positives, 0,
270 "Detected {false_positives} false positives — verification failed securely",
271 );
272 }
273
274 #[test]
275 fn test_uniqueness_hmac() {
276 let mut hmacs = HashSet::new();
277
278 for _ in 0..60 {
279 let challenge = generate_challenge();
280 let suffix8 = challenge
281 .challenge_id
282 .rsplit(':')
283 .next()
284 .unwrap_or("")
285 .chars()
286 .rev()
287 .take(8)
288 .collect::<String>();
289
290 hmacs.insert(suffix8);
291
292 sleep(Duration::from_millis(10));
293 }
294
295 assert_eq!(
296 hmacs.len(),
297 60,
298 "HMACs are not unique, potential rainbow table vulnerability"
299 );
300 }
301
302 #[test]
303 fn test_challenge_id_should_be_unforgeable() {
304 let challenge = generate_challenge();
305
306 let parts: Vec<&str> = challenge.challenge_id.split(':').collect();
307 let forged_index = (challenge.correct_number + 1) % 9;
308
309 let mut mac = hmac::Hmac::<Sha256>::new_from_slice(b"BAD_SECRET").unwrap();
311 mac.update(parts[0].as_bytes());
312 mac.update(&[forged_index]);
313 mac.update(&parts[1].parse::<u64>().unwrap().to_be_bytes());
314 let forged_code = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
315
316 let forged_challenge = format!("{}:{}:{}", parts[0], parts[1], forged_code);
317 let valid = verify(SECRET, &forged_challenge, forged_index, CHALLENGE_TTL);
318 assert!(
319 !valid,
320 "Forged challenge ID was accepted. HMAC security failure"
321 )
322 }
323}