essence/
snowflake.rs

1//! Snowflake generation and parsing.
2//!
3//! # Snowflake bit format
4//! Snowflakes are represented as unsigned 64-bit integers (`u64`). The bits
5//! (from left to right, 0-indexed, `inclusive..exclusive`) are as follows:
6//!
7//! * Bits 0..46: Timestamp in milliseconds since `2022-12-25T00:00:00Z`. (See [`EPOCH`])
8//! * Bits 46..51: The model type represented as an enumeration. (See [`ModelType`])
9//! * Bits 51..56: The node or process ID that generated the snowflake.
10//! * Bits 56..64: The incrementing counter for the snowflake.
11//!
12//! ```text
13//! 1111111111111111111111111111111111111111111111_11111_11111_11111111
14//! milliseconds from 2022-12-25T00:00:00Z         ^     ^     ^
15//!                                                |     |     increment (0 to 255)
16//!                                                |     node number (0 to 31)
17//!                                                model number (0 to 31)
18//! ```
19
20use crate::models::ModelType;
21use regex::Regex;
22use std::{
23    sync::{
24        atomic::{AtomicU8, Ordering::Relaxed},
25        OnceLock,
26    },
27    time::{Duration, SystemTime, UNIX_EPOCH},
28};
29
30static INCREMENT: AtomicU8 = AtomicU8::new(0);
31
32/// The snowflake epoch. This is ``2022-12-25T00:00:00Z`` as a Unix timestamp, in milliseconds.
33pub const EPOCH_MILLIS: u64 = 1_671_926_400_000;
34
35/// Returns the current time in milliseconds since the epoch.
36#[inline]
37#[must_use]
38pub fn epoch_time() -> u64 {
39    let now = SystemTime::now()
40        .duration_since(UNIX_EPOCH)
41        .expect("system time is before UNIX epoch")
42        .as_millis() as u64;
43
44    now.saturating_sub(EPOCH_MILLIS)
45}
46
47/// Generates a snowflake with the given model type and node ID.
48///
49/// # Safety
50/// This assumes that `node_id < 32`. If this is not the case, bits will flow and overwrite
51/// other fields, resulting in an invalid snowflake.
52#[inline]
53#[must_use]
54pub unsafe fn generate_snowflake_unchecked(model_type: ModelType, node_id: u8) -> u64 {
55    let increment = INCREMENT.fetch_add(1, Relaxed);
56
57    (epoch_time() << 18) | ((model_type as u64) << 13) | ((node_id as u64) << 8) | increment as u64
58}
59
60/// Generates a snowflake with the given model type and node ID.
61///
62/// # Panics
63/// * If `node_id >= 32`.
64#[inline]
65#[must_use]
66pub fn generate_snowflake(model_type: ModelType, node_id: u8) -> u64 {
67    assert!(node_id < 32, "node ID must be less than 32");
68
69    unsafe { generate_snowflake_unchecked(model_type, node_id) }
70}
71
72/// Returns the given snowflake with its model type altered to the given one.
73#[inline]
74#[must_use]
75pub const fn with_model_type(snowflake: u64, model_type: ModelType) -> u64 {
76    snowflake & !(0b11111 << 13) | (model_type as u64) << 13
77}
78
79/// Extract all snowflake IDs surrounded by <@!? and >, called mentions, from a string.
80#[must_use]
81pub fn extract_mentions(s: &str) -> Vec<u64> {
82    static REGEX: OnceLock<Regex> = OnceLock::new();
83
84    let regex = REGEX.get_or_init(|| Regex::new(r"<@!?(\d+)>").unwrap());
85    regex
86        .captures_iter(s)
87        .map(|c| c.get(1).unwrap().as_str().parse().unwrap())
88        .collect::<Vec<_>>()
89}
90
91/// Reads parts of a snowflake.
92#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
93pub struct SnowflakeReader(u64);
94
95impl SnowflakeReader {
96    /// Creates a new snowflake reader from the given snowflake.
97    #[inline]
98    #[must_use]
99    pub const fn new(snowflake: u64) -> Self {
100        Self(snowflake)
101    }
102
103    /// Reads and returns the timestamp of the snowflake as a Unix timestamp in milliseconds.
104    #[inline]
105    #[must_use]
106    pub const fn timestamp_millis(&self) -> u64 {
107        self.0 >> 18
108    }
109
110    /// Reads and returns the timestamp of the snowflake as a Unix timestamp in seconds.
111    #[inline]
112    #[must_use]
113    pub const fn timestamp_secs(&self) -> u64 {
114        self.timestamp_millis() / 1000
115    }
116
117    /// Reads and returns the timestamp of the snowflake as a [`SystemTime`].
118    #[inline]
119    #[must_use]
120    pub fn timestamp(&self) -> SystemTime {
121        UNIX_EPOCH + Duration::from_millis(self.timestamp_millis())
122    }
123
124    /// Reads and returns the model type of the snowflake.
125    #[inline]
126    #[must_use]
127    pub const fn model_type(&self) -> ModelType {
128        ModelType::from_u8(((self.0 >> 13) & 0b11111) as u8)
129    }
130
131    /// Reads and returns the node ID of the snowflake.
132    #[inline]
133    #[must_use]
134    pub const fn node_id(&self) -> u8 {
135        ((self.0 >> 8) & 0b11111) as u8
136    }
137
138    /// Reads and returns the increment of the snowflake.
139    #[inline]
140    #[must_use]
141    pub const fn increment(&self) -> u8 {
142        (self.0 & 0b1111_1111) as u8
143    }
144}
145
146impl From<u64> for SnowflakeReader {
147    fn from(value: u64) -> Self {
148        Self(value)
149    }
150}
151
152impl From<i64> for SnowflakeReader {
153    fn from(value: i64) -> Self {
154        Self(value as u64)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_generate_snowflake() {
164        let a = generate_snowflake(ModelType::User, 0);
165        let b = generate_snowflake(ModelType::User, 0);
166
167        assert_ne!(a, b);
168        println!("{} != {}", a, b);
169    }
170
171    #[test]
172    fn test_parse_snowflake() {
173        let snowflake = generate_snowflake(ModelType::Channel, 6);
174        let reader = SnowflakeReader::new(snowflake);
175
176        assert_eq!(reader.model_type(), ModelType::Channel);
177        assert_eq!(reader.node_id(), 6);
178    }
179
180    #[test]
181    fn test_with_model_type() {
182        let original = generate_snowflake(ModelType::User, 0);
183        let original_reader = SnowflakeReader::new(original);
184
185        let new = with_model_type(original, ModelType::Channel);
186        let new_reader = SnowflakeReader::new(new);
187
188        assert_eq!(
189            original_reader.timestamp_millis(),
190            new_reader.timestamp_millis()
191        );
192        assert_eq!(original_reader.node_id(), new_reader.node_id());
193        assert_eq!(original_reader.increment(), new_reader.increment());
194
195        assert_eq!(original_reader.model_type(), ModelType::User);
196        assert_eq!(new_reader.model_type(), ModelType::Channel);
197    }
198}