1use std::{marker::PhantomData, str::FromStr};
2
3use base64::{
4 display::Base64Display,
5 engine::{general_purpose::URL_SAFE_NO_PAD, GeneralPurpose},
6 Engine,
7};
8use schemars::JsonSchema;
9use sqlx::{postgres::PgTypeInfo, Database};
10use thiserror::Error;
11use uuid::Uuid;
12
13#[macro_export]
16macro_rules! make_object_id {
17 ($typ:ident, $prefix:ident) => {
18 mod $prefix {
19 #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
20 pub struct $typ;
21 impl $crate::object_id::ObjectIdPrefix for $typ {
22 fn prefix() -> &'static str {
23 stringify!($prefix)
24 }
25 }
26 }
27
28 pub type $typ = $crate::object_id::ObjectId<$prefix::$typ>;
30 };
31}
32
33#[derive(Debug, Error)]
35pub enum ObjectIdError {
36 #[error("Invalid ID prefix, expected {0}")]
38 InvalidPrefix(&'static str),
39
40 #[error("Failed to decode object ID")]
42 DecodeFailure,
43}
44
45pub trait ObjectIdPrefix:
47 Clone + Copy + Eq + PartialEq + PartialOrd + Ord + std::hash::Hash
48{
49 fn prefix() -> &'static str;
51}
52
53#[derive(Hash, PartialOrd, Ord, Eq)]
57pub struct ObjectId<PREFIX: ObjectIdPrefix>(pub Uuid, PhantomData<PREFIX>);
58
59impl<PREFIX: ObjectIdPrefix> JsonSchema for ObjectId<PREFIX> {
60 fn schema_name() -> String {
61 format!("ObjectId<{}>", PREFIX::prefix())
62 }
63
64 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
65 String::json_schema(gen)
66 }
67
68 fn is_referenceable() -> bool {
69 false
70 }
71
72 fn schema_id() -> std::borrow::Cow<'static, str> {
73 format!(concat!(module_path!(), "::ObjectId<{}>"), PREFIX::prefix()).into()
74 }
75}
76
77impl<PREFIX: ObjectIdPrefix> Clone for ObjectId<PREFIX> {
78 fn clone(&self) -> Self {
79 Self(self.0, PhantomData)
80 }
81}
82
83impl<PREFIX: ObjectIdPrefix> Copy for ObjectId<PREFIX> {}
84
85impl<PREFIX: ObjectIdPrefix> ObjectId<PREFIX> {
86 pub fn new() -> Self {
88 Self(uuid::Uuid::now_v7(), PhantomData)
89 }
90
91 pub const fn from_uuid(u: Uuid) -> Self {
93 Self(u, PhantomData)
94 }
95
96 pub const fn into_inner(self) -> Uuid {
98 self.0
99 }
100
101 pub const fn as_uuid(&self) -> &Uuid {
103 &self.0
104 }
105
106 pub const fn nil() -> Self {
108 Self(Uuid::nil(), PhantomData)
109 }
110
111 pub fn display_without_prefix(&self) -> Base64Display<GeneralPurpose> {
113 base64::display::Base64Display::new(self.0.as_bytes(), &URL_SAFE_NO_PAD)
114 }
115}
116
117impl<PREFIX: ObjectIdPrefix> Default for ObjectId<PREFIX> {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl<PREFIX: ObjectIdPrefix> PartialEq for ObjectId<PREFIX> {
124 fn eq(&self, other: &Self) -> bool {
125 self.0 == other.0
126 }
127}
128
129impl<PREFIX: ObjectIdPrefix> PartialEq<Uuid> for ObjectId<PREFIX> {
130 fn eq(&self, other: &Uuid) -> bool {
131 &self.0 == other
132 }
133}
134
135impl<PREFIX: ObjectIdPrefix> AsRef<Uuid> for ObjectId<PREFIX> {
136 fn as_ref(&self) -> &Uuid {
137 &self.0
138 }
139}
140
141impl<PREFIX: ObjectIdPrefix> From<Uuid> for ObjectId<PREFIX> {
142 fn from(u: Uuid) -> Self {
143 Self(u, PhantomData)
144 }
145}
146
147impl<PREFIX: ObjectIdPrefix> From<ObjectId<PREFIX>> for Uuid {
148 fn from(data: ObjectId<PREFIX>) -> Self {
149 data.0
150 }
151}
152
153impl<PREFIX: ObjectIdPrefix> std::fmt::Debug for ObjectId<PREFIX> {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_tuple("ObjectId")
156 .field(&self.to_string())
157 .field(&self.0)
158 .finish()
159 }
160}
161
162impl<PREFIX: ObjectIdPrefix> std::fmt::Display for ObjectId<PREFIX> {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.write_str(PREFIX::prefix())?;
165 self.display_without_prefix().fmt(f)
166 }
167}
168
169fn decode_suffix(s: &str) -> Result<Uuid, ObjectIdError> {
170 let bytes = URL_SAFE_NO_PAD
171 .decode(s)
172 .map_err(|_| ObjectIdError::DecodeFailure)?;
173 Uuid::from_slice(&bytes).map_err(|_| ObjectIdError::DecodeFailure)
174}
175
176impl<PREFIX: ObjectIdPrefix> FromStr for ObjectId<PREFIX> {
177 type Err = ObjectIdError;
178
179 fn from_str(s: &str) -> Result<Self, Self::Err> {
180 let expected_prefix = PREFIX::prefix();
181 if !s.starts_with(expected_prefix) {
182 return Err(ObjectIdError::InvalidPrefix(expected_prefix));
183 }
184
185 decode_suffix(&s[expected_prefix.len()..]).map(Self::from_uuid)
186 }
187}
188
189impl<PREFIX: ObjectIdPrefix> serde::Serialize for ObjectId<PREFIX> {
191 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
192 where
193 S: serde::Serializer,
194 {
195 let s = self.to_string();
196 serializer.serialize_str(&s)
197 }
198}
199
200struct ObjectIdVisitor<PREFIX: ObjectIdPrefix>(PhantomData<PREFIX>);
201
202impl<PREFIX: ObjectIdPrefix> Default for ObjectIdVisitor<PREFIX> {
203 fn default() -> Self {
204 Self(Default::default())
205 }
206}
207
208impl<'de, PREFIX: ObjectIdPrefix> serde::de::Visitor<'de> for ObjectIdVisitor<PREFIX> {
209 type Value = ObjectId<PREFIX>;
210
211 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
212 formatter.write_str("an object ID starting with ")?;
213 formatter.write_str(PREFIX::prefix())
214 }
215
216 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
217 where
218 E: serde::de::Error,
219 {
220 match Self::Value::from_str(v) {
221 Ok(id) => Ok(id),
222 Err(e) => {
223 Uuid::from_str(v)
226 .map(ObjectId::<PREFIX>::from_uuid)
227 .map_err(|_| e)
229 }
230 }
231 .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(v), &self))
232 }
233}
234
235impl<'de, PREFIX: ObjectIdPrefix> serde::Deserialize<'de> for ObjectId<PREFIX> {
237 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
238 where
239 D: serde::Deserializer<'de>,
240 {
241 deserializer.deserialize_str(ObjectIdVisitor::default())
242 }
243}
244
245impl<PREFIX: ObjectIdPrefix> sqlx::Type<sqlx::Postgres> for ObjectId<PREFIX> {
247 fn type_info() -> <sqlx::Postgres as Database>::TypeInfo {
248 <sqlx::types::Uuid as sqlx::Type<sqlx::Postgres>>::type_info()
249 }
250}
251
252impl<PREFIX: ObjectIdPrefix> sqlx::postgres::PgHasArrayType for ObjectId<PREFIX> {
253 fn array_type_info() -> PgTypeInfo {
254 <sqlx::types::Uuid as sqlx::postgres::PgHasArrayType>::array_type_info()
255 }
256}
257
258impl<'q, PREFIX: ObjectIdPrefix> sqlx::Encode<'q, sqlx::Postgres> for ObjectId<PREFIX> {
259 fn encode_by_ref(
260 &self,
261 buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>,
262 ) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync>> {
263 <sqlx::types::Uuid as sqlx::Encode<'_, sqlx::Postgres>>::encode_by_ref(&self.0, buf)
264 }
265}
266
267impl<'r, PREFIX: ObjectIdPrefix> sqlx::Decode<'r, sqlx::Postgres> for ObjectId<PREFIX> {
268 fn decode(
269 value: <sqlx::Postgres as sqlx::Database>::ValueRef<'r>,
270 ) -> Result<Self, sqlx::error::BoxDynError> {
271 let u = <sqlx::types::Uuid as sqlx::Decode<'r, sqlx::Postgres>>::decode(value)?;
272 Ok(Self(u, PhantomData))
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use axum::{extract::Path, response::IntoResponse, Router};
279
280 use super::*;
281
282 make_object_id!(TeamId, tm);
283
284 #[test]
285 fn to_from_str() {
286 let id = TeamId::new();
287
288 let s = id.to_string();
289 let id2 = TeamId::from_str(&s).unwrap();
290 assert_eq!(id, id2, "ID converts to string and back");
291 }
292
293 #[test]
294 fn serde() {
295 let id = TeamId::new();
296 let json_str = serde_json::to_string(&id).unwrap();
297 let id2: TeamId = serde_json::from_str(&json_str).unwrap();
298 drop(json_str);
299 assert_eq!(id, id2, "Value serializes and deserializes to itself");
300 }
301
302 #[test]
303 fn can_use_in_axum_path() {
304 async fn get_id(Path(_id): Path<TeamId>) -> impl IntoResponse {
305 "ok"
306 }
307
308 let _ = Router::<()>::new().route("/:id", axum::routing::get(get_id));
309 }
310}