promocrypt_core/
counter.rs

1//! Counter management for promotional code generation.
2//!
3//! Supports multiple counter storage modes:
4//! - File: Separate `.counter` file with OS locking
5//! - InBin: Counter stored in the .bin mutable section
6//! - External: Consumer-managed counter (e.g., database column, PostgreSQL)
7
8use std::fs::OpenOptions;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::{Path, PathBuf};
11
12use fs2::FileExt;
13use serde::{Deserialize, Serialize};
14
15use crate::error::{PromocryptError, Result};
16
17/// Counter storage mode.
18#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
19#[serde(tag = "type", rename_all = "lowercase")]
20pub enum CounterMode {
21    /// Store counter in a separate file.
22    File {
23        /// Path to the counter file.
24        path: String,
25    },
26
27    /// Store counter in the .bin file's mutable section.
28    InBin,
29
30    /// External counter management (e.g., database column, PostgreSQL).
31    /// The consumer is responsible for counter storage and atomicity.
32    /// Use `generate_at()` or `generate_batch_at()` with explicit counter values.
33    #[default]
34    #[serde(alias = "manual")] // Backwards compatibility with old .bin files
35    External,
36}
37
38impl CounterMode {
39    /// Create a file-based counter mode.
40    pub fn file(path: impl Into<String>) -> Self {
41        CounterMode::File { path: path.into() }
42    }
43
44    /// Check if this mode stores counter in-bin.
45    pub fn is_in_bin(&self) -> bool {
46        matches!(self, CounterMode::InBin)
47    }
48
49    /// Check if this mode requires external counter management.
50    pub fn is_external(&self) -> bool {
51        matches!(self, CounterMode::External)
52    }
53
54    /// Get the counter file path, if applicable.
55    pub fn file_path(&self) -> Option<&str> {
56        match self {
57            CounterMode::File { path } => Some(path),
58            _ => None,
59        }
60    }
61}
62
63/// Counter manager for different storage modes.
64pub struct CounterManager {
65    mode: CounterMode,
66    /// In-memory counter for InBin mode
67    in_bin_value: u64,
68    /// Flag indicating if in-bin value has been modified
69    in_bin_modified: bool,
70}
71
72impl CounterManager {
73    /// Create a new counter manager.
74    pub fn new(mode: CounterMode) -> Self {
75        Self {
76            mode,
77            in_bin_value: 0,
78            in_bin_modified: false,
79        }
80    }
81
82    /// Create a counter manager with an initial in-bin value.
83    pub fn with_in_bin_value(mode: CounterMode, initial_value: u64) -> Self {
84        Self {
85            mode,
86            in_bin_value: initial_value,
87            in_bin_modified: false,
88        }
89    }
90
91    /// Get the counter mode.
92    pub fn mode(&self) -> &CounterMode {
93        &self.mode
94    }
95
96    /// Get the current counter value.
97    pub fn get(&self) -> Result<u64> {
98        match &self.mode {
99            CounterMode::File { path } => read_counter_file(Path::new(path)),
100            CounterMode::InBin => Ok(self.in_bin_value),
101            CounterMode::External => Err(PromocryptError::InvalidArgument(
102                "Cannot get counter in external mode - use generate_at() with explicit counter"
103                    .to_string(),
104            )),
105        }
106    }
107
108    /// Increment counter and return the starting value for a batch.
109    ///
110    /// For a batch of `count` codes, returns the starting counter value
111    /// and increments the stored counter by `count`.
112    pub fn reserve(&mut self, count: u64) -> Result<u64> {
113        match &self.mode {
114            CounterMode::File { path } => {
115                let path = PathBuf::from(path);
116                reserve_counter_file(&path, count)
117            }
118            CounterMode::InBin => {
119                let start = self.in_bin_value;
120                self.in_bin_value = self
121                    .in_bin_value
122                    .checked_add(count)
123                    .ok_or(PromocryptError::CounterOverflow)?;
124                self.in_bin_modified = true;
125                Ok(start)
126            }
127            CounterMode::External => Err(PromocryptError::InvalidArgument(
128                "Cannot reserve counter in external mode - counter is managed externally"
129                    .to_string(),
130            )),
131        }
132    }
133
134    /// Set the counter value (for initialization or reset).
135    pub fn set(&mut self, value: u64) -> Result<()> {
136        match &self.mode {
137            CounterMode::File { path } => write_counter_file(Path::new(path), value),
138            CounterMode::InBin => {
139                self.in_bin_value = value;
140                self.in_bin_modified = true;
141                Ok(())
142            }
143            CounterMode::External => Err(PromocryptError::InvalidArgument(
144                "Cannot set counter in external mode - counter is managed externally".to_string(),
145            )),
146        }
147    }
148
149    /// Check if the in-bin counter has been modified.
150    pub fn is_in_bin_modified(&self) -> bool {
151        self.in_bin_modified && matches!(self.mode, CounterMode::InBin)
152    }
153
154    /// Get the in-bin counter value.
155    pub fn in_bin_value(&self) -> u64 {
156        self.in_bin_value
157    }
158
159    /// Mark the in-bin counter as saved.
160    pub fn mark_saved(&mut self) {
161        self.in_bin_modified = false;
162    }
163
164    /// Update the counter mode.
165    pub fn set_mode(&mut self, mode: CounterMode) {
166        self.mode = mode;
167    }
168}
169
170/// Read counter value from a file.
171///
172/// Creates the file with value 0 if it doesn't exist.
173pub fn read_counter_file(path: &Path) -> Result<u64> {
174    // Create parent directory if needed
175    if let Some(parent) = path.parent()
176        && !parent.exists()
177    {
178        std::fs::create_dir_all(parent)?;
179    }
180
181    // Open or create file
182    let mut file = OpenOptions::new()
183        .read(true)
184        .write(true)
185        .create(true)
186        .truncate(false)
187        .open(path)?;
188
189    // Lock for reading
190    file.lock_shared()
191        .map_err(|_| PromocryptError::CounterLocked)?;
192
193    let mut buffer = [0u8; 8];
194    let bytes_read = file.read(&mut buffer)?;
195
196    // Unlock
197    file.unlock().ok();
198
199    if bytes_read == 0 {
200        // New file, initialize to 0
201        return Ok(0);
202    }
203
204    if bytes_read != 8 {
205        return Err(PromocryptError::InvalidFileFormatDetails(
206            "Counter file corrupted".to_string(),
207        ));
208    }
209
210    Ok(u64::from_le_bytes(buffer))
211}
212
213/// Write counter value to a file.
214pub fn write_counter_file(path: &Path, value: u64) -> Result<()> {
215    // Create parent directory if needed
216    if let Some(parent) = path.parent()
217        && !parent.exists()
218    {
219        std::fs::create_dir_all(parent)?;
220    }
221
222    let mut file = OpenOptions::new()
223        .read(true)
224        .write(true)
225        .create(true)
226        .truncate(false)
227        .open(path)?;
228
229    // Lock for writing
230    file.lock_exclusive()
231        .map_err(|_| PromocryptError::CounterLocked)?;
232
233    file.seek(SeekFrom::Start(0))?;
234    file.write_all(&value.to_le_bytes())?;
235    file.set_len(8)?;
236    file.sync_all()?;
237
238    // Unlock
239    file.unlock().ok();
240
241    Ok(())
242}
243
244/// Atomically reserve a range of counter values.
245///
246/// Returns the starting counter value and increments the file by `count`.
247pub fn reserve_counter_file(path: &Path, count: u64) -> Result<u64> {
248    // Create parent directory if needed
249    if let Some(parent) = path.parent()
250        && !parent.exists()
251    {
252        std::fs::create_dir_all(parent)?;
253    }
254
255    let mut file = OpenOptions::new()
256        .read(true)
257        .write(true)
258        .create(true)
259        .truncate(false)
260        .open(path)?;
261
262    // Lock for exclusive access
263    file.lock_exclusive()
264        .map_err(|_| PromocryptError::CounterLocked)?;
265
266    // Read current value
267    let mut buffer = [0u8; 8];
268    let bytes_read = file.read(&mut buffer)?;
269
270    let current = if bytes_read == 0 {
271        0
272    } else if bytes_read == 8 {
273        u64::from_le_bytes(buffer)
274    } else {
275        file.unlock().ok();
276        return Err(PromocryptError::InvalidFileFormatDetails(
277            "Counter file corrupted".to_string(),
278        ));
279    };
280
281    // Calculate new value
282    let new_value = current
283        .checked_add(count)
284        .ok_or(PromocryptError::CounterOverflow)?;
285
286    // Write new value
287    file.seek(SeekFrom::Start(0))?;
288    file.write_all(&new_value.to_le_bytes())?;
289    file.set_len(8)?;
290    file.sync_all()?;
291
292    // Unlock
293    file.unlock().ok();
294
295    Ok(current)
296}
297
298/// Increment counter by a specific amount.
299///
300/// Returns the new counter value.
301pub fn increment_counter_file(path: &Path, amount: u64) -> Result<u64> {
302    let start = reserve_counter_file(path, amount)?;
303    Ok(start + amount)
304}
305
306/// Serialize counter value to bytes (for in-bin storage).
307pub fn counter_to_bytes(value: u64) -> [u8; 8] {
308    value.to_le_bytes()
309}
310
311/// Deserialize counter value from bytes.
312pub fn counter_from_bytes(bytes: &[u8]) -> Result<u64> {
313    if bytes.len() < 8 {
314        return Err(PromocryptError::InvalidFileFormatDetails(
315            "Counter data too short".to_string(),
316        ));
317    }
318
319    let mut arr = [0u8; 8];
320    arr.copy_from_slice(&bytes[..8]);
321    Ok(u64::from_le_bytes(arr))
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use tempfile::TempDir;
328
329    #[test]
330    fn test_counter_mode_default() {
331        let mode = CounterMode::default();
332        assert!(matches!(mode, CounterMode::External));
333    }
334
335    #[test]
336    fn test_counter_mode_file() {
337        let mode = CounterMode::file("/tmp/test.counter");
338        assert!(matches!(mode, CounterMode::File { .. }));
339        assert_eq!(mode.file_path(), Some("/tmp/test.counter"));
340    }
341
342    #[test]
343    fn test_counter_manager_external() {
344        let manager = CounterManager::new(CounterMode::External);
345        assert!(manager.get().is_err());
346    }
347
348    #[test]
349    fn test_counter_manager_in_bin() {
350        let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, 100);
351
352        assert_eq!(manager.get().unwrap(), 100);
353
354        let start = manager.reserve(10).unwrap();
355        assert_eq!(start, 100);
356        assert_eq!(manager.get().unwrap(), 110);
357        assert!(manager.is_in_bin_modified());
358    }
359
360    #[test]
361    fn test_counter_file_read_write() {
362        let temp_dir = TempDir::new().unwrap();
363        let path = temp_dir.path().join("test.counter");
364
365        // Initially should be 0 (new file)
366        let value = read_counter_file(&path).unwrap();
367        assert_eq!(value, 0);
368
369        // Write a value
370        write_counter_file(&path, 12345).unwrap();
371
372        // Read it back
373        let value = read_counter_file(&path).unwrap();
374        assert_eq!(value, 12345);
375    }
376
377    #[test]
378    fn test_counter_file_reserve() {
379        let temp_dir = TempDir::new().unwrap();
380        let path = temp_dir.path().join("test.counter");
381
382        // Reserve 100
383        let start1 = reserve_counter_file(&path, 100).unwrap();
384        assert_eq!(start1, 0);
385
386        // Reserve 50 more
387        let start2 = reserve_counter_file(&path, 50).unwrap();
388        assert_eq!(start2, 100);
389
390        // Verify final value
391        let final_value = read_counter_file(&path).unwrap();
392        assert_eq!(final_value, 150);
393    }
394
395    #[test]
396    fn test_counter_file_increment() {
397        let temp_dir = TempDir::new().unwrap();
398        let path = temp_dir.path().join("test.counter");
399
400        let new_value = increment_counter_file(&path, 10).unwrap();
401        assert_eq!(new_value, 10);
402
403        let new_value = increment_counter_file(&path, 5).unwrap();
404        assert_eq!(new_value, 15);
405    }
406
407    #[test]
408    fn test_counter_bytes_roundtrip() {
409        let value = 0x123456789ABCDEF0u64;
410        let bytes = counter_to_bytes(value);
411        let recovered = counter_from_bytes(&bytes).unwrap();
412        assert_eq!(value, recovered);
413    }
414
415    #[test]
416    fn test_counter_overflow() {
417        let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, u64::MAX - 5);
418
419        // This should fail due to overflow
420        let result = manager.reserve(10);
421        assert!(result.is_err());
422        assert!(matches!(
423            result.unwrap_err(),
424            PromocryptError::CounterOverflow
425        ));
426    }
427
428    #[test]
429    fn test_counter_manager_file_mode() {
430        let temp_dir = TempDir::new().unwrap();
431        let path = temp_dir.path().join("test.counter");
432        let path_str = path.to_str().unwrap().to_string();
433
434        let mut manager = CounterManager::new(CounterMode::File { path: path_str });
435
436        // Initial value
437        let value = manager.get().unwrap();
438        assert_eq!(value, 0);
439
440        // Reserve
441        let start = manager.reserve(100).unwrap();
442        assert_eq!(start, 0);
443
444        // Get again
445        let value = manager.get().unwrap();
446        assert_eq!(value, 100);
447    }
448
449    #[test]
450    fn test_counter_mode_is_external() {
451        assert!(
452            !CounterMode::File {
453                path: "test".to_string()
454            }
455            .is_external()
456        );
457        assert!(!CounterMode::InBin.is_external());
458        assert!(CounterMode::External.is_external());
459    }
460
461    #[test]
462    fn test_counter_mode_serde() {
463        let modes = vec![
464            CounterMode::File {
465                path: "/tmp/test.counter".to_string(),
466            },
467            CounterMode::InBin,
468            CounterMode::External,
469        ];
470
471        for mode in modes {
472            let json = serde_json::to_string(&mode).unwrap();
473            let recovered: CounterMode = serde_json::from_str(&json).unwrap();
474            assert_eq!(mode, recovered);
475        }
476    }
477
478    #[test]
479    fn test_counter_mode_serde_backwards_compat() {
480        // Old "manual" should deserialize to External for backwards compatibility
481        let json = r#"{"type":"manual"}"#;
482        let mode: CounterMode = serde_json::from_str(json).unwrap();
483        assert!(matches!(mode, CounterMode::External));
484
485        // New "external" also works
486        let json = r#"{"type":"external"}"#;
487        let mode: CounterMode = serde_json::from_str(json).unwrap();
488        assert!(matches!(mode, CounterMode::External));
489
490        // Serialization uses "external"
491        let external_json = serde_json::to_string(&CounterMode::External).unwrap();
492        assert!(external_json.contains("external"));
493    }
494}