promocrypt_core/
counter.rs1use std::fs::OpenOptions;
10use std::io::{Read, Seek, SeekFrom, Write};
11use std::path::{Path, PathBuf};
12
13use fs2::FileExt;
14use serde::{Deserialize, Serialize};
15
16use crate::error::{PromocryptError, Result};
17
18#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(tag = "type", rename_all = "lowercase")]
21pub enum CounterMode {
22 File {
24 path: String,
26 },
27
28 InBin,
30
31 #[default]
33 Manual,
34
35 External,
38}
39
40impl CounterMode {
41 pub fn file(path: impl Into<String>) -> Self {
43 CounterMode::File { path: path.into() }
44 }
45
46 pub fn is_in_bin(&self) -> bool {
48 matches!(self, CounterMode::InBin)
49 }
50
51 pub fn is_external(&self) -> bool {
53 matches!(self, CounterMode::External | CounterMode::Manual)
54 }
55
56 pub fn file_path(&self) -> Option<&str> {
58 match self {
59 CounterMode::File { path } => Some(path),
60 _ => None,
61 }
62 }
63}
64
65pub struct CounterManager {
67 mode: CounterMode,
68 in_bin_value: u64,
70 in_bin_modified: bool,
72}
73
74impl CounterManager {
75 pub fn new(mode: CounterMode) -> Self {
77 Self {
78 mode,
79 in_bin_value: 0,
80 in_bin_modified: false,
81 }
82 }
83
84 pub fn with_in_bin_value(mode: CounterMode, initial_value: u64) -> Self {
86 Self {
87 mode,
88 in_bin_value: initial_value,
89 in_bin_modified: false,
90 }
91 }
92
93 pub fn mode(&self) -> &CounterMode {
95 &self.mode
96 }
97
98 pub fn get(&self) -> Result<u64> {
100 match &self.mode {
101 CounterMode::File { path } => read_counter_file(Path::new(path)),
102 CounterMode::InBin => Ok(self.in_bin_value),
103 CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
104 "Cannot get counter in manual/external mode - counter must be provided".to_string(),
105 )),
106 }
107 }
108
109 pub fn reserve(&mut self, count: u64) -> Result<u64> {
114 match &self.mode {
115 CounterMode::File { path } => {
116 let path = PathBuf::from(path);
117 reserve_counter_file(&path, count)
118 }
119 CounterMode::InBin => {
120 let start = self.in_bin_value;
121 self.in_bin_value = self
122 .in_bin_value
123 .checked_add(count)
124 .ok_or(PromocryptError::CounterOverflow)?;
125 self.in_bin_modified = true;
126 Ok(start)
127 }
128 CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
129 "Cannot reserve counter in manual/external mode".to_string(),
130 )),
131 }
132 }
133
134 pub fn set(&mut self, value: u64) -> Result<()> {
136 match &self.mode {
137 CounterMode::File { path } => write_counter_file(Path::new(path), value),
138 CounterMode::InBin => {
139 self.in_bin_value = value;
140 self.in_bin_modified = true;
141 Ok(())
142 }
143 CounterMode::Manual | CounterMode::External => Err(PromocryptError::InvalidArgument(
144 "Cannot set counter in manual/external mode".to_string(),
145 )),
146 }
147 }
148
149 pub fn is_in_bin_modified(&self) -> bool {
151 self.in_bin_modified && matches!(self.mode, CounterMode::InBin)
152 }
153
154 pub fn in_bin_value(&self) -> u64 {
156 self.in_bin_value
157 }
158
159 pub fn mark_saved(&mut self) {
161 self.in_bin_modified = false;
162 }
163
164 pub fn set_mode(&mut self, mode: CounterMode) {
166 self.mode = mode;
167 }
168}
169
170pub fn read_counter_file(path: &Path) -> Result<u64> {
174 if let Some(parent) = path.parent()
176 && !parent.exists()
177 {
178 std::fs::create_dir_all(parent)?;
179 }
180
181 let mut file = OpenOptions::new()
183 .read(true)
184 .write(true)
185 .create(true)
186 .truncate(false)
187 .open(path)?;
188
189 file.lock_shared()
191 .map_err(|_| PromocryptError::CounterLocked)?;
192
193 let mut buffer = [0u8; 8];
194 let bytes_read = file.read(&mut buffer)?;
195
196 file.unlock().ok();
198
199 if bytes_read == 0 {
200 return Ok(0);
202 }
203
204 if bytes_read != 8 {
205 return Err(PromocryptError::InvalidFileFormatDetails(
206 "Counter file corrupted".to_string(),
207 ));
208 }
209
210 Ok(u64::from_le_bytes(buffer))
211}
212
213pub fn write_counter_file(path: &Path, value: u64) -> Result<()> {
215 if let Some(parent) = path.parent()
217 && !parent.exists()
218 {
219 std::fs::create_dir_all(parent)?;
220 }
221
222 let mut file = OpenOptions::new()
223 .read(true)
224 .write(true)
225 .create(true)
226 .truncate(false)
227 .open(path)?;
228
229 file.lock_exclusive()
231 .map_err(|_| PromocryptError::CounterLocked)?;
232
233 file.seek(SeekFrom::Start(0))?;
234 file.write_all(&value.to_le_bytes())?;
235 file.set_len(8)?;
236 file.sync_all()?;
237
238 file.unlock().ok();
240
241 Ok(())
242}
243
244pub fn reserve_counter_file(path: &Path, count: u64) -> Result<u64> {
248 if let Some(parent) = path.parent()
250 && !parent.exists()
251 {
252 std::fs::create_dir_all(parent)?;
253 }
254
255 let mut file = OpenOptions::new()
256 .read(true)
257 .write(true)
258 .create(true)
259 .truncate(false)
260 .open(path)?;
261
262 file.lock_exclusive()
264 .map_err(|_| PromocryptError::CounterLocked)?;
265
266 let mut buffer = [0u8; 8];
268 let bytes_read = file.read(&mut buffer)?;
269
270 let current = if bytes_read == 0 {
271 0
272 } else if bytes_read == 8 {
273 u64::from_le_bytes(buffer)
274 } else {
275 file.unlock().ok();
276 return Err(PromocryptError::InvalidFileFormatDetails(
277 "Counter file corrupted".to_string(),
278 ));
279 };
280
281 let new_value = current
283 .checked_add(count)
284 .ok_or(PromocryptError::CounterOverflow)?;
285
286 file.seek(SeekFrom::Start(0))?;
288 file.write_all(&new_value.to_le_bytes())?;
289 file.set_len(8)?;
290 file.sync_all()?;
291
292 file.unlock().ok();
294
295 Ok(current)
296}
297
298pub fn increment_counter_file(path: &Path, amount: u64) -> Result<u64> {
302 let start = reserve_counter_file(path, amount)?;
303 Ok(start + amount)
304}
305
306pub fn counter_to_bytes(value: u64) -> [u8; 8] {
308 value.to_le_bytes()
309}
310
311pub fn counter_from_bytes(bytes: &[u8]) -> Result<u64> {
313 if bytes.len() < 8 {
314 return Err(PromocryptError::InvalidFileFormatDetails(
315 "Counter data too short".to_string(),
316 ));
317 }
318
319 let mut arr = [0u8; 8];
320 arr.copy_from_slice(&bytes[..8]);
321 Ok(u64::from_le_bytes(arr))
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use tempfile::TempDir;
328
329 #[test]
330 fn test_counter_mode_default() {
331 let mode = CounterMode::default();
332 assert!(matches!(mode, CounterMode::Manual));
333 }
334
335 #[test]
336 fn test_counter_mode_file() {
337 let mode = CounterMode::file("/tmp/test.counter");
338 assert!(matches!(mode, CounterMode::File { .. }));
339 assert_eq!(mode.file_path(), Some("/tmp/test.counter"));
340 }
341
342 #[test]
343 fn test_counter_manager_manual() {
344 let manager = CounterManager::new(CounterMode::Manual);
345 assert!(manager.get().is_err());
346 }
347
348 #[test]
349 fn test_counter_manager_in_bin() {
350 let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, 100);
351
352 assert_eq!(manager.get().unwrap(), 100);
353
354 let start = manager.reserve(10).unwrap();
355 assert_eq!(start, 100);
356 assert_eq!(manager.get().unwrap(), 110);
357 assert!(manager.is_in_bin_modified());
358 }
359
360 #[test]
361 fn test_counter_file_read_write() {
362 let temp_dir = TempDir::new().unwrap();
363 let path = temp_dir.path().join("test.counter");
364
365 let value = read_counter_file(&path).unwrap();
367 assert_eq!(value, 0);
368
369 write_counter_file(&path, 12345).unwrap();
371
372 let value = read_counter_file(&path).unwrap();
374 assert_eq!(value, 12345);
375 }
376
377 #[test]
378 fn test_counter_file_reserve() {
379 let temp_dir = TempDir::new().unwrap();
380 let path = temp_dir.path().join("test.counter");
381
382 let start1 = reserve_counter_file(&path, 100).unwrap();
384 assert_eq!(start1, 0);
385
386 let start2 = reserve_counter_file(&path, 50).unwrap();
388 assert_eq!(start2, 100);
389
390 let final_value = read_counter_file(&path).unwrap();
392 assert_eq!(final_value, 150);
393 }
394
395 #[test]
396 fn test_counter_file_increment() {
397 let temp_dir = TempDir::new().unwrap();
398 let path = temp_dir.path().join("test.counter");
399
400 let new_value = increment_counter_file(&path, 10).unwrap();
401 assert_eq!(new_value, 10);
402
403 let new_value = increment_counter_file(&path, 5).unwrap();
404 assert_eq!(new_value, 15);
405 }
406
407 #[test]
408 fn test_counter_bytes_roundtrip() {
409 let value = 0x123456789ABCDEF0u64;
410 let bytes = counter_to_bytes(value);
411 let recovered = counter_from_bytes(&bytes).unwrap();
412 assert_eq!(value, recovered);
413 }
414
415 #[test]
416 fn test_counter_overflow() {
417 let mut manager = CounterManager::with_in_bin_value(CounterMode::InBin, u64::MAX - 5);
418
419 let result = manager.reserve(10);
421 assert!(result.is_err());
422 assert!(matches!(
423 result.unwrap_err(),
424 PromocryptError::CounterOverflow
425 ));
426 }
427
428 #[test]
429 fn test_counter_manager_file_mode() {
430 let temp_dir = TempDir::new().unwrap();
431 let path = temp_dir.path().join("test.counter");
432 let path_str = path.to_str().unwrap().to_string();
433
434 let mut manager = CounterManager::new(CounterMode::File { path: path_str });
435
436 let value = manager.get().unwrap();
438 assert_eq!(value, 0);
439
440 let start = manager.reserve(100).unwrap();
442 assert_eq!(start, 0);
443
444 let value = manager.get().unwrap();
446 assert_eq!(value, 100);
447 }
448
449 #[test]
450 fn test_counter_mode_is_external() {
451 assert!(
452 !CounterMode::File {
453 path: "test".to_string()
454 }
455 .is_external()
456 );
457 assert!(!CounterMode::InBin.is_external());
458 assert!(CounterMode::Manual.is_external());
459 assert!(CounterMode::External.is_external());
460 }
461
462 #[test]
463 fn test_counter_mode_serde() {
464 let modes = vec![
465 CounterMode::File {
466 path: "/tmp/test.counter".to_string(),
467 },
468 CounterMode::InBin,
469 CounterMode::Manual,
470 CounterMode::External,
471 ];
472
473 for mode in modes {
474 let json = serde_json::to_string(&mode).unwrap();
475 let recovered: CounterMode = serde_json::from_str(&json).unwrap();
476 assert_eq!(mode, recovered);
477 }
478 }
479}