1#![doc = include_str!("../tests/readme_example.rs")]
13#![cfg_attr(
17 feature = "document-features",
18 doc = ::document_features::document_features!()
19)]
20#![cfg_attr(docsrs, feature(doc_cfg))]
21#![allow(clippy::result_large_err)]
22
23use diesel::query_dsl::methods::SelectDsl;
24use diesel::{RunQueryDsl, query_dsl::methods::LoadQuery};
25#[doc(inline)]
26pub use diesel_enums_proc_macro::*;
27use owo_colors::OwoColorize;
28use std::{
29 collections::HashMap,
30 fmt::{Debug, Display},
31 hash::Hash,
32};
33use thiserror::Error;
34
35mod test_runners;
36pub use test_runners::*;
37
38pub trait DbEnum: Debug + Hash + Eq + Copy + Sized {
45 #[doc(hidden)]
46 const VARIANT_MAPPINGS: &[(Self::IdType, &str)];
47 #[doc(hidden)]
48 const RUST_ENUM_NAME: &str;
49 #[doc(hidden)]
50 const TABLE_NAME: &str;
51
52 #[doc(hidden)]
54 type Table: diesel::Table + SelectDsl<(Self::IdColumn, Self::NameColumn)> + Default;
55 #[doc(hidden)]
56 type IdColumn: diesel::Column + Default;
57 #[doc(hidden)]
58 type NameColumn: diesel::Column + Default;
59 #[doc(hidden)]
60 type IdType: Copy + Into<i64> + PartialEq + 'static;
61
62 fn db_name(self) -> &'static str;
64 fn from_db_name(name: &str) -> Result<Self, UnknownVariantError>;
66
67 fn db_id(self) -> Self::IdType;
69 fn from_db_id(id: Self::IdType) -> Result<Self, UnknownIdError>;
71
72 #[cold]
76 #[inline(never)]
77 #[track_caller]
78 fn check_db_mapping<'query, Conn>(conn: &mut Conn) -> Result<(), DbEnumError>
79 where
80 <Self::Table as SelectDsl<(Self::IdColumn, Self::NameColumn)>>::Output:
82 LoadQuery<'query, Conn, (Self::IdType, String)>,
83 Conn: diesel::Connection,
84 {
85 let db_variants: Vec<(Self::IdType, String)> = Self::Table::default()
86 .select((Self::IdColumn::default(), Self::NameColumn::default()))
87 .load(conn)
88 .unwrap_or_else(|e| {
89 panic!(
90 "\n ❌ Failed to load the variants for the rust enum `{}` from the database table `{}`: {}",
91 Self::RUST_ENUM_NAME,
92 Self::TABLE_NAME,
93 e
94 )
95 });
96
97 let mut error =
98 DbEnumError::new(Self::RUST_ENUM_NAME, DbEnumSource::Table(Self::TABLE_NAME));
99
100 let mut rust_variants_set: HashMap<&str, Self::IdType> = Self::VARIANT_MAPPINGS
101 .iter()
102 .map(|(id, name)| (*name, *id))
103 .collect();
104
105 for (id, name) in db_variants {
106 let rust_variant_id = if let Some(id) = rust_variants_set.remove(name.as_str()) {
107 id
108 } else {
109 error.missing_from_rust.push(name);
110 continue;
111 };
112
113 if id != rust_variant_id {
114 error.id_mismatches.push(IdMismatch {
115 variant: name,
116 expected: id.into(),
117 found: rust_variant_id.into(),
118 });
119 }
120 }
121
122 error.missing_from_db.extend(
123 rust_variants_set
124 .into_keys()
125 .map(|v| v.to_string()),
126 );
127
128 if error.is_clean() { Ok(()) } else { Err(error) }
129 }
130}
131
132#[cfg(feature = "postgres")]
133use diesel::connection::LoadConnection;
134#[cfg(feature = "postgres")]
135pub trait PgEnum: Debug + Sized {
148 #[doc(hidden)]
149 const VARIANT_MAPPINGS: &[&str];
150 #[doc(hidden)]
151 const RUST_ENUM_NAME: &str;
152 #[doc(hidden)]
153 const PG_ENUM_NAME: &str;
154
155 fn db_name(self) -> &'static str;
157 fn from_db_name(name: &str) -> Result<Self, UnknownVariantError>;
159
160 #[cold]
164 #[inline(never)]
165 #[track_caller]
166 fn check_db_mapping<Conn>(conn: &mut Conn) -> Result<(), DbEnumError>
167 where
168 Conn: diesel::Connection<Backend = diesel::pg::Pg> + LoadConnection,
169 {
170 use diesel::RunQueryDsl;
171 use std::collections::HashSet;
172
173 let mut error = DbEnumError::new(
174 Self::RUST_ENUM_NAME,
175 DbEnumSource::CustomEnum(Self::PG_ENUM_NAME),
176 );
177
178 let mut variants_set: HashSet<&str> = Self::VARIANT_MAPPINGS.iter().copied().collect();
179
180 let pg_variants: Vec<DeserializedPgEnum> = diesel::sql_query(format!(
181 "SELECT unnest(enum_range(NULL::{})) AS variant",
182 Self::PG_ENUM_NAME
183 ))
184 .load(conn)
185 .unwrap_or_else(|_| {
186 panic!(
187 "\n ❌ Failed to load the variants for the postgres enum {}",
188 Self::PG_ENUM_NAME
189 )
190 });
191
192 for variant in pg_variants {
193 let variant_name = variant.variant;
194
195 let was_present = variants_set.remove(variant_name.as_str());
196
197 if !was_present {
198 error.missing_from_rust.push(variant_name);
199 }
200 }
201
202 error
203 .missing_from_db
204 .extend(variants_set.into_iter().map(|s| s.to_string()));
205
206 if error.is_clean() { Ok(()) } else { Err(error) }
207 }
208}
209
210#[cfg(feature = "postgres")]
211#[derive(diesel::deserialize::QueryableByName)]
212struct DeserializedPgEnum {
213 #[diesel(sql_type = diesel::sql_types::Text)]
214 variant: String,
215}
216
217#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
219#[error("No variant named `{variant}` exists for the enum `{enum_name}`")]
220pub struct UnknownVariantError {
221 pub enum_name: &'static str,
222 pub variant: String,
223}
224
225#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
227#[error("The id `{id}` does not match any variant for the enum `{enum_name}`")]
228pub struct UnknownIdError {
229 pub enum_name: &'static str,
230 pub id: i64,
231}
232
233#[derive(Clone, Debug, PartialEq, Eq, Hash)]
235pub struct IdMismatch {
236 variant: String,
237 pub expected: i64,
238 pub found: i64,
239}
240
241impl IdMismatch {
242 pub fn variant(&self) -> &str {
244 &self.variant
245 }
246}
247
248#[derive(Clone, Debug, PartialEq, Eq, Hash, Error)]
252pub struct DbEnumError {
253 rust_enum_name: &'static str,
254 db_source: DbEnumSource,
255 pub missing_from_db: Vec<String>,
257 pub missing_from_rust: Vec<String>,
259 pub id_mismatches: Vec<IdMismatch>,
263}
264
265impl DbEnumError {
266 pub(crate) fn new(rust_enum: &'static str, db_source: DbEnumSource) -> Self {
267 Self {
268 rust_enum_name: rust_enum,
269 db_source,
270 missing_from_db: vec![],
271 missing_from_rust: vec![],
272 id_mismatches: vec![],
273 }
274 }
275
276 pub(crate) fn is_clean(&self) -> bool {
277 self.missing_from_db.is_empty()
278 && self.missing_from_rust.is_empty()
279 && self.id_mismatches.is_empty()
280 }
281}
282
283#[derive(Clone, Debug, PartialEq, Eq, Hash)]
284pub(crate) enum DbEnumSource {
285 CustomEnum(&'static str),
286 Table(&'static str),
287}
288
289impl DbEnumSource {
290 pub(crate) fn name(&self) -> &str {
291 match self {
292 Self::CustomEnum(name) => name,
293 Self::Table(name) => name,
294 }
295 }
296
297 pub(crate) fn db_type(&self) -> &str {
298 match self {
299 Self::CustomEnum(_) => "enum",
300 Self::Table { .. } => "table",
301 }
302 }
303}
304
305#[doc(hidden)]
306pub mod __macro_fallbacks {
307 #[derive(Default)]
308 pub struct DummyTable;
309 #[derive(Default)]
310 pub struct DummyColumn;
311
312 impl diesel::query_builder::Query for DummyTable {
313 type SqlType = i64;
314 }
315
316 impl diesel::QuerySource for DummyTable {
317 type DefaultSelection = DummyColumn;
318 type FromClause = DummyColumn;
319
320 fn default_selection(&self) -> Self::DefaultSelection {
321 unimplemented!()
322 }
323
324 fn from_clause(&self) -> Self::FromClause {
325 unimplemented!()
326 }
327 }
328
329 impl diesel::Column for DummyColumn {
330 type Table = DummyTable;
331
332 const NAME: &'static str = "error";
333 }
334
335 impl diesel::SelectableExpression<DummyTable> for DummyColumn {}
336
337 impl diesel::AppearsOnTable<DummyTable> for DummyColumn {}
338
339 impl diesel::Expression for DummyColumn {
340 type SqlType = diesel::sql_types::Integer;
341 }
342
343 impl diesel::expression::ValidGrouping<()> for DummyColumn {
344 type IsAggregate = diesel::expression::is_aggregate::Yes;
345 }
346
347 impl diesel::Table for DummyTable {
348 type AllColumns = DummyColumn;
349 type PrimaryKey = DummyColumn;
350
351 fn primary_key(&self) -> Self::PrimaryKey {
352 unimplemented!()
353 }
354
355 fn all_columns() -> Self::AllColumns {
356 unimplemented!()
357 }
358 }
359}
360
361impl Display for DbEnumError {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 let _ = writeln!(
364 f,
365 "\n ❌ The rust enum `{}` and the database {} `{}` are out of sync: ",
366 self.rust_enum_name.bright_yellow(),
367 self.db_source.db_type(),
368 self.db_source.name().bright_cyan()
369 );
370
371 if !self.missing_from_db.is_empty() {
372 let _ = writeln!(
373 f,
374 "\n - Variants missing from the {}:",
375 "database".bright_cyan()
376 );
377
378 for variant in &self.missing_from_db {
379 let _ = writeln!(f, " • {variant}");
380 }
381 }
382
383 if !self.missing_from_rust.is_empty() {
384 let _ = writeln!(
385 f,
386 "\n - Variants missing from the {}:",
387 "rust enum".bright_yellow()
388 );
389
390 for variant in &self.missing_from_rust {
391 writeln!(f, " • {variant}").unwrap();
392 }
393 }
394
395 if !self.id_mismatches.is_empty() {
396 for IdMismatch {
397 variant,
398 expected,
399 found,
400 } in &self.id_mismatches
401 {
402 let _ = writeln!(
403 f,
404 "\n - Wrong id mapping for `{}`",
405 variant.bright_yellow()
406 );
407 let _ = writeln!(f, " Expected: {}", expected.bright_green());
408 let _ = writeln!(f, " Found: {}", found.bright_red());
409 }
410 }
411
412 Ok(())
413 }
414}