1use core::fmt;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::time::{Duration, SystemTime, UNIX_EPOCH};
39
40pub const DEFAULT_EPOCH_MS: u64 = 1_767_225_600_000;
42
43pub const SEQUENCE_BITS: u32 = 12;
45pub const WORKER_BITS: u32 = 10;
47pub const TIMESTAMP_BITS: u32 = 41;
49
50const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
51const WORKER_MASK: u64 = (1 << WORKER_BITS) - 1;
52const TIMESTAMP_MASK: u64 = (1 << TIMESTAMP_BITS) - 1;
53
54const WORKER_SHIFT: u32 = SEQUENCE_BITS;
55const TIMESTAMP_SHIFT: u32 = SEQUENCE_BITS + WORKER_BITS;
56
57const STATE_SEQ_BITS: u32 = 13;
61const STATE_SEQ_MASK: u64 = (1 << STATE_SEQ_BITS) - 1;
62const STATE_SEQ_EXHAUSTED: u64 = SEQUENCE_MASK + 1; #[derive(Debug)]
79pub struct Snowflake {
80 worker_id: u16,
81 epoch_ms: u64,
82 state: AtomicU64,
83}
84
85impl Snowflake {
86 pub const fn new(worker_id: u16) -> Self {
100 Self::with_epoch(worker_id, DEFAULT_EPOCH_MS)
101 }
102
103 pub const fn with_epoch(worker_id: u16, epoch_ms: u64) -> Self {
116 Self {
117 worker_id: (worker_id as u64 & WORKER_MASK) as u16,
118 epoch_ms,
119 state: AtomicU64::new(0),
120 }
121 }
122
123 pub const fn worker_id(&self) -> u16 {
125 self.worker_id
126 }
127
128 pub const fn epoch_ms(&self) -> u64 {
130 self.epoch_ms
131 }
132
133 pub fn try_next_id(&self) -> Result<u64, ClockSkew> {
150 loop {
151 let cur = self.state.load(Ordering::Acquire);
152 let last_ms = cur >> STATE_SEQ_BITS;
153 let next_seq = cur & STATE_SEQ_MASK;
154
155 let now = current_offset_ms(self.epoch_ms);
156 if now < last_ms {
157 return Err(ClockSkew {
158 last_ms,
159 now_ms: now,
160 });
161 }
162
163 let (use_ms, assigned, new_next_seq) = if now == last_ms {
164 if next_seq >= STATE_SEQ_EXHAUSTED {
165 sleep_until_after(self.epoch_ms, last_ms);
166 continue;
167 }
168 (last_ms, next_seq, next_seq + 1)
169 } else {
170 (now, 0u64, 1u64)
171 };
172
173 let new_state = (use_ms << STATE_SEQ_BITS) | new_next_seq;
174 if self
175 .state
176 .compare_exchange(cur, new_state, Ordering::AcqRel, Ordering::Acquire)
177 .is_ok()
178 {
179 let id = (use_ms << TIMESTAMP_SHIFT)
180 | ((self.worker_id as u64) << WORKER_SHIFT)
181 | assigned;
182 return Ok(id);
183 }
184 }
185 }
186
187 pub fn next_id(&self) -> u64 {
209 match self.try_next_id() {
210 Ok(id) => id,
211 Err(e) => panic!("snowflake: clock moved backward ({e})"),
212 }
213 }
214
215 pub const fn parts(id: u64) -> (u64, u16, u16) {
236 let timestamp_offset = (id >> TIMESTAMP_SHIFT) & TIMESTAMP_MASK;
237 let worker = ((id >> WORKER_SHIFT) & WORKER_MASK) as u16;
238 let sequence = (id & SEQUENCE_MASK) as u16;
239 (timestamp_offset, worker, sequence)
240 }
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub struct ClockSkew {
247 pub last_ms: u64,
250 pub now_ms: u64,
253}
254
255impl fmt::Display for ClockSkew {
256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 write!(
258 f,
259 "clock moved backward: last issued at offset {} ms, now at offset {} ms",
260 self.last_ms, self.now_ms
261 )
262 }
263}
264
265impl std::error::Error for ClockSkew {}
266
267fn current_offset_ms(epoch_ms: u64) -> u64 {
268 let now = SystemTime::now()
269 .duration_since(UNIX_EPOCH)
270 .map(|d| d.as_millis() as u64)
271 .unwrap_or(0);
272 now.saturating_sub(epoch_ms) & TIMESTAMP_MASK
273}
274
275fn sleep_until_after(epoch_ms: u64, last_ms: u64) {
276 loop {
277 if current_offset_ms(epoch_ms) > last_ms {
278 return;
279 }
280 std::thread::sleep(Duration::from_micros(100));
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use std::collections::HashSet;
288 use std::sync::atomic::Ordering;
289 use std::sync::Arc;
290 use std::thread;
291
292 #[test]
293 fn next_id_produces_value() {
294 let gen = Snowflake::new(1);
295 assert!(gen.next_id() > 0);
296 }
297
298 #[test]
299 fn worker_id_clamped() {
300 let gen = Snowflake::new(0xffff);
301 assert_eq!(gen.worker_id(), 0x3ff);
302 let id = gen.next_id();
303 assert_eq!(Snowflake::parts(id).1, 0x3ff);
304 }
305
306 #[test]
307 fn worker_field_extracts() {
308 let gen = Snowflake::new(42);
309 let id = gen.next_id();
310 let (_, worker, _) = Snowflake::parts(id);
311 assert_eq!(worker, 42);
312 }
313
314 #[test]
315 fn monotonic_in_burst() {
316 let gen = Snowflake::new(1);
317 let mut prev = gen.next_id();
318 for _ in 0..10_000 {
319 let cur = gen.next_id();
320 assert!(cur > prev, "expected {cur} > {prev}");
321 prev = cur;
322 }
323 }
324
325 #[test]
326 fn all_unique_in_burst() {
327 let gen = Snowflake::new(1);
328 let mut set = HashSet::new();
329 for _ in 0..50_000 {
330 let id = gen.next_id();
331 assert!(set.insert(id));
332 }
333 }
334
335 #[test]
336 fn parts_round_trip() {
337 let gen = Snowflake::with_epoch(7, DEFAULT_EPOCH_MS);
338 let id = gen.next_id();
339 let (ts, worker, seq) = Snowflake::parts(id);
340 assert_eq!(worker, 7);
341 let reassembled = (ts << TIMESTAMP_SHIFT) | ((worker as u64) << WORKER_SHIFT) | seq as u64;
342 assert_eq!(reassembled, id);
343 }
344
345 #[test]
346 fn sequence_resets_each_ms() {
347 let gen = Snowflake::new(1);
348 let _ = gen.next_id();
349 thread::sleep(Duration::from_millis(3));
350 let id_after_sleep = gen.next_id();
351 let (_, _, seq) = Snowflake::parts(id_after_sleep);
352 assert_eq!(seq, 0, "first ID of a fresh ms must have sequence 0");
353 }
354
355 #[test]
356 fn sequence_exhaustion_blocks_until_next_ms() {
357 let gen = Snowflake::new(1);
359 let now = current_offset_ms(gen.epoch_ms);
360 let exhausted_state = (now << STATE_SEQ_BITS) | STATE_SEQ_EXHAUSTED;
361 gen.state.store(exhausted_state, Ordering::Release);
362
363 let start = SystemTime::now();
364 let id = gen.next_id();
365 let elapsed = SystemTime::now().duration_since(start).unwrap();
366
367 let (ts, _, seq) = Snowflake::parts(id);
368 assert!(ts > now, "new ID must be in a later millisecond");
369 assert_eq!(seq, 0);
370 assert!(
371 elapsed < Duration::from_millis(50),
372 "block should be roughly one ms, got {elapsed:?}"
373 );
374 }
375
376 #[test]
377 fn clock_skew_reported_via_result() {
378 let gen = Snowflake::new(1);
379 let future_ms = current_offset_ms(gen.epoch_ms) + 5_000;
382 gen.state
383 .store(future_ms << STATE_SEQ_BITS, Ordering::Release);
384
385 match gen.try_next_id() {
386 Err(ClockSkew { last_ms, now_ms }) => {
387 assert_eq!(last_ms, future_ms);
388 assert!(now_ms < last_ms);
389 }
390 Ok(id) => panic!("expected ClockSkew, got id {id}"),
391 }
392 }
393
394 #[test]
395 #[should_panic(expected = "clock moved backward")]
396 fn next_id_panics_on_clock_skew() {
397 let gen = Snowflake::new(1);
398 let future_ms = current_offset_ms(gen.epoch_ms) + 5_000;
399 gen.state
400 .store(future_ms << STATE_SEQ_BITS, Ordering::Release);
401 let _ = gen.next_id();
402 }
403
404 #[test]
405 fn multi_thread_all_unique() {
406 let gen = Arc::new(Snowflake::new(3));
407 let mut handles = Vec::new();
408 for _ in 0..8 {
409 let g = Arc::clone(&gen);
410 handles.push(thread::spawn(move || {
411 let mut local = Vec::with_capacity(2000);
412 for _ in 0..2000 {
413 local.push(g.next_id());
414 }
415 local
416 }));
417 }
418 let mut all = HashSet::new();
419 for h in handles {
420 for id in h.join().unwrap() {
421 assert!(all.insert(id), "duplicate id under thread contention");
422 }
423 }
424 assert_eq!(all.len(), 8 * 2000);
425 }
426
427 #[test]
428 fn custom_epoch_round_trip() {
429 let epoch = 1_700_000_000_000_u64;
430 let gen = Snowflake::with_epoch(9, epoch);
431 let id = gen.next_id();
432 let (ts_offset, worker, _) = Snowflake::parts(id);
433 assert_eq!(worker, 9);
434 assert_eq!(gen.epoch_ms(), epoch);
435 let wall = ts_offset + epoch;
436 assert!(wall > epoch);
437 }
438
439 #[test]
440 fn parts_extracts_each_field() {
441 let ts: u64 = 12_345;
443 let worker: u64 = 700;
444 let seq: u64 = 4000;
445 let id = (ts << TIMESTAMP_SHIFT) | (worker << WORKER_SHIFT) | seq;
446 let (got_ts, got_w, got_s) = Snowflake::parts(id);
447 assert_eq!(got_ts, ts);
448 assert_eq!(got_w as u64, worker);
449 assert_eq!(got_s as u64, seq);
450 }
451}