promocrypt_core/
counter.rs1use std::fs::OpenOptions;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::{Path, PathBuf};
11
12use fs2::FileExt;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{PromocryptError, Result};
16
17#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "lowercase")]
20pub enum CounterMode {
21 File {
23 path: String,
25 },
26
27 InBin,
29
30 #[default]
34 #[serde(alias = "manual")] External,
36}
37
38impl CounterMode {
39 pub fn file(path: impl Into<String>) -> Self {
41 CounterMode::File { path: path.into() }
42 }
43
44 pub fn is_in_bin(&self) -> bool {
46 matches!(self, CounterMode::InBin)
47 }
48
49 pub fn is_external(&self) -> bool {
51 matches!(self, CounterMode::External)
52 }
53
54 pub fn file_path(&self) -> Option<&str> {
56 match self {
57 CounterMode::File { path } => Some(path),
58 _ => None,
59 }
60 }
61}
62
63pub struct CounterManager {
65 mode: CounterMode,
66 in_bin_value: u64,
68 in_bin_modified: bool,
70}
71
72impl CounterManager {
73 pub fn new(mode: CounterMode) -> Self {
75 Self {
76 mode,
77 in_bin_value: 0,
78 in_bin_modified: false,
79 }
80 }
81
82 pub fn with_in_bin_value(mode: CounterMode, initial_value: u64) -> Self {
84 Self {
85 mode,
86 in_bin_value: initial_value,
87 in_bin_modified: false,
88 }
89 }
90
91 pub fn mode(&self) -> &CounterMode {
93 &self.mode
94 }
95
96 pub fn get(&self) -> Result<u64> {
98 match &self.mode {
99 CounterMode::File { path } => read_counter_file(Path::new(path)),
100 CounterMode::InBin => Ok(self.in_bin_value),
101 CounterMode::External => Err(PromocryptError::InvalidArgument(
102 "Cannot get counter in external mode - use generate_at() with explicit counter"
103 .to_string(),
104 )),
105 }
106 }
107
108 pub fn reserve(&mut self, count: u64) -> Result<u64> {
113 match &self.mode {
114 CounterMode::File { path } => {
115 let path = PathBuf::from(path);
116 reserve_counter_file(&path, count)
117 }
118 CounterMode::InBin => {
119 let start = self.in_bin_value;
120 self.in_bin_value = self
121 .in_bin_value
122 .checked_add(count)
123 .ok_or(PromocryptError::CounterOverflow)?;
124 self.in_bin_modified = true;
125 Ok(start)
126 }
127 CounterMode::External => Err(PromocryptError::InvalidArgument(
128 "Cannot reserve counter in external mode - counter is managed externally"
129 .to_string(),
130 )),
131 }
132 }
133
134 pub fn set(&mut self, value: u64) -> Result<()> {
136 match &self.mode {
137 CounterMode::File { path } => write_counter_file(Path::new(path), value),
138 CounterMode::InBin => {
139 self.in_bin_value = value;
140 self.in_bin_modified = true;
141 Ok(())
142 }
143 CounterMode::External => Err(PromocryptError::InvalidArgument(
144 "Cannot set counter in external mode - counter is managed externally".to_string(),
145 )),
146 }
147 }
148
149 pub fn is_in_bin_modified(&self) -> bool {
151 self.in_bin_modified && matches!(self.mode, CounterMode::InBin)
152 }
153
154 pub fn in_bin_value(&self) -> u64 {
156 self.in_bin_value
157 }
158
159 pub fn mark_saved(&mut self) {
161 self.in_bin_modified = false;
162 }
163
164 pub fn set_mode(&mut self, mode: CounterMode) {
166 self.mode = mode;
167 }
168}
169
170pub fn read_counter_file(path: &Path) -> Result<u64> {
174 if let Some(parent) = path.parent()
176 && !parent.exists()
177 {
178 std::fs::create_dir_all(parent)?;
179 }
180
181 let mut file = OpenOptions::new()
183 .read(true)
184 .write(true)
185 .create(true)
186 .truncate(false)
187 .open(path)?;
188
189 file.lock_shared()
191 .map_err(|_| PromocryptError::CounterLocked)?;
192
193 let mut buffer = [0u8; 8];
194 let bytes_read = file.read(&mut buffer)?;
195
196 file.unlock().ok();
198
199 if bytes_read == 0 {
200 return Ok(0);
202 }
203
204 if bytes_read != 8 {
205 return Err(PromocryptError::InvalidFileFormatDetails(
206 "Counter file corrupted".to_string(),
207 ));
208 }
209
210 Ok(u64::from_le_bytes(buffer))
211}
212
213pub fn write_counter_file(path: &Path, value: u64) -> Result<()> {
215 if let Some(parent) = path.parent()
217 && !parent.exists()
218 {
219 std::fs::create_dir_all(parent)?;
220 }
221
222 let mut file = OpenOptions::new()
223 .read(true)
224 .write(true)
225 .create(true)
226 .truncate(false)
227 .open(path)?;
228
229 file.lock_exclusive()
231 .map_err(|_| PromocryptError::CounterLocked)?;
232
233 file.seek(SeekFrom::Start(0))?;
234 file.write_all(&value.to_le_bytes())?;
235 file.set_len(8)?;
236 file.sync_all()?;
237
238 file.unlock().ok();
240
241 Ok(())
242}
243
244pub fn reserve_counter_file(path: &Path, count: u64) -> Result<u64> {
248 if let Some(parent) = path.parent()
250 && !parent.exists()
251 {
252 std::fs::create_dir_all(parent)?;
253 }
254
255 let mut file = OpenOptions::new()
256 .read(true)
257 .write(true)
258 .create(true)
259 .truncate(false)
260 .open(path)?;
261
262 file.lock_exclusive()
264 .map_err(|_| PromocryptError::CounterLocked)?;
265
266 let mut buffer = [0u8; 8];
268 let bytes_read = file.read(&mut buffer)?;
269
270 let current = if bytes_read == 0 {
271 0
272 } else if bytes_read == 8 {
273 u64::from_le_bytes(buffer)
274 } else {
275 file.unlock().ok();
276 return Err(PromocryptError::InvalidFileFormatDetails(
277 "Counter file corrupted".to_string(),
278 ));
279 };
280
281 let new_value = current
283 .checked_add(count)
284 .ok_or(PromocryptError::CounterOverflow)?;
285
286 file.seek(SeekFrom::Start(0))?;
288 file.write_all(&new_value.to_le_bytes())?;
289 file.set_len(8)?;
290 file.sync_all()?;
291
292 file.unlock().ok();
294
295 Ok(current)
296}
297
298pub fn increment_counter_file(path: &Path, amount: u64) -> Result<u64> {
302 let start = reserve_counter_file(path, amount)?;
303 Ok(start + amount)
304}
305
306pub fn counter_to_bytes(value: u64) -> [u8; 8] {
308 value.to_le_bytes()
309}
310
311pub fn counter_from_bytes(bytes: &[u8]) -> Result<u64> {
313 if bytes.len() < 8 {
314 return Err(PromocryptError::InvalidFileFormatDetails(
315 "Counter data too short".to_string(),
316 ));
317 }
318
319 let mut arr = [0u8; 8];
320 arr.copy_from_slice(&bytes[..8]);
321 Ok(u64::from_le_bytes(arr))
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use tempfile::TempDir;
328
329 #[test]
330 fn test_counter_mode_default() {
331 let mode = CounterMode::default();
332 assert!(matches!(mode, CounterMode::External));
333 }
334
335 #[test]
336 fn test_counter_mode_file() {
337 let mode = CounterMode::file("/tmp/test.counter");
338 assert!(matches!(mode, CounterMode::File { .. }));
339 assert_eq!(mode.file_path(), Some("/tmp/test.counter"));
340 }
341
342 #[test]
343 fn test_counter_manager_external() {
344 let manager = CounterManager::new(CounterMode::External);
345 assert!(manager.get().is_err());
346 }
347
348 #[test]
349 fn test_counter_manager_in_bin() {
350 let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, 100);
351
352 assert_eq!(manager.get().unwrap(), 100);
353
354 let start = manager.reserve(10).unwrap();
355 assert_eq!(start, 100);
356 assert_eq!(manager.get().unwrap(), 110);
357 assert!(manager.is_in_bin_modified());
358 }
359
360 #[test]
361 fn test_counter_file_read_write() {
362 let temp_dir = TempDir::new().unwrap();
363 let path = temp_dir.path().join("test.counter");
364
365 let value = read_counter_file(&path).unwrap();
367 assert_eq!(value, 0);
368
369 write_counter_file(&path, 12345).unwrap();
371
372 let value = read_counter_file(&path).unwrap();
374 assert_eq!(value, 12345);
375 }
376
377 #[test]
378 fn test_counter_file_reserve() {
379 let temp_dir = TempDir::new().unwrap();
380 let path = temp_dir.path().join("test.counter");
381
382 let start1 = reserve_counter_file(&path, 100).unwrap();
384 assert_eq!(start1, 0);
385
386 let start2 = reserve_counter_file(&path, 50).unwrap();
388 assert_eq!(start2, 100);
389
390 let final_value = read_counter_file(&path).unwrap();
392 assert_eq!(final_value, 150);
393 }
394
395 #[test]
396 fn test_counter_file_increment() {
397 let temp_dir = TempDir::new().unwrap();
398 let path = temp_dir.path().join("test.counter");
399
400 let new_value = increment_counter_file(&path, 10).unwrap();
401 assert_eq!(new_value, 10);
402
403 let new_value = increment_counter_file(&path, 5).unwrap();
404 assert_eq!(new_value, 15);
405 }
406
407 #[test]
408 fn test_counter_bytes_roundtrip() {
409 let value = 0x123456789ABCDEF0u64;
410 let bytes = counter_to_bytes(value);
411 let recovered = counter_from_bytes(&bytes).unwrap();
412 assert_eq!(value, recovered);
413 }
414
415 #[test]
416 fn test_counter_overflow() {
417 let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, u64::MAX - 5);
418
419 let result = manager.reserve(10);
421 assert!(result.is_err());
422 assert!(matches!(
423 result.unwrap_err(),
424 PromocryptError::CounterOverflow
425 ));
426 }
427
428 #[test]
429 fn test_counter_manager_file_mode() {
430 let temp_dir = TempDir::new().unwrap();
431 let path = temp_dir.path().join("test.counter");
432 let path_str = path.to_str().unwrap().to_string();
433
434 let mut manager = CounterManager::new(CounterMode::File { path: path_str });
435
436 let value = manager.get().unwrap();
438 assert_eq!(value, 0);
439
440 let start = manager.reserve(100).unwrap();
442 assert_eq!(start, 0);
443
444 let value = manager.get().unwrap();
446 assert_eq!(value, 100);
447 }
448
449 #[test]
450 fn test_counter_mode_is_external() {
451 assert!(
452 !CounterMode::File {
453 path: "test".to_string()
454 }
455 .is_external()
456 );
457 assert!(!CounterMode::InBin.is_external());
458 assert!(CounterMode::External.is_external());
459 }
460
461 #[test]
462 fn test_counter_mode_serde() {
463 let modes = vec![
464 CounterMode::File {
465 path: "/tmp/test.counter".to_string(),
466 },
467 CounterMode::InBin,
468 CounterMode::External,
469 ];
470
471 for mode in modes {
472 let json = serde_json::to_string(&mode).unwrap();
473 let recovered: CounterMode = serde_json::from_str(&json).unwrap();
474 assert_eq!(mode, recovered);
475 }
476 }
477
478 #[test]
479 fn test_counter_mode_serde_backwards_compat() {
480 let json = r#"{"type":"manual"}"#;
482 let mode: CounterMode = serde_json::from_str(json).unwrap();
483 assert!(matches!(mode, CounterMode::External));
484
485 let json = r#"{"type":"external"}"#;
487 let mode: CounterMode = serde_json::from_str(json).unwrap();
488 assert!(matches!(mode, CounterMode::External));
489
490 let external_json = serde_json::to_string(&CounterMode::External).unwrap();
492 assert!(external_json.contains("external"));
493 }
494}