1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::Arc;
5
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use uuid::Uuid;
8
9#[cfg(feature = "diesel")]
10use diesel::deserialize::{self, FromSql, Queryable};
11#[cfg(feature = "diesel")]
12use diesel::expression::AsExpression;
13#[cfg(feature = "diesel")]
14use diesel::pg::{Pg, PgValue};
15#[cfg(feature = "diesel")]
16use diesel::serialize::{self, Output, ToSql};
17#[cfg(feature = "diesel")]
18use diesel::sql_types::BigInt;
19
20#[cfg(feature = "sqlx")]
21use sqlx::{postgres::PgTypeInfo, Postgres, Type};
22
23use crate::{Codec, Config};
24
25thread_local! {
26 static CODEC_CACHE: RefCell<HashMap<String, Arc<Codec>>> = RefCell::new(HashMap::new());
27}
28
29pub fn clear_codec_cache() {
32 CODEC_CACHE.with(|cache| {
33 cache.borrow_mut().clear();
34 });
35}
36
37fn get_or_create_codec(name: &str) -> Arc<Codec> {
38 CODEC_CACHE.with(|cache| {
39 let mut cache = cache.borrow_mut();
40 if let Some(codec) = cache.get(name) {
41 codec.clone()
42 } else {
43 let codec = Arc::new(Codec::new(name, &Config::effective().unwrap()));
44 cache.insert(name.to_string(), codec.clone());
45 codec
46 }
47 })
48}
49
50pub trait TypeMarker: std::fmt::Debug {
51 fn name() -> &'static str;
52}
53
54#[cfg_attr(feature = "diesel", derive(AsExpression))]
89#[cfg_attr(feature = "diesel", diesel(sql_type = BigInt))]
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub struct Field<T: TypeMarker> {
92 id: u64,
93 _marker: std::marker::PhantomData<T>,
94}
95
96impl<T: TypeMarker + std::hash::Hash> std::hash::Hash for Field<T> {
98 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
99 self.id.hash(state);
100 self._marker.hash(state);
101 }
102}
103
104impl<T: TypeMarker> From<Field<T>> for u64 {
105 fn from(field: Field<T>) -> Self {
107 field.id
108 }
109}
110
111impl<T: TypeMarker> fmt::Display for Field<T> {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 let codec_name = T::name();
114 let codec = get_or_create_codec(codec_name);
115 write!(f, "{}", codec.encode(self.id))
116 }
117}
118
119impl<T: TypeMarker> Field<T> {
120 pub fn from(id: u64) -> Self {
124 Field {
125 id,
126 _marker: std::marker::PhantomData,
127 }
128 }
129
130 pub fn encode_uuid(self) -> Uuid {
132 let codec_name = T::name();
133 let codec = get_or_create_codec(codec_name);
134 codec.encode_uuid(self.id)
135 }
136
137 pub fn decode_uuid(uuid: Uuid) -> Result<Self, crate::codec::Error> {
139 let codec_name = T::name();
140 let codec = get_or_create_codec(codec_name);
141 let id = codec.decode_uuid(uuid)?;
142 Ok(Field::from(id))
143 }
144}
145
146impl<T: TypeMarker> Serialize for Field<T> {
148 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
149 where
150 S: Serializer,
151 {
152 let codec_name = T::name();
153 let codec = get_or_create_codec(codec_name);
154 serializer.serialize_str(&codec.encode(self.id))
155 }
156}
157
158impl<'de, T: TypeMarker> Deserialize<'de> for Field<T> {
160 fn deserialize<D>(deserializer: D) -> Result<Field<T>, D::Error>
161 where
162 D: Deserializer<'de>,
163 {
164 use serde::de::Error;
165
166 let text = String::deserialize(deserializer)?;
167 let codec_name = T::name();
168 let codec = get_or_create_codec(codec_name);
169 let id = codec.decode(&text).map_err(Error::custom)?;
170 Ok(Field {
171 id,
172 _marker: std::marker::PhantomData,
173 })
174 }
175}
176
177impl<T: TypeMarker> std::str::FromStr for Field<T> {
178 type Err = crate::codec::Error;
179
180 fn from_str(s: &str) -> Result<Self, Self::Err> {
181 let codec_name = T::name();
182 let codec = get_or_create_codec(codec_name);
183 let id = codec.decode(s)?;
184 Ok(Field::from(id))
185 }
186}
187
188#[cfg(feature = "diesel")]
190impl<T: TypeMarker> ToSql<BigInt, Pg> for Field<T> {
191 fn to_sql(&self, out: &mut Output<'_, '_, Pg>) -> serialize::Result {
192 <i64 as ToSql<BigInt, Pg>>::to_sql(&(self.id as i64), &mut out.reborrow())
193 }
194}
195
196#[cfg(feature = "diesel")]
197impl<T: TypeMarker> FromSql<BigInt, Pg> for Field<T> {
198 fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
199 let id = <i64 as FromSql<BigInt, Pg>>::from_sql(bytes)?;
200 Ok(Field::from(id as u64))
201 }
202}
203
204#[cfg(feature = "diesel")]
205impl<T> Queryable<BigInt, Pg> for Field<T>
206where
207 T: TypeMarker,
208{
209 type Row = <i64 as Queryable<BigInt, Pg>>::Row;
210
211 fn build(row: Self::Row) -> deserialize::Result<Self> {
212 let id = i64::build(row)?;
213 Ok(Field::from(id as u64))
214 }
215}
216
217#[cfg(feature = "sqlx")]
220impl<T: TypeMarker> Type<Postgres> for Field<T> {
222 fn type_info() -> PgTypeInfo {
223 <i64 as Type<Postgres>>::type_info()
224 }
225}
226
227#[cfg(feature = "sqlx")]
228impl<'q, T: TypeMarker> sqlx::Encode<'q, Postgres> for Field<T> {
230 fn encode_by_ref(
231 &self,
232 buf: &mut sqlx::postgres::PgArgumentBuffer,
233 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
234 let id = self.id as i64;
235 <i64 as sqlx::Encode<Postgres>>::encode_by_ref(&id, buf)
236 }
237}
238
239#[cfg(feature = "sqlx")]
240impl<'r, T: TypeMarker> sqlx::Decode<'r, Postgres> for Field<T> {
242 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
243 let id = <i64 as sqlx::Decode<Postgres>>::decode(value)?;
244 Ok(Field {
245 id: id as u64,
246 _marker: std::marker::PhantomData,
247 })
248 }
249}