1#![warn(missing_docs)]
9
10use serde::{Deserialize, Serialize};
11use std::{collections::HashSet, sync::Arc, time::Duration};
12use wae_types::{WaeError, WaeResult};
13
14pub use feature_flag::{FeatureFlag, FeatureFlagManager, FlagDefinition, Strategy, evaluate};
15pub use id_generator::{IdGenerator, SnowflakeGenerator, UuidGenerator, UuidVersion};
16pub use lock::{DistributedLock, InMemoryLock, InMemoryLockManager, LockOptions};
17
18pub mod feature_flag {
20 use super::*;
21 use std::collections::HashMap;
22
23 #[derive(Debug, Clone, Default, Serialize, Deserialize)]
25 pub enum Strategy {
26 On,
28 #[default]
30 Off,
31 Percentage(u32),
33 UserList(Vec<String>),
35 }
36
37 #[derive(Debug, Clone, Serialize, Deserialize)]
39 pub struct FlagDefinition {
40 pub name: String,
42 pub description: String,
44 pub strategy: Strategy,
46 pub enabled: bool,
48 }
49
50 impl FlagDefinition {
51 pub fn new(name: impl Into<String>) -> Self {
53 Self { name: name.into(), description: String::new(), strategy: Strategy::default(), enabled: false }
54 }
55
56 pub fn with_description(mut self, description: impl Into<String>) -> Self {
58 self.description = description.into();
59 self
60 }
61
62 pub fn with_strategy(mut self, strategy: Strategy) -> Self {
64 self.strategy = strategy;
65 self
66 }
67
68 pub fn with_enabled(mut self, enabled: bool) -> Self {
70 self.enabled = enabled;
71 self
72 }
73 }
74
75 #[allow(async_fn_in_trait)]
77 pub trait FeatureFlag: Send + Sync {
78 async fn is_enabled(&self, key: &str) -> bool;
80
81 async fn is_enabled_for_user(&self, key: &str, user_id: &str) -> bool;
83
84 async fn get_variant(&self, key: &str) -> Option<String>;
86 }
87
88 pub struct FeatureFlagManager {
90 flags: parking_lot::RwLock<HashMap<String, FlagDefinition>>,
91 }
92
93 impl FeatureFlagManager {
94 pub fn new() -> Self {
96 Self { flags: parking_lot::RwLock::new(HashMap::new()) }
97 }
98
99 pub fn register(&self, flag: FlagDefinition) {
101 let mut flags = self.flags.write();
102 flags.insert(flag.name.clone(), flag);
103 }
104
105 pub fn unregister(&self, name: &str) -> bool {
107 let mut flags = self.flags.write();
108 flags.remove(name).is_some()
109 }
110
111 pub fn get(&self, name: &str) -> Option<FlagDefinition> {
113 let flags = self.flags.read();
114 flags.get(name).cloned()
115 }
116
117 pub fn update(&self, name: &str, enabled: bool) -> bool {
119 let mut flags = self.flags.write();
120 if let Some(flag) = flags.get_mut(name) {
121 flag.enabled = enabled;
122 return true;
123 }
124 false
125 }
126
127 pub fn list(&self) -> Vec<FlagDefinition> {
129 let flags = self.flags.read();
130 flags.values().cloned().collect()
131 }
132 }
133
134 impl Default for FeatureFlagManager {
135 fn default() -> Self {
136 Self::new()
137 }
138 }
139
140 impl FeatureFlag for FeatureFlagManager {
141 async fn is_enabled(&self, key: &str) -> bool {
142 let flags = self.flags.read();
143 if let Some(flag) = flags.get(key) {
144 return flag.enabled && matches!(flag.strategy, Strategy::On);
145 }
146 false
147 }
148
149 async fn is_enabled_for_user(&self, key: &str, user_id: &str) -> bool {
150 let flags = self.flags.read();
151 if let Some(flag) = flags.get(key) {
152 return evaluate(&flag.strategy, user_id);
153 }
154 false
155 }
156
157 async fn get_variant(&self, _key: &str) -> Option<String> {
158 None
159 }
160 }
161
162 pub fn evaluate(strategy: &Strategy, user_id: &str) -> bool {
164 match strategy {
165 Strategy::On => true,
166 Strategy::Off => false,
167 Strategy::Percentage(pct) => {
168 let hash = calculate_hash(user_id);
169 let bucket = hash % 100;
170 bucket < *pct as u64
171 }
172 Strategy::UserList(users) => users.contains(&user_id.to_string()),
173 }
174 }
175
176 fn calculate_hash(s: &str) -> u64 {
177 let mut hash: u64 = 0;
178 for c in s.chars() {
179 hash = hash.wrapping_mul(31).wrapping_add(c as u64);
180 }
181 hash
182 }
183}
184
185pub mod id_generator {
187 use parking_lot::Mutex;
188 use std::time::{SystemTime, UNIX_EPOCH};
189
190 #[allow(async_fn_in_trait)]
192 pub trait IdGenerator: Send + Sync {
193 async fn generate(&self) -> String;
195
196 async fn generate_batch(&self, count: usize) -> Vec<String>;
198 }
199
200 pub struct SnowflakeGenerator {
202 worker_id: u64,
203 datacenter_id: u64,
204 sequence: Mutex<u64>,
205 last_timestamp: Mutex<u64>,
206 }
207
208 impl SnowflakeGenerator {
209 const EPOCH: u64 = 1704067200000;
210 const WORKER_ID_BITS: u64 = 5;
211 const DATACENTER_ID_BITS: u64 = 5;
212 const SEQUENCE_BITS: u64 = 12;
213 const MAX_WORKER_ID: u64 = (1 << Self::WORKER_ID_BITS) - 1;
214 const MAX_DATACENTER_ID: u64 = (1 << Self::DATACENTER_ID_BITS) - 1;
215 const SEQUENCE_MASK: u64 = (1 << Self::SEQUENCE_BITS) - 1;
216 const WORKER_ID_SHIFT: u64 = Self::SEQUENCE_BITS;
217 const DATACENTER_ID_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS;
218 const TIMESTAMP_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS + Self::DATACENTER_ID_BITS;
219
220 pub fn new(worker_id: u64, datacenter_id: u64) -> Result<Self, String> {
222 if worker_id > Self::MAX_WORKER_ID {
223 return Err(format!("Worker ID must be between 0 and {}", Self::MAX_WORKER_ID));
224 }
225 if datacenter_id > Self::MAX_DATACENTER_ID {
226 return Err(format!("Datacenter ID must be between 0 and {}", Self::MAX_DATACENTER_ID));
227 }
228 Ok(Self { worker_id, datacenter_id, sequence: Mutex::new(0), last_timestamp: Mutex::new(0) })
229 }
230
231 fn current_timestamp() -> u64 {
232 SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64
233 }
234
235 fn til_next_millis(last_timestamp: u64) -> u64 {
236 let mut timestamp = Self::current_timestamp();
237 while timestamp <= last_timestamp {
238 timestamp = Self::current_timestamp();
239 }
240 timestamp
241 }
242
243 fn generate_id(&self) -> u64 {
244 let mut sequence = self.sequence.lock();
245 let mut last_timestamp = self.last_timestamp.lock();
246
247 let timestamp = Self::current_timestamp();
248
249 if timestamp < *last_timestamp {
250 panic!("Clock moved backwards!");
251 }
252
253 if timestamp == *last_timestamp {
254 *sequence = (*sequence + 1) & Self::SEQUENCE_MASK;
255 if *sequence == 0 {
256 *last_timestamp = Self::til_next_millis(*last_timestamp);
257 }
258 }
259 else {
260 *sequence = 0;
261 }
262
263 *last_timestamp = timestamp;
264
265 ((timestamp - Self::EPOCH) << Self::TIMESTAMP_SHIFT)
266 | (self.datacenter_id << Self::DATACENTER_ID_SHIFT)
267 | (self.worker_id << Self::WORKER_ID_SHIFT)
268 | *sequence
269 }
270 }
271
272 impl IdGenerator for SnowflakeGenerator {
273 async fn generate(&self) -> String {
274 self.generate_id().to_string()
275 }
276
277 async fn generate_batch(&self, count: usize) -> Vec<String> {
278 (0..count).map(|_| self.generate_id().to_string()).collect()
279 }
280 }
281
282 #[derive(Debug, Clone, Copy, Default)]
284 pub enum UuidVersion {
285 #[default]
287 V4,
288 V7,
290 }
291
292 pub struct UuidGenerator {
294 version: UuidVersion,
295 }
296
297 impl UuidGenerator {
298 pub fn new(version: UuidVersion) -> Self {
300 Self { version }
301 }
302
303 pub fn v4() -> Self {
305 Self::new(UuidVersion::V4)
306 }
307
308 pub fn v7() -> Self {
310 Self::new(UuidVersion::V7)
311 }
312 }
313
314 impl Default for UuidGenerator {
315 fn default() -> Self {
316 Self::v4()
317 }
318 }
319
320 impl IdGenerator for UuidGenerator {
321 async fn generate(&self) -> String {
322 match self.version {
323 UuidVersion::V4 => uuid::Uuid::new_v4().to_string(),
324 UuidVersion::V7 => {
325 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64;
326 let random_bytes: [u8; 10] = {
327 let mut bytes = [0u8; 10];
328 for byte in &mut bytes {
329 *byte = rand_byte();
330 }
331 bytes
332 };
333
334 let mut uuid_bytes = [0u8; 16];
335 uuid_bytes[0..6].copy_from_slice(&now.to_be_bytes()[2..8]);
336 uuid_bytes[6..16].copy_from_slice(&random_bytes);
337
338 uuid_bytes[6] = (uuid_bytes[6] & 0x0F) | 0x70;
339 uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80;
340
341 uuid::Uuid::from_bytes(uuid_bytes).to_string()
342 }
343 }
344 }
345
346 async fn generate_batch(&self, count: usize) -> Vec<String> {
347 let mut result = Vec::with_capacity(count);
348 for _ in 0..count {
349 result.push(self.generate().await);
350 }
351 result
352 }
353 }
354
355 fn rand_byte() -> u8 {
356 use std::{
357 collections::hash_map::RandomState,
358 hash::{BuildHasher, Hasher},
359 };
360 let state = RandomState::new();
361 let mut hasher = state.build_hasher();
362 hasher.write_u64(SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64);
363 (hasher.finish() & 0xFF) as u8
364 }
365}
366
367pub mod lock {
369 use super::*;
370
371 #[derive(Debug, Clone)]
373 pub struct LockOptions {
374 pub ttl: Duration,
376 pub wait_timeout: Duration,
378 }
379
380 impl Default for LockOptions {
381 fn default() -> Self {
382 Self { ttl: Duration::from_secs(30), wait_timeout: Duration::from_secs(10) }
383 }
384 }
385
386 impl LockOptions {
387 pub fn new() -> Self {
389 Self::default()
390 }
391
392 pub fn with_ttl(mut self, ttl: Duration) -> Self {
394 self.ttl = ttl;
395 self
396 }
397
398 pub fn with_wait_timeout(mut self, timeout: Duration) -> Self {
400 self.wait_timeout = timeout;
401 self
402 }
403 }
404
405 #[allow(async_fn_in_trait)]
407 pub trait DistributedLock: Send + Sync {
408 async fn lock(&self) -> WaeResult<()>;
410
411 async fn try_lock(&self) -> WaeResult<bool>;
413
414 async fn unlock(&self) -> WaeResult<()>;
416
417 async fn lock_with_timeout(&self, timeout: Duration) -> WaeResult<()>;
419
420 fn key(&self) -> &str;
422
423 async fn is_locked(&self) -> bool;
425 }
426
427 pub struct InMemoryLockManager {
429 locks: parking_lot::RwLock<HashSet<String>>,
430 }
431
432 impl InMemoryLockManager {
433 pub fn new() -> Self {
435 Self { locks: parking_lot::RwLock::new(HashSet::new()) }
436 }
437
438 pub fn create_lock(&self, key: impl Into<String>) -> InMemoryLock {
440 InMemoryLock::new(key, Arc::new(self.clone()))
441 }
442
443 async fn acquire_lock(&self, key: &str, _ttl: Duration) -> WaeResult<bool> {
444 let mut locks = self.locks.write();
445 if locks.contains(key) {
446 return Ok(false);
447 }
448 locks.insert(key.to_string());
449 Ok(true)
450 }
451
452 async fn release_lock(&self, key: &str) -> WaeResult<()> {
453 let mut locks = self.locks.write();
454 if locks.remove(key) {
455 return Ok(());
456 }
457 Err(WaeError::not_found("Lock", key))
458 }
459
460 async fn is_locked(&self, key: &str) -> bool {
461 self.locks.read().contains(key)
462 }
463 }
464
465 impl Default for InMemoryLockManager {
466 fn default() -> Self {
467 Self::new()
468 }
469 }
470
471 impl Clone for InMemoryLockManager {
472 fn clone(&self) -> Self {
473 Self { locks: parking_lot::RwLock::new(self.locks.read().clone()) }
474 }
475 }
476
477 pub struct InMemoryLock {
479 key: String,
480 manager: Arc<InMemoryLockManager>,
481 }
482
483 impl InMemoryLock {
484 pub fn new(key: impl Into<String>, manager: Arc<InMemoryLockManager>) -> Self {
486 Self { key: key.into(), manager }
487 }
488 }
489
490 impl DistributedLock for InMemoryLock {
491 async fn lock(&self) -> WaeResult<()> {
492 self.lock_with_timeout(Duration::from_secs(30)).await
493 }
494
495 async fn try_lock(&self) -> WaeResult<bool> {
496 self.manager.acquire_lock(&self.key, Duration::from_secs(30)).await
497 }
498
499 async fn unlock(&self) -> WaeResult<()> {
500 self.manager.release_lock(&self.key).await
501 }
502
503 async fn lock_with_timeout(&self, timeout: Duration) -> WaeResult<()> {
504 let start = std::time::Instant::now();
505 loop {
506 if self.manager.acquire_lock(&self.key, Duration::from_secs(30)).await? {
507 return Ok(());
508 }
509 if start.elapsed() >= timeout {
510 return Err(WaeError::operation_timeout(format!("Lock key: {}", self.key), timeout.as_millis() as u64));
511 }
512 tokio::time::sleep(Duration::from_millis(50)).await;
513 }
514 }
515
516 fn key(&self) -> &str {
517 &self.key
518 }
519
520 async fn is_locked(&self) -> bool {
521 self.manager.is_locked(&self.key).await
522 }
523 }
524}