Skip to main content

filigree/
object_id.rs

1use std::{marker::PhantomData, str::FromStr};
2
3use base64::{
4    display::Base64Display,
5    engine::{general_purpose::URL_SAFE_NO_PAD, GeneralPurpose},
6    Engine,
7};
8use schemars::JsonSchema;
9use sqlx::{postgres::PgTypeInfo, Database};
10use thiserror::Error;
11use uuid::Uuid;
12
13/// Create a new ObjectId type. This automatically implements the prefix structure and creates
14/// a type alias for the type.
15#[macro_export]
16macro_rules! make_object_id {
17    ($typ:ident, $prefix:ident) => {
18        mod $prefix {
19            #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
20            pub struct $typ;
21            impl $crate::object_id::ObjectIdPrefix for $typ {
22                fn prefix() -> &'static str {
23                    stringify!($prefix)
24                }
25            }
26        }
27
28        /// The ObjectId type alias for this model.
29        pub type $typ = $crate::object_id::ObjectId<$prefix::$typ>;
30    };
31}
32
33/// An error related to parsing an ObjectId
34#[derive(Debug, Error)]
35pub enum ObjectIdError {
36    /// The prefix in the parsed ID did not match the expected prefix
37    #[error("Invalid ID prefix, expected {0}")]
38    InvalidPrefix(&'static str),
39
40    /// Some other parsing error, such as invalid base64
41    #[error("Failed to decode object ID")]
42    DecodeFailure,
43}
44
45/// An object that provides a the prefix for a serialized ObjectId.
46pub trait ObjectIdPrefix:
47    Clone + Copy + Eq + PartialEq + PartialOrd + Ord + std::hash::Hash
48{
49    /// The short prefix for this ID type
50    fn prefix() -> &'static str;
51}
52
53/// A type that is internally stored as a UUID but externally as a
54/// more accessible string with a prefix indicating its type. This uses
55/// UUID v7 so that the output will be lexicographically sortable.
56#[derive(Hash, PartialOrd, Ord, Eq)]
57pub struct ObjectId<PREFIX: ObjectIdPrefix>(pub Uuid, PhantomData<PREFIX>);
58
59impl<PREFIX: ObjectIdPrefix> JsonSchema for ObjectId<PREFIX> {
60    fn schema_name() -> String {
61        format!("ObjectId<{}>", PREFIX::prefix())
62    }
63
64    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
65        String::json_schema(gen)
66    }
67
68    fn is_referenceable() -> bool {
69        false
70    }
71
72    fn schema_id() -> std::borrow::Cow<'static, str> {
73        format!(concat!(module_path!(), "::ObjectId<{}>"), PREFIX::prefix()).into()
74    }
75}
76
77impl<PREFIX: ObjectIdPrefix> Clone for ObjectId<PREFIX> {
78    fn clone(&self) -> Self {
79        Self(self.0, PhantomData)
80    }
81}
82
83impl<PREFIX: ObjectIdPrefix> Copy for ObjectId<PREFIX> {}
84
85impl<PREFIX: ObjectIdPrefix> ObjectId<PREFIX> {
86    /// Create a new ObjectId with a timestamp of now
87    pub fn new() -> Self {
88        Self(uuid::Uuid::now_v7(), PhantomData)
89    }
90
91    /// Create a new ObjectId from a UUID
92    pub const fn from_uuid(u: Uuid) -> Self {
93        Self(u, PhantomData)
94    }
95
96    /// Return the inner Uuid
97    pub const fn into_inner(self) -> Uuid {
98        self.0
99    }
100
101    /// Return a reference to the inner Uuid
102    pub const fn as_uuid(&self) -> &Uuid {
103        &self.0
104    }
105
106    /// Return an ObjectId corresponding to the "all zeroes" UUID
107    pub const fn nil() -> Self {
108        Self(Uuid::nil(), PhantomData)
109    }
110
111    /// Writes the UUID portion of the object ID, without the prefix
112    pub fn display_without_prefix(&self) -> Base64Display<GeneralPurpose> {
113        base64::display::Base64Display::new(self.0.as_bytes(), &URL_SAFE_NO_PAD)
114    }
115}
116
117impl<PREFIX: ObjectIdPrefix> Default for ObjectId<PREFIX> {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl<PREFIX: ObjectIdPrefix> PartialEq for ObjectId<PREFIX> {
124    fn eq(&self, other: &Self) -> bool {
125        self.0 == other.0
126    }
127}
128
129impl<PREFIX: ObjectIdPrefix> PartialEq<Uuid> for ObjectId<PREFIX> {
130    fn eq(&self, other: &Uuid) -> bool {
131        &self.0 == other
132    }
133}
134
135impl<PREFIX: ObjectIdPrefix> AsRef<Uuid> for ObjectId<PREFIX> {
136    fn as_ref(&self) -> &Uuid {
137        &self.0
138    }
139}
140
141impl<PREFIX: ObjectIdPrefix> From<Uuid> for ObjectId<PREFIX> {
142    fn from(u: Uuid) -> Self {
143        Self(u, PhantomData)
144    }
145}
146
147impl<PREFIX: ObjectIdPrefix> From<ObjectId<PREFIX>> for Uuid {
148    fn from(data: ObjectId<PREFIX>) -> Self {
149        data.0
150    }
151}
152
153impl<PREFIX: ObjectIdPrefix> std::fmt::Debug for ObjectId<PREFIX> {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        f.debug_tuple("ObjectId")
156            .field(&self.to_string())
157            .field(&self.0)
158            .finish()
159    }
160}
161
162impl<PREFIX: ObjectIdPrefix> std::fmt::Display for ObjectId<PREFIX> {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.write_str(PREFIX::prefix())?;
165        self.display_without_prefix().fmt(f)
166    }
167}
168
169fn decode_suffix(s: &str) -> Result<Uuid, ObjectIdError> {
170    let bytes = URL_SAFE_NO_PAD
171        .decode(s)
172        .map_err(|_| ObjectIdError::DecodeFailure)?;
173    Uuid::from_slice(&bytes).map_err(|_| ObjectIdError::DecodeFailure)
174}
175
176impl<PREFIX: ObjectIdPrefix> FromStr for ObjectId<PREFIX> {
177    type Err = ObjectIdError;
178
179    fn from_str(s: &str) -> Result<Self, Self::Err> {
180        let expected_prefix = PREFIX::prefix();
181        if !s.starts_with(expected_prefix) {
182            return Err(ObjectIdError::InvalidPrefix(expected_prefix));
183        }
184
185        decode_suffix(&s[expected_prefix.len()..]).map(Self::from_uuid)
186    }
187}
188
189/// Serialize into string form with the prefix
190impl<PREFIX: ObjectIdPrefix> serde::Serialize for ObjectId<PREFIX> {
191    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
192    where
193        S: serde::Serializer,
194    {
195        let s = self.to_string();
196        serializer.serialize_str(&s)
197    }
198}
199
200struct ObjectIdVisitor<PREFIX: ObjectIdPrefix>(PhantomData<PREFIX>);
201
202impl<PREFIX: ObjectIdPrefix> Default for ObjectIdVisitor<PREFIX> {
203    fn default() -> Self {
204        Self(Default::default())
205    }
206}
207
208impl<'de, PREFIX: ObjectIdPrefix> serde::de::Visitor<'de> for ObjectIdVisitor<PREFIX> {
209    type Value = ObjectId<PREFIX>;
210
211    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
212        formatter.write_str("an object ID starting with ")?;
213        formatter.write_str(PREFIX::prefix())
214    }
215
216    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
217    where
218        E: serde::de::Error,
219    {
220        match Self::Value::from_str(v) {
221            Ok(id) => Ok(id),
222            Err(e) => {
223                // See if it's in UUID format instead of the encoded format. This mostly happens when
224                // deserializing from a JSON object generated in Postgres with jsonb_build_object.
225                Uuid::from_str(v)
226                    .map(ObjectId::<PREFIX>::from_uuid)
227                    // Return the more descriptive original error instead of the UUID parsing error
228                    .map_err(|_| e)
229            }
230        }
231        .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))
232    }
233}
234
235/// Deserialize from string form with the prefix.
236impl<'de, PREFIX: ObjectIdPrefix> serde::Deserialize<'de> for ObjectId<PREFIX> {
237    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
238    where
239        D: serde::Deserializer<'de>,
240    {
241        deserializer.deserialize_str(ObjectIdVisitor::default())
242    }
243}
244
245/// Store and retrieve in Postgres as a raw UUID
246impl<PREFIX: ObjectIdPrefix> sqlx::Type<sqlx::Postgres> for ObjectId<PREFIX> {
247    fn type_info() -> <sqlx::Postgres as Database>::TypeInfo {
248        <sqlx::types::Uuid as sqlx::Type<sqlx::Postgres>>::type_info()
249    }
250}
251
252impl<PREFIX: ObjectIdPrefix> sqlx::postgres::PgHasArrayType for ObjectId<PREFIX> {
253    fn array_type_info() -> PgTypeInfo {
254        <sqlx::types::Uuid as sqlx::postgres::PgHasArrayType>::array_type_info()
255    }
256}
257
258impl<'q, PREFIX: ObjectIdPrefix> sqlx::Encode<'q, sqlx::Postgres> for ObjectId<PREFIX> {
259    fn encode_by_ref(
260        &self,
261        buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
262    ) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync>> {
263        <sqlx::types::Uuid as sqlx::Encode<'_, sqlx::Postgres>>::encode_by_ref(&self.0, buf)
264    }
265}
266
267impl<'r, PREFIX: ObjectIdPrefix> sqlx::Decode<'r, sqlx::Postgres> for ObjectId<PREFIX> {
268    fn decode(
269        value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
270    ) -> Result<Self, sqlx::error::BoxDynError> {
271        let u = <sqlx::types::Uuid as sqlx::Decode<'r, sqlx::Postgres>>::decode(value)?;
272        Ok(Self(u, PhantomData))
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use axum::{extract::Path, response::IntoResponse, Router};
279
280    use super::*;
281
282    make_object_id!(TeamId, tm);
283
284    #[test]
285    fn to_from_str() {
286        let id = TeamId::new();
287
288        let s = id.to_string();
289        let id2 = TeamId::from_str(&s).unwrap();
290        assert_eq!(id, id2, "ID converts to string and back");
291    }
292
293    #[test]
294    fn serde() {
295        let id = TeamId::new();
296        let json_str = serde_json::to_string(&id).unwrap();
297        let id2: TeamId = serde_json::from_str(&json_str).unwrap();
298        drop(json_str);
299        assert_eq!(id, id2, "Value serializes and deserializes to itself");
300    }
301
302    #[test]
303    fn can_use_in_axum_path() {
304        async fn get_id(Path(_id): Path<TeamId>) -> impl IntoResponse {
305            "ok"
306        }
307
308        let _ = Router::<()>::new().route("/:id", axum::routing::get(get_id));
309    }
310}