1#[macro_use]
156extern crate serde;
157
158use chrono::{DateTime, TimeZone, Utc};
159use pnet::datalink::interfaces;
160use std::fmt::{Debug, Formatter};
161use std::net::{IpAddr, Ipv4Addr};
162use std::sync::Arc;
163use std::time::Duration;
164use parking_lot::Mutex;
165
166const BIT_LEN_TIME: i64 = 39;
168
169const BIT_LEN_SEQUENCE: i64 = 8;
171
172const BIT_LEN_MACHINE_ID: i64 = 63 - BIT_LEN_TIME - BIT_LEN_SEQUENCE;
174
175const FLAKE_TIME_UNIT: i64 = 10_000_000;
177
178#[derive(Debug)]
182pub enum Error {
183 StartTimeAheadOfCurrentTime(DateTime<Utc>),
185
186 MachineIdFailed(Box<dyn std::error::Error + 'static + Send + Sync>),
188
189 InvalidMachineID(u16),
191
192 TimeOverflow,
194
195 NoPrivateIPv4Address,
197}
198
199unsafe impl Send for Error {}
200unsafe impl Sync for Error {}
201
202impl std::fmt::Display for Error {
203 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
204 match self {
205 Error::StartTimeAheadOfCurrentTime(time) => {
206 write!(f, "start_time {} is ahead of current time", time)
207 }
208 Error::MachineIdFailed(e) => write!(f, "cannot get a machine id: {}", e),
209 Error::InvalidMachineID(id) => write!(f, "invalid machine id: {}", id),
210 Error::TimeOverflow => write!(f, "over the sonyflake time limit"),
211 Error::NoPrivateIPv4Address => write!(f, "no private IPv4 address"),
212 }
213 }
214}
215
216impl std::error::Error for Error {}
217
218pub trait MachineID {
220 fn machine_id(&mut self) -> Result<u16, Box<dyn std::error::Error + Send + Sync + 'static>>;
225}
226
227pub trait MachineIDChecker {
229 fn check_machine_id(&self, id: u16) -> bool;
233}
234
235pub struct Settings {
239 start_time: Option<DateTime<Utc>>,
240 machine_id: Option<Box<dyn MachineID>>,
241 check_machine_id: Option<Box<dyn MachineIDChecker>>,
242}
243
244impl Default for Settings {
245 fn default() -> Self {
246 Settings::new()
247 }
248}
249
250impl Settings {
251 pub fn new() -> Self {
256 Self {
257 start_time: None,
258 machine_id: None,
259 check_machine_id: None,
260 }
261 }
262
263 fn get_start_time(&self) -> Result<i64, Error> {
264 return if let Some(start_time) = self.start_time {
265 if start_time > Utc::now() {
266 return Err(Error::StartTimeAheadOfCurrentTime(start_time));
267 }
268 Ok(to_sonyflake_time(start_time))
269 } else {
270 Ok(to_sonyflake_time(default_start_time()))
271 }
272 }
273
274 fn get_and_check_machine_id(self) -> Result<u16, Error> {
275 return if let Some(mut machine_id) = self.machine_id {
276 match machine_id.machine_id() {
277 Ok(machine_id) => {
278 if let Some(checker) = self.check_machine_id {
279 if !checker.check_machine_id(machine_id) {
280 return Err(Error::InvalidMachineID(machine_id));
281 }
282 }
283 Ok(machine_id)
284 },
285 Err(e) => Err(Error::MachineIdFailed(e)),
286 }
287 } else {
288 match lower_16_bit_private_ip() {
289 Ok(machine_id) => {
290 if let Some(checker) = self.check_machine_id {
291 if !checker.check_machine_id(machine_id) {
292 return Err(Error::InvalidMachineID(machine_id));
293 }
294 }
295 Ok(machine_id)
296 },
297 Err(e) => Err(e),
298 }
299 };
300 }
301
302 pub fn set_start_time(mut self, start_time: DateTime<Utc>) -> Self {
305 self.start_time = Some(start_time);
306 self
307 }
308
309 pub fn set_machine_id(mut self, machine_id: Box<dyn MachineID>) -> Self {
312 self.machine_id = Some(machine_id);
313 self
314 }
315
316 pub fn set_check_machine_id(mut self, check_machine_id: Box<dyn MachineIDChecker>) -> Self {
319 self.check_machine_id = Some(check_machine_id);
320 self
321 }
322
323 pub fn into_sonyflake(self) -> Result<SonyFlake, Error> {
324 SonyFlake::new(self)
325 }
326
327 pub fn into_infallible_sonyflake(self) -> Result<InfallibleSonyFlake, Error> {
328 InfallibleSonyFlake::new(self)
329 }
330}
331
332#[derive(Debug)]
334pub struct SonyFlake {
335 start_time: i64,
336 machine_id: u16,
337 inner: Arc<Mutex<Inner>>,
338}
339
340impl SonyFlake {
341 pub fn new(st: Settings) -> Result<Self, Error> {
346 let sequence = 1 << (BIT_LEN_SEQUENCE - 1);
347
348 let start_time = st.get_start_time()?;
349
350 let machine_id = st.get_and_check_machine_id()?;
351
352 Ok(SonyFlake {
353 start_time,
354 machine_id,
355 inner: Arc::new(Mutex::new(Inner {
356 sequence,
357 elapsed_time: 0,
358 })),
359 })
360 }
361
362 pub fn next_id(&mut self) -> Result<u64, Error> {
365 let mask_sequence = (1 << BIT_LEN_SEQUENCE) - 1;
366
367 let mut inner = self.inner.lock();
368
369 let current = current_elapsed_time(self.start_time);
370
371 if inner.elapsed_time < current {
372 inner.elapsed_time = current;
373 inner.sequence = 0;
374 } else {
375 inner.sequence = (inner.sequence + 1) & mask_sequence;
377 if inner.sequence == 0 {
378 inner.elapsed_time += 1;
379 let overtime = inner.elapsed_time - current;
380 std::thread::sleep(sleep_time(overtime));
381 }
382 }
383
384 if inner.elapsed_time >= 1 << BIT_LEN_TIME {
385 return Err(Error::TimeOverflow);
386 }
387
388 Ok(to_id(inner.elapsed_time, inner.sequence, self.machine_id))
389 }
390}
391
392impl Clone for SonyFlake {
394 fn clone(&self) -> Self {
395 Self {
396 start_time: self.start_time,
397 machine_id: self.machine_id,
398 inner: self.inner.clone(),
399 }
400 }
401}
402
403#[derive(Debug)]
406pub struct InfallibleSonyFlake {
407 start_time: i64,
408 machine_id: u16,
409 inner: Arc<Mutex<Inner>>,
410}
411
412impl InfallibleSonyFlake {
413 pub fn new(st: Settings) -> Result<Self, Error> {
418 let sequence = 1 << (BIT_LEN_SEQUENCE - 1);
419
420 let start_time = st.get_start_time()?;
421
422 let machine_id = st.get_and_check_machine_id()?;
423
424 Ok(Self {
425 start_time,
426 machine_id,
427 inner: Arc::new(Mutex::new(Inner {
428 sequence,
429 elapsed_time: 0,
430 })),
431 })
432 }
433
434 pub fn next_id(&mut self) -> u64 {
437 let mask_sequence = (1 << BIT_LEN_SEQUENCE) - 1;
438
439 let mut inner = self.inner.lock();
440
441 let current = current_elapsed_time(self.start_time);
442
443 if inner.elapsed_time < current {
444 inner.elapsed_time = current;
445 inner.sequence = 0;
446 } else {
447 inner.sequence = (inner.sequence + 1) & mask_sequence;
449 if inner.sequence == 0 {
450 inner.elapsed_time += 1;
451 let overtime = inner.elapsed_time - current;
452 std::thread::sleep(sleep_time(overtime));
453 }
454 }
455
456 if inner.elapsed_time >= 1 << BIT_LEN_TIME {
457 let now = Utc::now();
458 self.start_time = to_sonyflake_time(now, );
460 inner.elapsed_time = 0;
461 inner.sequence = 0;
462 return to_id(inner.elapsed_time, inner.sequence, self.machine_id);
463 }
464
465 to_id(inner.elapsed_time, inner.sequence, self.machine_id)
466 }
467}
468
469impl Clone for InfallibleSonyFlake {
471 fn clone(&self) -> Self {
472 Self {
473 start_time: self.start_time,
474 machine_id: self.machine_id,
475 inner: self.inner.clone(),
476 }
477 }
478}
479
480fn private_ipv4() -> Option<Ipv4Addr> {
481 interfaces()
482 .iter()
483 .filter(|interface| interface.is_up() && !interface.is_loopback())
484 .map(|interface| {
485 interface
486 .ips
487 .iter()
488 .map(|ip_addr| ip_addr.ip()) .find(|ip_addr| match ip_addr {
490 IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
491 IpAddr::V6(_) => false,
492 })
493 .and_then(|ip_addr| match ip_addr {
494 IpAddr::V4(ipv4) => Some(ipv4), _ => None,
496 })
497 })
498 .find(|ip| ip.is_some())
499 .flatten()
500}
501
502fn is_private_ipv4(ip: Ipv4Addr) -> bool {
503 let octets = ip.octets();
504 octets[0] == 10
505 || octets[0] == 172 && (octets[1] >= 16 && octets[1] < 32)
506 || octets[0] == 192 && octets[1] == 168
507}
508
509fn lower_16_bit_private_ip() -> Result<u16, Error> {
510 match private_ipv4() {
511 Some(ip) => {
512 let octets = ip.octets();
513 Ok(((octets[2] as u16) << 8) + (octets[3] as u16))
514 }
515 None => Err(Error::NoPrivateIPv4Address),
516 }
517}
518
519#[derive(Debug)]
520struct Inner {
521 elapsed_time: i64,
522 sequence: u16,
523}
524
525fn to_id(elapsed_time: i64, seq: u16, machine_id: u16) -> u64 {
526 (elapsed_time as u64) << (BIT_LEN_SEQUENCE + BIT_LEN_MACHINE_ID)
527 | (seq as u64) << BIT_LEN_MACHINE_ID
528 | (machine_id as u64)
529}
530
531fn to_sonyflake_time(time: DateTime<Utc>) -> i64 {
532 time.timestamp_nanos() / FLAKE_TIME_UNIT
533}
534
535fn current_elapsed_time(start_time: i64) -> i64 {
536 to_sonyflake_time(Utc::now()) - start_time
537}
538
539fn sleep_time(overtime: i64) -> Duration {
540 Duration::from_millis(overtime as u64 * 10)
541 - Duration::from_nanos((Utc::now().timestamp_nanos() % FLAKE_TIME_UNIT) as u64)
542}
543
544#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
546pub struct IDParts {
547 id: u64,
548 msb: u64,
549 time: u64,
550 sequence: u64,
551 machine_id: u64,
552}
553
554impl IDParts {
555 pub fn decompose(id: u64) -> Self {
557 decompose(id)
558 }
559
560 pub fn get_id(&self) -> u64 {
562 self.id
563 }
564
565 pub fn get_msb(&self) -> u64 {
567 self.msb
568 }
569
570 pub fn get_time(&self) -> u64 {
572 self.time
573 }
574
575 pub fn get_sequence(&self) -> u64 {
577 self.sequence
578 }
579
580 pub fn get_machine_id(&self) -> u64 {
582 self.machine_id
583 }
584}
585
586pub fn decompose(id: u64) -> IDParts {
588 let mask_seq = ((1 << BIT_LEN_SEQUENCE) - 1 as u64) << BIT_LEN_MACHINE_ID;
589 let mask_machine_id = (1 << BIT_LEN_MACHINE_ID) - 1 as u64;
590
591 let msb = id >> 63;
592 let time = id >> (BIT_LEN_SEQUENCE + BIT_LEN_MACHINE_ID);
593
594 let seq = (id & mask_seq) >> BIT_LEN_MACHINE_ID;
595 let machine_id = id & mask_machine_id;
596 IDParts {
597 id,
598 msb,
599 time,
600 sequence: seq,
601 machine_id,
602 }
603}
604
605fn default_start_time() -> DateTime<Utc> {
606 Utc.ymd(2021, 8, 6).and_hms_nano(0, 0, 0, 0)
607}
608
609#[cfg(test)]
610mod tests {
611 use crate::{Error as FlakeError, lower_16_bit_private_ip, to_sonyflake_time, IDParts, Settings, SonyFlake, InfallibleSonyFlake, BIT_LEN_SEQUENCE, MachineID, MachineIDChecker, BIT_LEN_TIME};
612 use chrono::Utc;
613 use std::time::Duration;
614 use std::error::Error;
615 use std::thread::JoinHandle;
616 use std::collections::HashSet;
617
618 #[test]
619 fn test_sonyflake_once() {
620 let now = Utc::now();
621 let mut f = Settings::new().set_start_time(now).into_sonyflake().unwrap();
622
623 let sleep_time = 500u64;
624 std::thread::sleep(Duration::from_millis(sleep_time));
625 let id = f.next_id().unwrap();
626
627 let parts = IDParts::decompose(id);
628 assert_eq!(parts.get_msb(), 0);
629 assert_eq!(parts.get_sequence(), 0);
630 assert!(parts.get_time() < sleep_time || parts.get_time() > sleep_time + 1);
631 assert_eq!(parts.machine_id, lower_16_bit_private_ip().unwrap() as u64);
632 }
633
634 #[test]
635 fn test_infallible_sonyflake_once() {
636 let now = Utc::now();
637 let mut f = Settings::new().set_start_time(now).into_infallible_sonyflake().unwrap();
638
639 let sleep_time = 500u64;
640 std::thread::sleep(Duration::from_millis(sleep_time));
641 let id = f.next_id();
642
643 let parts = IDParts::decompose(id);
644 assert_eq!(parts.get_msb(), 0);
645 assert_eq!(parts.get_sequence(), 0);
646 assert!(parts.get_time() < sleep_time || parts.get_time() > sleep_time + 1);
647 assert_eq!(parts.machine_id, lower_16_bit_private_ip().unwrap() as u64);
648 }
649
650 #[test]
651 fn test_sonyflake_for_10_sec() {
652 let now = Utc::now();
653 let start_time = to_sonyflake_time(now);
654 let mut f = SonyFlake::new(Settings::new().set_start_time(now)).unwrap();
655
656 let mut num_id: u64 = 0;
657 let mut last_id: u64 = 0;
658 let mut max_seq: u64 = 0;
659
660 let machine_id = lower_16_bit_private_ip().unwrap() as u64;
661
662 let initial = to_sonyflake_time(Utc::now());
663 let mut current = initial.clone();
664
665 while current - initial < 1000 {
666 let id = f.next_id().unwrap();
667
668 let parts = IDParts::decompose(id);
669 num_id += 1;
670
671 assert!(id > last_id);
672 last_id = id;
673
674 current = to_sonyflake_time(Utc::now());
675
676 assert_eq!(parts.get_msb(), 0);
677 let overtime = start_time + (parts.get_time() as i64) - current;
678 assert!(overtime <= 0);
679
680 if max_seq < parts.get_sequence() {
681 max_seq = parts.get_sequence();
682 }
683
684 assert_eq!(parts.get_machine_id(), machine_id);
685 }
686
687 assert_eq!(max_seq, (1 << BIT_LEN_SEQUENCE) - 1);
688 println!("number of id: {}", num_id);
689 }
690
691 #[test]
692 fn test_infallible_sonyflake_for_10_sec() {
693 let now = Utc::now();
694 let start_time = to_sonyflake_time(now);
695 let mut f = InfallibleSonyFlake::new(Settings::new().set_start_time(now)).unwrap();
696
697 let mut num_id: u64 = 0;
698 let mut last_id: u64 = 0;
699 let mut max_seq: u64 = 0;
700
701 let machine_id = lower_16_bit_private_ip().unwrap() as u64;
702
703 let initial = to_sonyflake_time(Utc::now());
704 let mut current = initial.clone();
705
706 while current - initial < 1000 {
707 let id = f.next_id();
708
709 let parts = IDParts::decompose(id);
710 num_id += 1;
711
712 assert!(id > last_id);
713 last_id = id;
714
715 current = to_sonyflake_time(Utc::now());
716
717 assert_eq!(parts.get_msb(), 0);
718 let overtime = start_time + (parts.get_time() as i64) - current;
719 assert!(overtime <= 0);
720
721 if max_seq < parts.get_sequence() {
722 max_seq = parts.get_sequence();
723 }
724
725 assert_eq!(parts.get_machine_id(), machine_id);
726 }
727
728 assert_eq!(max_seq, (1 << BIT_LEN_SEQUENCE) - 1);
729 println!("number of id: {}", num_id);
730 }
731
732 struct CustomMachineID {
733 counter: u64,
734 id: u16,
735 }
736
737 impl MachineID for CustomMachineID {
738 fn machine_id(&mut self) -> Result<u16, Box<dyn Error + Send + Sync + 'static>> {
739 self.counter += 1;
740 if self.counter % 2 != 0 {
741 Ok(self.id)
742 } else {
743 Err(Box::new("NaN".parse::<u32>().unwrap_err()))
744 }
745 }
746 }
747
748 struct CustomMachineIDChecker;
749
750 impl MachineIDChecker for CustomMachineIDChecker {
751 fn check_machine_id(&self, id: u16) -> bool {
752 if id % 2 != 0 {
753 true
754 } else {
755 false
756 }
757 }
758 }
759
760 #[test]
761 fn test_sonyflake_custom_machine_id_and_checker() {
762 let mut sf = Settings::new()
763 .set_machine_id(Box::new(CustomMachineID { counter: 0, id: 1 }))
764 .set_check_machine_id(Box::new(CustomMachineIDChecker {}))
765 .into_sonyflake().unwrap();
766 let id = sf.next_id().unwrap();
767 let parts = IDParts::decompose(id);
768 assert_eq!(parts.get_machine_id(), 1);
769
770 let err = Settings::new()
771 .set_machine_id(Box::new(CustomMachineID { counter: 0, id: 2 }))
772 .set_check_machine_id(Box::new(CustomMachineIDChecker {}))
773 .into_sonyflake().unwrap_err();
774
775 assert_eq!(format!("{}", err), FlakeError::InvalidMachineID(2).to_string());
776 }
777
778 #[test]
779 fn test_infallible_sonyflake_custom_machine_id_and_checker() {
780 let mut sf = Settings::new()
781 .set_machine_id(Box::new(CustomMachineID { counter: 0, id: 1 }))
782 .set_check_machine_id(Box::new(CustomMachineIDChecker {}))
783 .into_infallible_sonyflake().unwrap();
784 let id = sf.next_id();
785 let parts = IDParts::decompose(id);
786 assert_eq!(parts.get_machine_id(), 1);
787
788 let err = Settings::new()
789 .set_machine_id(Box::new(CustomMachineID { counter: 0, id: 2 }))
790 .set_check_machine_id(Box::new(CustomMachineIDChecker {}))
791 .into_infallible_sonyflake().unwrap_err();
792
793 assert_eq!(format!("{}", err), FlakeError::InvalidMachineID(2).to_string());
794 }
795
796 #[test]
797 #[should_panic]
798 fn test_fallible() {
799 let now = Utc::now();
800 let mut sf = Settings::new().set_start_time(now).into_sonyflake().unwrap();
801 sf.inner.lock().elapsed_time = 1 << BIT_LEN_TIME;
802 let _ = sf.next_id().unwrap();
803 }
804
805 #[test]
806 fn test_infallible() {
807 let now = Utc::now();
808 let mut sf = Settings::new().set_start_time(now).into_infallible_sonyflake().unwrap();
809 sf.inner.lock().elapsed_time = (1 << BIT_LEN_TIME) - 2;
810 let _ = sf.next_id();
811 let _ = sf.next_id();
812 let _ = sf.next_id();
813 let _ = sf.next_id();
814 }
815
816 #[test]
817 fn test_sonyflake_concurrency() {
818 let now = Utc::now();
819 let sf = Settings::new().set_start_time(now).into_sonyflake().unwrap();
820
821 let (tx, rx) = std::sync::mpsc::channel::<u64>();
822
823 let mut threads = Vec::<JoinHandle<()>>::with_capacity(1000);
824 for _ in 0..100 {
825 let mut thread_sf = sf.clone();
826 let thread_tx = tx.clone();
827 threads.push(std::thread::spawn(move || {
828 for _ in 0..1000 {
829 thread_tx.send(thread_sf.next_id().unwrap()).unwrap();
830 }
831 }));
832 }
833
834 let mut ids = HashSet::new();
835 for _ in 0..100000 {
836 let id = rx.recv().unwrap();
837 assert!(!ids.contains(&id), "duplicate id: {}", id);
838 ids.insert(id);
839 }
840
841 for t in threads {
842 t.join().expect("thread panicked");
843 }
844 }
845
846 #[test]
847 fn test_infallible_sonyflake_concurrency() {
848 let now = Utc::now();
849 let sf = Settings::new().set_start_time(now).into_infallible_sonyflake().unwrap();
850
851 let (tx, rx) = std::sync::mpsc::channel::<u64>();
852
853 let mut threads = Vec::<JoinHandle<()>>::with_capacity(1000);
854 for _ in 0..100 {
855 let mut thread_sf = sf.clone();
856 let thread_tx = tx.clone();
857 threads.push(std::thread::spawn(move || {
858 for _ in 0..1000 {
859 thread_tx.send(thread_sf.next_id()).unwrap();
860 }
861 }));
862 }
863
864 let mut ids = HashSet::new();
865 for _ in 0..100000 {
866 let id = rx.recv().unwrap();
867 assert!(!ids.contains(&id), "duplicate id: {}", id);
868 ids.insert(id);
869 }
870
871 for t in threads {
872 t.join().expect("thread panicked");
873 }
874 }
875
876 #[test]
877 fn test_error_send_sync() {
878 let res = SonyFlake::new(Settings::new());
879 std::thread::spawn(move || {
880 let _ = res.is_ok();
881 })
882 .join()
883 .unwrap();
884 }
885}