use crate::types::{IdStrategy, IdType};
use anyhow::{Result, anyhow};
use rat_logger::{error, info, warn};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
#[derive(Clone, Debug)]
pub struct IdGenerator {
strategy: IdStrategy,
snowflake_generator: Option<Arc<SnowflakeGenerator>>,
auto_increment_counter: Arc<AtomicU64>,
}
impl IdGenerator {
pub fn new(strategy: IdStrategy) -> Result<Self> {
let snowflake_generator = match &strategy {
IdStrategy::Snowflake {
machine_id,
datacenter_id,
} => Some(Arc::new(SnowflakeGenerator::new(
*machine_id,
*datacenter_id,
)?)),
_ => None,
};
Ok(Self {
strategy,
snowflake_generator,
auto_increment_counter: Arc::new(AtomicU64::new(1)),
})
}
pub async fn generate(&self) -> Result<IdType> {
match &self.strategy {
IdStrategy::AutoIncrement => {
let id = self.auto_increment_counter.fetch_add(1, Ordering::SeqCst);
Ok(IdType::Number(id as i64))
}
IdStrategy::Uuid => {
let uuid = Uuid::new_v4();
Ok(IdType::String(uuid.to_string()))
}
IdStrategy::Snowflake { .. } => {
if let Some(generator) = &self.snowflake_generator {
let id = generator.generate().await?;
Ok(IdType::String(id.to_string()))
} else {
Err(anyhow!("Snowflake generator not initialized"))
}
}
IdStrategy::ObjectId => {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
let random_bytes: [u8; 8] = rand::random();
let object_id =
format!("{:08x}{:016x}", timestamp, u64::from_be_bytes(random_bytes));
Ok(IdType::String(object_id))
}
IdStrategy::Custom(generator_name) => {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let custom_id = format!("{}_{}", generator_name, timestamp);
Ok(IdType::String(custom_id))
}
}
}
pub fn validate_id(&self, id: &IdType) -> bool {
match (&self.strategy, id) {
(IdStrategy::AutoIncrement, IdType::Number(n)) => *n > 0,
(IdStrategy::Uuid, IdType::String(s)) => Uuid::parse_str(s).is_ok(),
(IdStrategy::Snowflake { .. }, IdType::String(s)) => s.parse::<u64>().is_ok(),
(IdStrategy::ObjectId, IdType::String(s)) => {
s.len() == 24 && s.chars().all(|c| c.is_ascii_hexdigit())
}
(IdStrategy::Custom(_), IdType::String(_)) => true, _ => false,
}
}
pub fn strategy(&self) -> &IdStrategy {
&self.strategy
}
pub fn set_auto_increment_start(&self, start: u64) {
self.auto_increment_counter.store(start, Ordering::SeqCst);
}
}
#[derive(Debug)]
struct SnowflakeGenerator {
machine_id: u16,
datacenter_id: u8,
sequence: Arc<AtomicU64>,
last_timestamp: Arc<AtomicU64>,
}
impl SnowflakeGenerator {
fn new(machine_id: u16, datacenter_id: u8) -> Result<Self> {
if machine_id > 1023 {
return Err(anyhow!("Machine ID must be between 0 and 1023"));
}
if datacenter_id > 31 {
return Err(anyhow!("Datacenter ID must be between 0 and 31"));
}
Ok(Self {
machine_id,
datacenter_id,
sequence: Arc::new(AtomicU64::new(0)),
last_timestamp: Arc::new(AtomicU64::new(0)),
})
}
async fn generate(&self) -> Result<u64> {
let mut timestamp = self.current_timestamp();
let last_timestamp = self.last_timestamp.load(Ordering::SeqCst);
if timestamp < last_timestamp {
return Err(anyhow!("Clock moved backwards"));
}
let sequence = if timestamp == last_timestamp {
let seq = self.sequence.fetch_add(1, Ordering::SeqCst) & 0xFFF;
if seq == 0 {
timestamp = self.wait_next_millis(last_timestamp);
}
seq
} else {
self.sequence.store(0, Ordering::SeqCst);
0
};
self.last_timestamp.store(timestamp, Ordering::SeqCst);
let id = ((timestamp - 1288834974657) << 22)
| ((self.datacenter_id as u64) << 17)
| ((self.machine_id as u64) << 12)
| sequence;
Ok(id)
}
fn current_timestamp(&self) -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64
}
fn wait_next_millis(&self, last_timestamp: u64) -> u64 {
let mut timestamp = self.current_timestamp();
while timestamp <= last_timestamp {
timestamp = self.current_timestamp();
}
timestamp
}
}
#[derive(Debug)]
pub struct MongoAutoIncrementGenerator {
collection_name: String,
counter: Arc<AtomicU64>,
}
impl MongoAutoIncrementGenerator {
pub fn new(collection_name: String) -> Self {
Self {
collection_name,
counter: Arc::new(AtomicU64::new(1)),
}
}
pub async fn next_id(&self) -> Result<i64> {
let id = self.counter.fetch_add(1, Ordering::SeqCst);
Ok(id as i64)
}
pub fn set_start_value(&self, start: u64) {
self.counter.store(start, Ordering::SeqCst);
}
pub fn current_value(&self) -> u64 {
self.counter.load(Ordering::SeqCst)
}
pub fn reset(&self) {
self.counter.store(1, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_uuid_generator() {
let generator = IdGenerator::new(IdStrategy::Uuid).unwrap();
let id1 = generator.generate().await.unwrap();
let id2 = generator.generate().await.unwrap();
match (&id1, &id2) {
(IdType::String(s1), IdType::String(s2)) => {
assert_ne!(s1, s2);
assert!(generator.validate_id(&id1));
assert!(generator.validate_id(&id2));
}
_ => panic!("Expected string IDs"),
}
}
#[tokio::test]
async fn test_auto_increment_generator() {
let generator = IdGenerator::new(IdStrategy::AutoIncrement).unwrap();
let id1 = generator.generate().await.unwrap();
let id2 = generator.generate().await.unwrap();
match (&id1, &id2) {
(IdType::Number(n1), IdType::Number(n2)) => {
assert_eq!(*n1, 1);
assert_eq!(*n2, 2);
assert!(generator.validate_id(&id1));
assert!(generator.validate_id(&id2));
}
_ => panic!("Expected number IDs"),
}
}
#[tokio::test]
async fn test_snowflake_generator() {
let generator = IdGenerator::new(IdStrategy::Snowflake {
machine_id: 1,
datacenter_id: 1,
})
.unwrap();
let id1 = generator.generate().await.unwrap();
let id2 = generator.generate().await.unwrap();
match (&id1, &id2) {
(IdType::String(s1), IdType::String(s2)) => {
assert_ne!(s1, s2);
assert!(generator.validate_id(&id1));
assert!(generator.validate_id(&id2));
}
_ => panic!("Expected string IDs"),
}
}
#[tokio::test]
async fn test_object_id_generator() {
let generator = IdGenerator::new(IdStrategy::ObjectId).unwrap();
let id1 = generator.generate().await.unwrap();
let id2 = generator.generate().await.unwrap();
match (&id1, &id2) {
(IdType::String(s1), IdType::String(s2)) => {
assert_ne!(s1, s2);
assert_eq!(s1.len(), 24);
assert_eq!(s2.len(), 24);
assert!(generator.validate_id(&id1));
assert!(generator.validate_id(&id2));
}
_ => panic!("Expected string IDs"),
}
}
#[tokio::test]
async fn test_mongo_auto_increment_generator() {
let generator = MongoAutoIncrementGenerator::new("test_collection".to_string());
let id1 = generator.next_id().await.unwrap();
let id2 = generator.next_id().await.unwrap();
assert_eq!(id1, 1);
assert_eq!(id2, 2);
generator.set_start_value(100);
let id3 = generator.next_id().await.unwrap();
assert_eq!(id3, 100);
}
}