cronback_lib/
model.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3
4use derive_more::{Display, From, Into};
5use sea_orm::TryGetable;
6use serde::{Deserialize, Deserializer, Serialize};
7use thiserror::Error;
8use ulid::Ulid;
9
10const SHARD_COUNT: u64 = 1031;
11
12#[derive(Debug, Error)]
13pub enum ModelIdError {
14    #[error("Malformed Id: {0}")]
15    InvalidId(String),
16}
17
18// A shorthand for a Result that returns a ModelIdError
19impl From<ModelIdError> for tonic::Status {
20    fn from(value: ModelIdError) -> Self {
21        tonic::Status::invalid_argument(value.to_string())
22    }
23}
24
25#[derive(
26    Debug,
27    Clone,
28    Copy,
29    Default,
30    Eq,
31    PartialEq,
32    Serialize,
33    Deserialize,
34    PartialOrd,
35    Ord,
36    Display,
37    From,
38    Into,
39)]
40#[serde(transparent)]
41pub struct Shard(pub u64);
42
43impl Shard {
44    pub fn encoded(&self) -> String {
45        format!("{:04}", self.0)
46    }
47}
48
49pub trait ModelId: Sized + std::fmt::Display + From<String> {
50    fn has_valid_prefix(&self) -> bool;
51    fn value(&self) -> &str;
52    fn validated(self) -> Result<ValidShardedId<Self>, ModelIdError> {
53        ValidShardedId::try_from(self)
54    }
55}
56
57#[derive(Ord, PartialOrd, Debug, Clone, PartialEq, Eq, Display, Serialize)]
58pub struct ValidShardedId<T>(T);
59
60impl<T> ValidShardedId<T>
61where
62    T: ModelId + From<String>,
63{
64    pub fn try_from(s: T) -> Result<Self, ModelIdError> {
65        // validate Id
66        if s.has_valid_prefix() {
67            // Can also validate the rest of properties of the Id format.
68            // Including a future HMAC signature
69            Ok(Self(s))
70        } else {
71            Err(ModelIdError::InvalidId(s.to_string()))
72        }
73    }
74
75    // Should be used with caution, as it bypasses validation
76    pub fn from_string_unsafe(s: String) -> Self {
77        Self(T::from(s))
78    }
79
80    /// Returns the shard associated with this Id
81    pub fn shard(&self) -> Shard {
82        // extract the shard from the id
83        let (_, after) = self.value().split_once('_').expect("Id is malformed");
84        let shard: u64 = after[..4].parse().expect("Id is malformed");
85        Shard::from(shard)
86    }
87
88    // The timestamp section of the underlying Id
89    pub fn timestamp_ms(&self) -> Option<u64> {
90        // extract the shard from the id
91        let (_, after) = self.value().split_once('_').expect("Id is malformed");
92        let ulid = &after[4..];
93        let ulid = Ulid::from_string(ulid).ok()?;
94        Some(ulid.timestamp_ms())
95    }
96
97    pub fn inner(&self) -> &T {
98        &self.0
99    }
100
101    pub fn into_inner(self) -> T {
102        self.0
103    }
104}
105
106impl<'de, T> Deserialize<'de> for ValidShardedId<T>
107where
108    T: ModelId + From<String>,
109{
110    fn deserialize<D>(deserializer: D) -> Result<ValidShardedId<T>, D::Error>
111    where
112        D: Deserializer<'de>,
113    {
114        let s = String::deserialize(deserializer)?;
115        let id = T::from(s);
116        id.validated().map_err(serde::de::Error::custom)
117    }
118}
119
120impl<T: ModelId> std::ops::Deref for ValidShardedId<T> {
121    type Target = T;
122
123    fn deref(&self) -> &Self::Target {
124        &self.0
125    }
126}
127
128impl<T: ModelId> From<ValidShardedId<T>> for sea_query::Value {
129    fn from(id: ValidShardedId<T>) -> ::sea_query::Value {
130        ::sea_query::Value::String(Some(Box::new(id.value().to_owned())))
131    }
132}
133
134impl<T: ModelId + TryGetable> sea_orm::TryGetable for ValidShardedId<T> {
135    fn try_get_by<I: ::sea_orm::ColIdx>(
136        res: &::sea_orm::QueryResult,
137        index: I,
138    ) -> Result<Self, sea_orm::TryGetError> {
139        let val = T::try_get_by::<_>(res, index)?;
140
141        val.validated().map_err(|e| {
142            sea_orm::TryGetError::DbErr(sea_orm::DbErr::TryIntoErr {
143                from: "String",
144                into: "ValidShardedId",
145                source: Box::new(e),
146            })
147        })
148    }
149}
150
151impl<T: ModelId> sea_query::ValueType for ValidShardedId<T> {
152    fn try_from(
153        v: ::sea_query::Value,
154    ) -> Result<Self, ::sea_query::ValueTypeErr> {
155        match v {
156            | ::sea_query::Value::String(Some(x)) => {
157                let val: T = (*x).into();
158                val.validated().map_err(|_| sea_query::ValueTypeErr)
159            }
160            | _ => Err(sea_query::ValueTypeErr),
161        }
162    }
163
164    fn type_name() -> String {
165        stringify!($name).to_owned()
166    }
167
168    fn array_type() -> sea_orm::sea_query::ArrayType {
169        sea_orm::sea_query::ArrayType::String
170    }
171
172    fn column_type() -> sea_query::ColumnType {
173        sea_query::ColumnType::String(None)
174    }
175}
176
177impl<T: ModelId> sea_query::Nullable for ValidShardedId<T> {
178    fn null() -> ::sea_query::Value {
179        ::sea_query::Value::String(None)
180    }
181}
182
183impl<T: ModelId> sea_orm::TryFromU64 for ValidShardedId<T> {
184    fn try_from_u64(_: u64) -> Result<Self, ::sea_orm::DbErr> {
185        Err(::sea_orm::DbErr::ConvertFromU64(stringify!(T)))
186    }
187}
188
189/// Indicates that this is a top-level Id (does not follow sharding scheme of
190/// another Id)
191pub trait RootId: ModelId {}
192
193pub(crate) fn generate_model_id<T, B>(
194    model_prefix: T,
195    owner: &ValidShardedId<B>,
196) -> String
197where
198    T: AsRef<str>,
199    B: RootId,
200{
201    format!(
202        "{}_{}{}",
203        model_prefix.as_ref(),
204        owner.shard().encoded(),
205        Ulid::new().to_string()
206    )
207}
208
209pub(crate) fn generate_raw_id<T>(model_prefix: T) -> String
210where
211    T: AsRef<str>,
212{
213    // Raw ids are special, but we still prefix the string with the shard
214    // identifier even if it's self referential, for consistency that is.
215    let new_id = Ulid::new().to_string();
216    let mut hasher = DefaultHasher::new();
217    new_id.hash(&mut hasher);
218    let shard = Shard::from(hasher.finish() % SHARD_COUNT);
219
220    format!("{}_{}{}", model_prefix.as_ref(), shard.encoded(), new_id)
221}
222
223/// Define a new model id NewType 
224#[rustfmt::skip]
225macro_rules! define_model_id_base {
226    (
227        @prefix = $prefix:literal,
228        $(@proto = $proto:ty,)?
229        $(#[$m:meta])*
230        $type_vis:vis struct $name:ident;
231    ) => {
232        $(#[$m])*
233        #[derive(
234            Debug,
235            Hash,
236            Clone,
237            Default,
238            ::serde::Serialize,
239            ::serde::Deserialize,
240            Eq,
241            PartialEq,
242            PartialOrd,
243            Ord,
244            ::derive_more::Display,
245            ::derive_more::From,
246            ::derive_more::Into,
247        )]
248        #[serde(transparent)]
249        $type_vis struct $name(String);
250
251        impl $crate::model::ModelId for $name {
252            fn has_valid_prefix(&self) -> bool {
253                self.0.starts_with(concat!($prefix, "_"))
254            }
255            fn value(&self) -> &str {
256                &self.0
257            }
258        }
259
260        impl TryFrom<$name> for $crate::model::ValidShardedId<$name> {
261            type Error = $crate::model::ModelIdError;
262            fn try_from(id: $name) -> Result<Self, Self::Error> {
263                crate::model::ModelId::validated(id)
264            }
265        }
266
267        impl From<$name> for ::sea_query::Value {
268            fn from(id: $name) -> ::sea_query::Value {
269                ::sea_query::Value::String(Some(Box::new(id.0.to_owned())))
270            }
271        }
272
273        impl ::sea_orm::TryGetable for $name {
274            fn try_get_by<I: ::sea_orm::ColIdx>(
275                res: &::sea_orm::QueryResult,
276                index: I
277            ) -> Result<Self, sea_orm::TryGetError> {
278                let val = res.try_get_by::<Option<String>, _>(index)?;
279                match (val) {
280                    Some(v) => Ok(v.into()),
281                    None => Err(sea_orm::TryGetError::Null(format!("{index:?}"))),
282                }
283            }
284        }
285
286        impl ::sea_query::ValueType for $name {
287            fn try_from(v: ::sea_query::Value) -> Result<Self, ::sea_query::ValueTypeErr> {
288                match v {
289                    ::sea_query::Value::String(Some(x)) => Ok((*x).into()),
290                    _ => Err(sea_query::ValueTypeErr),
291                }
292            }
293
294            fn type_name() -> String {
295                stringify!($name).to_owned()
296            }
297
298            fn array_type() -> sea_orm::sea_query::ArrayType {
299                sea_orm::sea_query::ArrayType::String
300            }
301
302            fn column_type() -> sea_query::ColumnType {
303                sea_query::ColumnType::String(None)
304            }
305        }
306
307        impl sea_query::Nullable for $name {
308            fn null() -> ::sea_query::Value {
309                ::sea_query::Value::String(None)
310            }
311        }
312
313        impl ::sea_orm::TryFromU64 for $name {
314            fn try_from_u64(_: u64) -> Result<Self, ::sea_orm::DbErr> {
315                Err(::sea_orm::DbErr::ConvertFromU64(stringify!($name)))
316            }
317        }
318
319        $(
320            // Proto newtype conversions.
321            impl From<$crate::model::ValidShardedId<$name>> for $proto {
322                fn from(value: $crate::model::ValidShardedId<$name>) -> Self {
323                    Self {
324                        value: value.to_string(),
325                    }
326                }
327            }
328
329            impl From<$proto> for $crate::model::ValidShardedId<$name> {
330                fn from(value: $proto) -> Self {
331                    $crate::model::ValidShardedId::from_string_unsafe(
332                        value.value,
333                    )
334                }
335            }
336
337            impl From<$proto> for $name {
338                fn from(value: $proto) -> Self {
339                    Self(value.value)
340                }
341            }
342
343            impl From<$name> for $proto {
344                fn from(value: $name) -> Self {
345                    Self {
346                        value: value.to_string(),
347                    }
348                }
349            }
350        )?
351
352        // TODO: Remove after we migrate all proto id types from Strings to new types
353        impl From<$crate::model::ValidShardedId<$name>> for std::string::String {
354            fn from(value: $crate::model::ValidShardedId<$name>) -> Self {
355                value.to_string()
356            }
357        }
358
359        // Unfortunately we can't implement this generically!
360        impl From<$crate::model::ValidShardedId<$name>> for $name {
361            fn from(value: $crate::model::ValidShardedId<$name>) -> Self {
362                value.into_inner()
363            }
364        }
365
366    };
367}
368
369#[rustfmt::skip]
370macro_rules! define_model_id {
371    (
372        @prefix = $prefix:literal,
373        @no_owner,
374        $(@proto = $proto:ty,)?
375        $(#[$m:meta])*
376        $type_vis:vis struct $name:ident;
377    ) => {
378
379        $crate::model::define_model_id_base!{
380            @prefix = $prefix,
381            $(@proto = $proto,)?
382            $(#[$m])*
383            $type_vis struct $name;
384        }
385
386        impl $crate::model::RootId for $name {}
387        
388        impl $name {
389            pub fn generate() -> $crate::model::ValidShardedId<Self> {
390                $crate::model::ValidShardedId::from_string_unsafe(
391                    $crate::model::generate_raw_id($prefix)
392                )
393            }
394        }
395    };
396    (
397        @prefix = $prefix:literal,
398        $(@proto = $proto:ty,)?
399        $(#[$m:meta])*
400        $type_vis:vis struct $name:ident;
401    ) => {
402        $crate::model::define_model_id_base!{
403            @prefix = $prefix,
404            $(@proto = $proto,)?
405            $(#[$m])*
406            $type_vis struct $name;
407        }
408
409        impl $name {
410            pub fn generate(owner: &$crate::model::ValidShardedId<impl $crate::model::RootId>) -> $crate::model::ValidShardedId<Self> {
411                $crate::model::ValidShardedId::from_string_unsafe(
412                    $crate::model::generate_model_id($prefix, owner)
413                )
414            }
415
416            pub fn from(value: String) -> Self {
417                Self(value)
418            }
419        }
420    };
421}
422
423pub(crate) use {define_model_id, define_model_id_base};
424
425#[cfg(test)]
426mod tests {
427    use anyhow::Result;
428
429    use super::*;
430
431    define_model_id! {
432        @prefix = "owner",
433        @no_owner,
434        pub struct OwnerId;
435    }
436    // test that Shard generate encoded string correctly
437    #[test]
438    fn test_shard_encoding() {
439        let shard = Shard::from(123);
440        assert_eq!("123", shard.to_string());
441        assert_eq!("0123", shard.encoded());
442    }
443
444    #[test]
445    fn test_model_id_generation() -> Result<()> {
446        let base = ValidShardedId::<OwnerId>::from_string_unsafe(
447            "owner_049342352".into(),
448        );
449
450        assert_eq!("0493", base.shard().encoded());
451        let id1 = generate_model_id("trig", &base);
452        assert!(id1.len() > 4);
453        assert!(id1.starts_with("trig_0493"));
454        Ok(())
455    }
456
457    #[test]
458    fn test_mode_id_macro() -> Result<()> {
459        define_model_id! {
460            @prefix = "som",
461            pub struct SomeId;
462        }
463
464        let owner = OwnerId::generate();
465
466        let id1 = SomeId::generate(&owner);
467        assert!(id1.timestamp_ms().is_some());
468        assert!(id1.timestamp_ms().unwrap() > 0);
469
470        assert!(id1.to_string().starts_with("som_"));
471        assert!(id1.value().starts_with("som_"));
472        assert_eq!(id1.shard(), owner.shard());
473
474        // lexographically ordered
475        std::thread::sleep(std::time::Duration::from_millis(2));
476
477        let id2 = SomeId::generate(&owner);
478        assert!(id2 > id1);
479        assert!(id2.timestamp_ms().unwrap() > id1.timestamp_ms().unwrap());
480        assert_eq!(id2.shard(), owner.shard());
481
482        // invalid Ids.
483        let id1 = SomeId::from("nothing_1234".into());
484        assert!(id1.validated().is_err());
485        Ok(())
486    }
487}