Skip to main content

diesel_enums/
lib.rs

1//! # Diesel-enums
2//!
3//! `diesel-enums` can be used to create mappings between Rust enums and database tables with fixed values, as well as custom postgres enums.
4//!
5//! It creates a seamless interface with the `diesel` API, and generates the logic to enforce the correctness of the mapping.
6//!
7//! Refer to the documentation for [`DbEnum`](macro@PgEnum) and [`PgEnum`](macro@PgEnum) to learn more about usage.
8//!
9//! # Full Example With Sqlite
10//!
11//!```rust
12#![doc = include_str!("../tests/readme_example.rs")]
13//!```
14//!
15//! # Features
16#![cfg_attr(
17		feature = "document-features",
18		doc = ::document_features::document_features!()
19)]
20#![cfg_attr(docsrs, feature(doc_cfg))]
21#![allow(clippy::result_large_err)]
22
23use diesel::query_dsl::methods::SelectDsl;
24use diesel::{RunQueryDsl, query_dsl::methods::LoadQuery};
25#[doc(inline)]
26pub use diesel_enums_proc_macro::*;
27use owo_colors::OwoColorize;
28use std::{
29	collections::HashMap,
30	fmt::{Debug, Display},
31	hash::Hash,
32};
33use thiserror::Error;
34
35mod test_runners;
36pub use test_runners::*;
37
38/// Maps a Rust enum to a database table with fixed values.
39///
40/// When derived with the corresponding macro, it implements [`FromSqlRow`](diesel::FromSqlRow), [`AsExpression`](diesel::AsExpression), [`HasTable`](diesel::associations::HasTable) and [`Identifiable`](diesel::associations::Identifiable),
41/// as well as [`FromSql`](diesel::deserialize::FromSql) and [`ToSql`](diesel::serialize::ToSql) with the target SQL type.
42///
43/// It implements the [`check_db_mapping`](DbEnum::check_db_mapping) method, which can be used to verify that the mapping with the database is valid. By default, the macro automatically generates a test that calls this method and panics on error.
44pub trait DbEnum: Debug + Hash + Eq + Copy + Sized {
45	#[doc(hidden)]
46	const VARIANT_MAPPINGS: &[(Self::IdType, &str)];
47	#[doc(hidden)]
48	const RUST_ENUM_NAME: &str;
49	#[doc(hidden)]
50	const TABLE_NAME: &str;
51
52	// Constrain that we are allowed to select these two columns from this table
53	#[doc(hidden)]
54	type Table: diesel::Table + SelectDsl<(Self::IdColumn, Self::NameColumn)> + Default;
55	#[doc(hidden)]
56	type IdColumn: diesel::Column + Default;
57	#[doc(hidden)]
58	type NameColumn: diesel::Column + Default;
59	#[doc(hidden)]
60	type IdType: Copy + Into<i64> + PartialEq + 'static;
61
62	/// Returns the database name for this variant.
63	fn db_name(self) -> &'static str;
64	/// Attempts to create an instance from a string, if this matches one of the variants in the database.
65	fn from_db_name(name: &str) -> Result<Self, UnknownVariantError>;
66
67	/// Returns the database ID for this variant.
68	fn db_id(self) -> Self::IdType;
69	/// Attempts to create an instance from a number, if this matches the ID of one of the variants in the database.
70	fn from_db_id(id: Self::IdType) -> Result<Self, UnknownIdError>;
71
72	/// Verifies that this enum is mapped correctly to a database table.
73	///
74	/// By default, the derive macro automatically generates a test that calls this method and panics on error.
75	#[cold]
76	#[inline(never)]
77	#[track_caller]
78	fn check_db_mapping<'query, Conn>(conn: &mut Conn) -> Result<(), DbEnumError>
79	where
80		// Constrain that the output of that select query can be loaded into a tuple of (Self::IdType, String)
81		<Self::Table as SelectDsl<(Self::IdColumn, Self::NameColumn)>>::Output:
82			LoadQuery<'query, Conn, (Self::IdType, String)>,
83		Conn: diesel::Connection,
84	{
85		let db_variants: Vec<(Self::IdType, String)> = Self::Table::default()
86			.select((Self::IdColumn::default(), Self::NameColumn::default()))
87			.load(conn)
88			.unwrap_or_else(|e| {
89				panic!(
90					"\n ❌ Failed to load the variants for the rust enum `{}` from the database table `{}`: {}",
91					Self::RUST_ENUM_NAME,
92					Self::TABLE_NAME,
93					e
94				)
95			});
96
97		let mut error =
98			DbEnumError::new(Self::RUST_ENUM_NAME, DbEnumSource::Table(Self::TABLE_NAME));
99
100		let mut rust_variants_set: HashMap<&str, Self::IdType> = Self::VARIANT_MAPPINGS
101			.iter()
102			.map(|(id, name)| (*name, *id))
103			.collect();
104
105		for (id, name) in db_variants {
106			let rust_variant_id = if let Some(id) = rust_variants_set.remove(name.as_str()) {
107				id
108			} else {
109				error.missing_from_rust.push(name);
110				continue;
111			};
112
113			if id != rust_variant_id {
114				error.id_mismatches.push(IdMismatch {
115					variant: name,
116					expected: id.into(),
117					found: rust_variant_id.into(),
118				});
119			}
120		}
121
122		error.missing_from_db.extend(
123			rust_variants_set
124				.into_keys()
125				.map(|v| v.to_string()),
126		);
127
128		if error.is_clean() { Ok(()) } else { Err(error) }
129	}
130}
131
132#[cfg(feature = "postgres")]
133use diesel::connection::LoadConnection;
134#[cfg(feature = "postgres")]
135/// Maps a Rust enum to a custom enum in postgres.
136///
137/// When derived with the corresponding macro, it implements [`FromSqlRow`](diesel::FromSqlRow) and [`AsExpression`](diesel::AsExpression),
138/// as well as [`FromSql`](diesel::deserialize::FromSql) and [`ToSql`](diesel::serialize::ToSql) with the target SQL type.
139///
140/// It implements the [`check_db_mapping`](DbEnum::check_db_mapping) method, which can be used to verify that the mapping with the database is valid. By default, the macro automatically generates a test that calls this method and panics on error.
141///
142/// **NOTE**: It is necessary to add the following to the `diesel.toml` configuration:
143///
144/// ```toml
145/// custom_type_derives = ["diesel::query_builder::QueryId"]
146/// ```
147pub trait PgEnum: Debug + Sized {
148	#[doc(hidden)]
149	const VARIANT_MAPPINGS: &[&str];
150	#[doc(hidden)]
151	const RUST_ENUM_NAME: &str;
152	#[doc(hidden)]
153	const PG_ENUM_NAME: &str;
154
155	/// Returns the database name for this variant.
156	fn db_name(self) -> &'static str;
157	/// Attempts to create an instance from a string, if this matches one of the variants in the database.
158	fn from_db_name(name: &str) -> Result<Self, UnknownVariantError>;
159
160	/// Verifies that this enum is mapped correctly to the database enum.
161	///
162	/// By default, the derive macro automatically generates a test that calls this method and panics on error.
163	#[cold]
164	#[inline(never)]
165	#[track_caller]
166	fn check_db_mapping<Conn>(conn: &mut Conn) -> Result<(), DbEnumError>
167	where
168		Conn: diesel::Connection<Backend = diesel::pg::Pg> + LoadConnection,
169	{
170		use diesel::RunQueryDsl;
171		use std::collections::HashSet;
172
173		let mut error = DbEnumError::new(
174			Self::RUST_ENUM_NAME,
175			DbEnumSource::CustomEnum(Self::PG_ENUM_NAME),
176		);
177
178		let mut variants_set: HashSet<&str> = Self::VARIANT_MAPPINGS.iter().copied().collect();
179
180		let pg_variants: Vec<DeserializedPgEnum> = diesel::sql_query(format!(
181			"SELECT unnest(enum_range(NULL::{})) AS variant",
182			Self::PG_ENUM_NAME
183		))
184		.load(conn)
185		.unwrap_or_else(|_| {
186			panic!(
187				"\n ❌ Failed to load the variants for the postgres enum {}",
188				Self::PG_ENUM_NAME
189			)
190		});
191
192		for variant in pg_variants {
193			let variant_name = variant.variant;
194
195			let was_present = variants_set.remove(variant_name.as_str());
196
197			if !was_present {
198				error.missing_from_rust.push(variant_name);
199			}
200		}
201
202		error
203			.missing_from_db
204			.extend(variants_set.into_iter().map(|s| s.to_string()));
205
206		if error.is_clean() { Ok(()) } else { Err(error) }
207	}
208}
209
210#[cfg(feature = "postgres")]
211#[derive(diesel::deserialize::QueryableByName)]
212struct DeserializedPgEnum {
213	#[diesel(sql_type = diesel::sql_types::Text)]
214	variant: String,
215}
216
217/// An error that can occur when trying to create an instance of a [`DbEnum`] or [`PgEnum`] from a string.
218#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
219#[error("No variant named `{variant}` exists for the enum `{enum_name}`")]
220pub struct UnknownVariantError {
221	pub enum_name: &'static str,
222	pub variant: String,
223}
224
225/// An error that can occur when trying to create an instance of a [`DbEnum`] from a number.
226#[derive(Error, Debug, Clone, PartialEq, Eq, Hash)]
227#[error("The id `{id}` does not match any variant for the enum `{enum_name}`")]
228pub struct UnknownIdError {
229	pub enum_name: &'static str,
230	pub id: i64,
231}
232
233/// Represents a mismatch between a Rust variant and a database variant.
234#[derive(Clone, Debug, PartialEq, Eq, Hash)]
235pub struct IdMismatch {
236	variant: String,
237	pub expected: i64,
238	pub found: i64,
239}
240
241impl IdMismatch {
242	/// Returns the variant's name for this mismatch.
243	pub fn variant(&self) -> &str {
244		&self.variant
245	}
246}
247
248/// An error that is produced when a rust enum does not match a database enum or table.
249///
250/// It includes the list of errors that may occur simultaneously, such as id mismatches as well as missing variants.
251#[derive(Clone, Debug, PartialEq, Eq, Hash, Error)]
252pub struct DbEnumError {
253	rust_enum_name: &'static str,
254	db_source: DbEnumSource,
255	/// The list of variants that are missing from the database.
256	pub missing_from_db: Vec<String>,
257	/// The list of variants that are missing from the Rust enum.
258	pub missing_from_rust: Vec<String>,
259	/// The list of ID mismatches between the Rust and database enum.
260	///
261	/// This is always empty for [`PgEnum`]s since they do not have IDs.
262	pub id_mismatches: Vec<IdMismatch>,
263}
264
265impl DbEnumError {
266	pub(crate) fn new(rust_enum: &'static str, db_source: DbEnumSource) -> Self {
267		Self {
268			rust_enum_name: rust_enum,
269			db_source,
270			missing_from_db: vec![],
271			missing_from_rust: vec![],
272			id_mismatches: vec![],
273		}
274	}
275
276	pub(crate) fn is_clean(&self) -> bool {
277		self.missing_from_db.is_empty()
278			&& self.missing_from_rust.is_empty()
279			&& self.id_mismatches.is_empty()
280	}
281}
282
283#[derive(Clone, Debug, PartialEq, Eq, Hash)]
284pub(crate) enum DbEnumSource {
285	CustomEnum(&'static str),
286	Table(&'static str),
287}
288
289impl DbEnumSource {
290	pub(crate) fn name(&self) -> &str {
291		match self {
292			Self::CustomEnum(name) => name,
293			Self::Table(name) => name,
294		}
295	}
296
297	pub(crate) fn db_type(&self) -> &str {
298		match self {
299			Self::CustomEnum(_) => "enum",
300			Self::Table { .. } => "table",
301		}
302	}
303}
304
305#[doc(hidden)]
306pub mod __macro_fallbacks {
307	#[derive(Default)]
308	pub struct DummyTable;
309	#[derive(Default)]
310	pub struct DummyColumn;
311
312	impl diesel::query_builder::Query for DummyTable {
313		type SqlType = i64;
314	}
315
316	impl diesel::QuerySource for DummyTable {
317		type DefaultSelection = DummyColumn;
318		type FromClause = DummyColumn;
319
320		fn default_selection(&self) -> Self::DefaultSelection {
321			unimplemented!()
322		}
323
324		fn from_clause(&self) -> Self::FromClause {
325			unimplemented!()
326		}
327	}
328
329	impl diesel::Column for DummyColumn {
330		type Table = DummyTable;
331
332		const NAME: &'static str = "error";
333	}
334
335	impl diesel::SelectableExpression<DummyTable> for DummyColumn {}
336
337	impl diesel::AppearsOnTable<DummyTable> for DummyColumn {}
338
339	impl diesel::Expression for DummyColumn {
340		type SqlType = diesel::sql_types::Integer;
341	}
342
343	impl diesel::expression::ValidGrouping<()> for DummyColumn {
344		type IsAggregate = diesel::expression::is_aggregate::Yes;
345	}
346
347	impl diesel::Table for DummyTable {
348		type AllColumns = DummyColumn;
349		type PrimaryKey = DummyColumn;
350
351		fn primary_key(&self) -> Self::PrimaryKey {
352			unimplemented!()
353		}
354
355		fn all_columns() -> Self::AllColumns {
356			unimplemented!()
357		}
358	}
359}
360
361impl Display for DbEnumError {
362	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363		let _ = writeln!(
364			f,
365			"\n ❌ The rust enum `{}` and the database {} `{}` are out of sync: ",
366			self.rust_enum_name.bright_yellow(),
367			self.db_source.db_type(),
368			self.db_source.name().bright_cyan()
369		);
370
371		if !self.missing_from_db.is_empty() {
372			let _ = writeln!(
373				f,
374				"\n  - Variants missing from the {}:",
375				"database".bright_cyan()
376			);
377
378			for variant in &self.missing_from_db {
379				let _ = writeln!(f, "    • {variant}");
380			}
381		}
382
383		if !self.missing_from_rust.is_empty() {
384			let _ = writeln!(
385				f,
386				"\n  - Variants missing from the {}:",
387				"rust enum".bright_yellow()
388			);
389
390			for variant in &self.missing_from_rust {
391				writeln!(f, "    • {variant}").unwrap();
392			}
393		}
394
395		if !self.id_mismatches.is_empty() {
396			for IdMismatch {
397				variant,
398				expected,
399				found,
400			} in &self.id_mismatches
401			{
402				let _ = writeln!(
403					f,
404					"\n  - Wrong id mapping for `{}`",
405					variant.bright_yellow()
406				);
407				let _ = writeln!(f, "    Expected: {}", expected.bright_green());
408				let _ = writeln!(f, "    Found: {}", found.bright_red());
409			}
410		}
411
412		Ok(())
413	}
414}