promocrypt_core/
counter.rs

1//! Counter management for promotional code generation.
2//!
3//! Supports multiple counter storage modes:
4//! - File: Separate `.counter` file with OS locking
5//! - InBin: Counter stored in the .bin mutable section
6//! - Manual: User-provided counter values
7//! - External: Consumer-managed counter (e.g., database column)
8
9use std::fs::OpenOptions;
10use std::io::{Read, Seek, SeekFrom, Write};
11use std::path::{Path, PathBuf};
12
13use fs2::FileExt;
14use serde::{Deserialize, Serialize};
15
16use crate::error::{PromocryptError, Result};
17
18/// Counter storage mode.
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(tag = "type", rename_all = "lowercase")]
21pub enum CounterMode {
22    /// Store counter in a separate file.
23    File {
24        /// Path to the counter file.
25        path: String,
26    },
27
28    /// Store counter in the .bin file's mutable section.
29    InBin,
30
31    /// No storage - counter provided manually by caller.
32    #[default]
33    Manual,
34
35    /// External counter management (e.g., database column).
36    /// The consumer is responsible for counter storage and atomicity.
37    External,
38}
39
40impl CounterMode {
41    /// Create a file-based counter mode.
42    pub fn file(path: impl Into<String>) -> Self {
43        CounterMode::File { path: path.into() }
44    }
45
46    /// Check if this mode stores counter in-bin.
47    pub fn is_in_bin(&self) -> bool {
48        matches!(self, CounterMode::InBin)
49    }
50
51    /// Check if this mode requires external counter management.
52    pub fn is_external(&self) -> bool {
53        matches!(self, CounterMode::External | CounterMode::Manual)
54    }
55
56    /// Get the counter file path, if applicable.
57    pub fn file_path(&self) -> Option<&str> {
58        match self {
59            CounterMode::File { path } => Some(path),
60            _ => None,
61        }
62    }
63}
64
65/// Counter manager for different storage modes.
66pub struct CounterManager {
67    mode: CounterMode,
68    /// In-memory counter for InBin mode
69    in_bin_value: u64,
70    /// Flag indicating if in-bin value has been modified
71    in_bin_modified: bool,
72}
73
74impl CounterManager {
75    /// Create a new counter manager.
76    pub fn new(mode: CounterMode) -> Self {
77        Self {
78            mode,
79            in_bin_value: 0,
80            in_bin_modified: false,
81        }
82    }
83
84    /// Create a counter manager with an initial in-bin value.
85    pub fn with_in_bin_value(mode: CounterMode, initial_value: u64) -> Self {
86        Self {
87            mode,
88            in_bin_value: initial_value,
89            in_bin_modified: false,
90        }
91    }
92
93    /// Get the counter mode.
94    pub fn mode(&self) -> &CounterMode {
95        &self.mode
96    }
97
98    /// Get the current counter value.
99    pub fn get(&self) -> Result<u64> {
100        match &self.mode {
101            CounterMode::File { path } => read_counter_file(Path::new(path)),
102            CounterMode::InBin => Ok(self.in_bin_value),
103            CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
104                "Cannot get counter in manual/external mode - counter must be provided".to_string(),
105            )),
106        }
107    }
108
109    /// Increment counter and return the starting value for a batch.
110    ///
111    /// For a batch of `count` codes, returns the starting counter value
112    /// and increments the stored counter by `count`.
113    pub fn reserve(&mut self, count: u64) -> Result<u64> {
114        match &self.mode {
115            CounterMode::File { path } => {
116                let path = PathBuf::from(path);
117                reserve_counter_file(&path, count)
118            }
119            CounterMode::InBin => {
120                let start = self.in_bin_value;
121                self.in_bin_value = self
122                    .in_bin_value
123                    .checked_add(count)
124                    .ok_or(PromocryptError::CounterOverflow)?;
125                self.in_bin_modified = true;
126                Ok(start)
127            }
128            CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
129                "Cannot reserve counter in manual/external mode".to_string(),
130            )),
131        }
132    }
133
134    /// Set the counter value (for initialization or reset).
135    pub fn set(&mut self, value: u64) -> Result<()> {
136        match &self.mode {
137            CounterMode::File { path } => write_counter_file(Path::new(path), value),
138            CounterMode::InBin => {
139                self.in_bin_value = value;
140                self.in_bin_modified = true;
141                Ok(())
142            }
143            CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
144                "Cannot set counter in manual/external mode".to_string(),
145            )),
146        }
147    }
148
149    /// Check if the in-bin counter has been modified.
150    pub fn is_in_bin_modified(&self) -> bool {
151        self.in_bin_modified && matches!(self.mode, CounterMode::InBin)
152    }
153
154    /// Get the in-bin counter value.
155    pub fn in_bin_value(&self) -> u64 {
156        self.in_bin_value
157    }
158
159    /// Mark the in-bin counter as saved.
160    pub fn mark_saved(&mut self) {
161        self.in_bin_modified = false;
162    }
163
164    /// Update the counter mode.
165    pub fn set_mode(&mut self, mode: CounterMode) {
166        self.mode = mode;
167    }
168}
169
170/// Read counter value from a file.
171///
172/// Creates the file with value 0 if it doesn't exist.
173pub fn read_counter_file(path: &Path) -> Result<u64> {
174    // Create parent directory if needed
175    if let Some(parent) = path.parent()
176        && !parent.exists()
177    {
178        std::fs::create_dir_all(parent)?;
179    }
180
181    // Open or create file
182    let mut file = OpenOptions::new()
183        .read(true)
184        .write(true)
185        .create(true)
186        .truncate(false)
187        .open(path)?;
188
189    // Lock for reading
190    file.lock_shared()
191        .map_err(|_| PromocryptError::CounterLocked)?;
192
193    let mut buffer = [0u8; 8];
194    let bytes_read = file.read(&mut buffer)?;
195
196    // Unlock
197    file.unlock().ok();
198
199    if bytes_read == 0 {
200        // New file, initialize to 0
201        return Ok(0);
202    }
203
204    if bytes_read != 8 {
205        return Err(PromocryptError::InvalidFileFormatDetails(
206            "Counter file corrupted".to_string(),
207        ));
208    }
209
210    Ok(u64::from_le_bytes(buffer))
211}
212
213/// Write counter value to a file.
214pub fn write_counter_file(path: &Path, value: u64) -> Result<()> {
215    // Create parent directory if needed
216    if let Some(parent) = path.parent()
217        && !parent.exists()
218    {
219        std::fs::create_dir_all(parent)?;
220    }
221
222    let mut file = OpenOptions::new()
223        .read(true)
224        .write(true)
225        .create(true)
226        .truncate(false)
227        .open(path)?;
228
229    // Lock for writing
230    file.lock_exclusive()
231        .map_err(|_| PromocryptError::CounterLocked)?;
232
233    file.seek(SeekFrom::Start(0))?;
234    file.write_all(&value.to_le_bytes())?;
235    file.set_len(8)?;
236    file.sync_all()?;
237
238    // Unlock
239    file.unlock().ok();
240
241    Ok(())
242}
243
244/// Atomically reserve a range of counter values.
245///
246/// Returns the starting counter value and increments the file by `count`.
247pub fn reserve_counter_file(path: &Path, count: u64) -> Result<u64> {
248    // Create parent directory if needed
249    if let Some(parent) = path.parent()
250        && !parent.exists()
251    {
252        std::fs::create_dir_all(parent)?;
253    }
254
255    let mut file = OpenOptions::new()
256        .read(true)
257        .write(true)
258        .create(true)
259        .truncate(false)
260        .open(path)?;
261
262    // Lock for exclusive access
263    file.lock_exclusive()
264        .map_err(|_| PromocryptError::CounterLocked)?;
265
266    // Read current value
267    let mut buffer = [0u8; 8];
268    let bytes_read = file.read(&mut buffer)?;
269
270    let current = if bytes_read == 0 {
271        0
272    } else if bytes_read == 8 {
273        u64::from_le_bytes(buffer)
274    } else {
275        file.unlock().ok();
276        return Err(PromocryptError::InvalidFileFormatDetails(
277            "Counter file corrupted".to_string(),
278        ));
279    };
280
281    // Calculate new value
282    let new_value = current
283        .checked_add(count)
284        .ok_or(PromocryptError::CounterOverflow)?;
285
286    // Write new value
287    file.seek(SeekFrom::Start(0))?;
288    file.write_all(&new_value.to_le_bytes())?;
289    file.set_len(8)?;
290    file.sync_all()?;
291
292    // Unlock
293    file.unlock().ok();
294
295    Ok(current)
296}
297
298/// Increment counter by a specific amount.
299///
300/// Returns the new counter value.
301pub fn increment_counter_file(path: &Path, amount: u64) -> Result<u64> {
302    let start = reserve_counter_file(path, amount)?;
303    Ok(start + amount)
304}
305
306/// Serialize counter value to bytes (for in-bin storage).
307pub fn counter_to_bytes(value: u64) -> [u8; 8] {
308    value.to_le_bytes()
309}
310
311/// Deserialize counter value from bytes.
312pub fn counter_from_bytes(bytes: &[u8]) -> Result<u64> {
313    if bytes.len() < 8 {
314        return Err(PromocryptError::InvalidFileFormatDetails(
315            "Counter data too short".to_string(),
316        ));
317    }
318
319    let mut arr = [0u8; 8];
320    arr.copy_from_slice(&bytes[..8]);
321    Ok(u64::from_le_bytes(arr))
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use tempfile::TempDir;
328
329    #[test]
330    fn test_counter_mode_default() {
331        let mode = CounterMode::default();
332        assert!(matches!(mode, CounterMode::Manual));
333    }
334
335    #[test]
336    fn test_counter_mode_file() {
337        let mode = CounterMode::file("/tmp/test.counter");
338        assert!(matches!(mode, CounterMode::File { .. }));
339        assert_eq!(mode.file_path(), Some("/tmp/test.counter"));
340    }
341
342    #[test]
343    fn test_counter_manager_manual() {
344        let manager = CounterManager::new(CounterMode::Manual);
345        assert!(manager.get().is_err());
346    }
347
348    #[test]
349    fn test_counter_manager_in_bin() {
350        let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, 100);
351
352        assert_eq!(manager.get().unwrap(), 100);
353
354        let start = manager.reserve(10).unwrap();
355        assert_eq!(start, 100);
356        assert_eq!(manager.get().unwrap(), 110);
357        assert!(manager.is_in_bin_modified());
358    }
359
360    #[test]
361    fn test_counter_file_read_write() {
362        let temp_dir = TempDir::new().unwrap();
363        let path = temp_dir.path().join("test.counter");
364
365        // Initially should be 0 (new file)
366        let value = read_counter_file(&path).unwrap();
367        assert_eq!(value, 0);
368
369        // Write a value
370        write_counter_file(&path, 12345).unwrap();
371
372        // Read it back
373        let value = read_counter_file(&path).unwrap();
374        assert_eq!(value, 12345);
375    }
376
377    #[test]
378    fn test_counter_file_reserve() {
379        let temp_dir = TempDir::new().unwrap();
380        let path = temp_dir.path().join("test.counter");
381
382        // Reserve 100
383        let start1 = reserve_counter_file(&path, 100).unwrap();
384        assert_eq!(start1, 0);
385
386        // Reserve 50 more
387        let start2 = reserve_counter_file(&path, 50).unwrap();
388        assert_eq!(start2, 100);
389
390        // Verify final value
391        let final_value = read_counter_file(&path).unwrap();
392        assert_eq!(final_value, 150);
393    }
394
395    #[test]
396    fn test_counter_file_increment() {
397        let temp_dir = TempDir::new().unwrap();
398        let path = temp_dir.path().join("test.counter");
399
400        let new_value = increment_counter_file(&path, 10).unwrap();
401        assert_eq!(new_value, 10);
402
403        let new_value = increment_counter_file(&path, 5).unwrap();
404        assert_eq!(new_value, 15);
405    }
406
407    #[test]
408    fn test_counter_bytes_roundtrip() {
409        let value = 0x123456789ABCDEF0u64;
410        let bytes = counter_to_bytes(value);
411        let recovered = counter_from_bytes(&bytes).unwrap();
412        assert_eq!(value, recovered);
413    }
414
415    #[test]
416    fn test_counter_overflow() {
417        let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, u64::MAX - 5);
418
419        // This should fail due to overflow
420        let result = manager.reserve(10);
421        assert!(result.is_err());
422        assert!(matches!(
423            result.unwrap_err(),
424            PromocryptError::CounterOverflow
425        ));
426    }
427
428    #[test]
429    fn test_counter_manager_file_mode() {
430        let temp_dir = TempDir::new().unwrap();
431        let path = temp_dir.path().join("test.counter");
432        let path_str = path.to_str().unwrap().to_string();
433
434        let mut manager = CounterManager::new(CounterMode::File { path: path_str });
435
436        // Initial value
437        let value = manager.get().unwrap();
438        assert_eq!(value, 0);
439
440        // Reserve
441        let start = manager.reserve(100).unwrap();
442        assert_eq!(start, 0);
443
444        // Get again
445        let value = manager.get().unwrap();
446        assert_eq!(value, 100);
447    }
448
449    #[test]
450    fn test_counter_mode_is_external() {
451        assert!(
452            !CounterMode::File {
453                path: "test".to_string()
454            }
455            .is_external()
456        );
457        assert!(!CounterMode::InBin.is_external());
458        assert!(CounterMode::Manual.is_external());
459        assert!(CounterMode::External.is_external());
460    }
461
462    #[test]
463    fn test_counter_mode_serde() {
464        let modes = vec![
465            CounterMode::File {
466                path: "/tmp/test.counter".to_string(),
467            },
468            CounterMode::InBin,
469            CounterMode::Manual,
470            CounterMode::External,
471        ];
472
473        for mode in modes {
474            let json = serde_json::to_string(&mode).unwrap();
475            let recovered: CounterMode = serde_json::from_str(&json).unwrap();
476            assert_eq!(mode, recovered);
477        }
478    }
479}