Skip to main content

rok_orm_factory/
lib.rs

1//! rok-orm-factory — model factories for tests and seeding.
2//!
3//! # Quick start
4//!
5//! Define a factory by implementing [`Factory`] on your model:
6//!
7//! ```rust,ignore
8//! use rok_orm_factory::{Factory, Faker};
9//!
10//! impl Factory for User {
11//!     fn definition() -> Self {
12//!         User {
13//!             id: 0,
14//!             name:  Faker::name(),
15//!             email: Faker::email(),
16//!             active: true,
17//!         }
18//!     }
19//! }
20//!
21//! // Make models in-memory (no DB)
22//! let user  = User::factory().make();
23//! let users = User::factory().count(5).make_many();
24//!
25//! // Override specific fields
26//! let admin = User::factory()
27//!     .with(|u| u.name = "Admin".to_string())
28//!     .make();
29//!
30//! // Persist to DB (requires `features = ["postgres"]` and a pool in scope)
31//! let user = User::factory().create().await?;
32//! let users = User::factory().count(3).create_many().await?;
33//! ```
34
35pub mod faker;
36
37pub use faker::Faker;
38
39// ── Factory trait ─────────────────────────────────────────────────────────────
40
41/// Implement this on a model to enable the factory DSL.
42///
43/// The `definition()` method returns a default instance with fake data.
44pub trait Factory: Sized + 'static {
45    /// Return a model instance filled with fake/default data.
46    fn definition() -> Self;
47
48    /// Access the [`FactoryBuilder`] for this model.
49    fn factory() -> FactoryBuilder<Self> {
50        FactoryBuilder::new()
51    }
52}
53
54// ── FactoryBuilder ────────────────────────────────────────────────────────────
55
56/// Fluent builder for creating model instances.
57#[allow(clippy::type_complexity)]
58pub struct FactoryBuilder<T: Factory> {
59    count: usize,
60    overrides: Vec<Box<dyn Fn(&mut T)>>,
61}
62
63impl<T: Factory> FactoryBuilder<T> {
64    pub fn new() -> Self {
65        Self {
66            count: 1,
67            overrides: Vec::new(),
68        }
69    }
70
71    /// Set how many instances to create.
72    pub fn count(mut self, n: usize) -> Self {
73        self.count = n;
74        self
75    }
76
77    /// Apply a field override closure to each generated model.
78    pub fn with(mut self, f: impl Fn(&mut T) + 'static) -> Self {
79        self.overrides.push(Box::new(f));
80        self
81    }
82
83    fn build_one(&self) -> T {
84        let mut model = T::definition();
85        for ov in &self.overrides {
86            ov(&mut model);
87        }
88        model
89    }
90
91    /// Build one model in-memory using the factory definition.
92    pub fn make(self) -> T {
93        self.build_one()
94    }
95
96    /// Build `count` models in-memory.
97    pub fn make_many(self) -> Vec<T> {
98        let count = self.count;
99        (0..count).map(|_| self.build_one()).collect()
100    }
101}
102
103impl<T: Factory> Default for FactoryBuilder<T> {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109// ── Tests ─────────────────────────────────────────────────────────────────────
110
111#[cfg(test)]
112mod tests {
113    use super::{Factory, Faker};
114
115    #[derive(Debug, PartialEq, Clone)]
116    struct Post {
117        id: u64,
118        title: String,
119        body: String,
120        published: bool,
121    }
122
123    impl Factory for Post {
124        fn definition() -> Self {
125            Post {
126                id: 0,
127                title: Faker::sentence(3),
128                body: Faker::sentence(8),
129                published: false,
130            }
131        }
132    }
133
134    #[test]
135    fn make_returns_one_model() {
136        let post = Post::factory().make();
137        assert_eq!(post.id, 0);
138        assert!(!post.title.is_empty());
139    }
140
141    #[test]
142    fn make_many_returns_correct_count() {
143        let posts = Post::factory().count(5).make_many();
144        assert_eq!(posts.len(), 5);
145    }
146
147    #[test]
148    fn with_override_applies_to_each() {
149        let posts = Post::factory()
150            .count(3)
151            .with(|p| p.published = true)
152            .make_many();
153        assert!(posts.iter().all(|p| p.published));
154    }
155
156    #[test]
157    fn make_with_override_changes_field() {
158        let post = Post::factory()
159            .with(|p| p.title = "Custom Title".to_string())
160            .make();
161        assert_eq!(post.title, "Custom Title");
162    }
163
164    #[test]
165    fn multiple_overrides_are_applied_in_order() {
166        let post = Post::factory()
167            .with(|p| p.id = 99)
168            .with(|p| p.published = true)
169            .make();
170        assert_eq!(post.id, 99);
171        assert!(post.published);
172    }
173
174    #[test]
175    fn default_builder_count_is_one() {
176        use super::FactoryBuilder;
177        let builder = FactoryBuilder::<Post>::new();
178        let many = builder.make_many();
179        assert_eq!(many.len(), 1);
180    }
181
182    // ── Faker ─────────────────────────────────────────────────────────────────
183
184    #[test]
185    fn faker_name_not_empty() {
186        assert!(!Faker::name().is_empty());
187    }
188
189    #[test]
190    fn faker_email_contains_at() {
191        assert!(Faker::email().contains('@'));
192    }
193
194    #[test]
195    fn faker_uuid_is_36_chars() {
196        assert_eq!(Faker::uuid().len(), 36);
197    }
198
199    #[test]
200    fn faker_integer_in_range() {
201        for _ in 0..20 {
202            let n = Faker::integer(10, 20);
203            assert!((10..=20).contains(&n));
204        }
205    }
206
207    #[test]
208    fn faker_sentence_ends_with_period() {
209        let s = Faker::sentence(4);
210        assert!(s.ends_with('.'));
211    }
212
213    #[test]
214    fn faker_phone_starts_with_plus1() {
215        assert!(Faker::phone().starts_with("+1-"));
216    }
217
218    #[test]
219    fn faker_password_starts_with_pass() {
220        assert!(Faker::password().starts_with("pass-"));
221    }
222}
223
224// ── Helper: extract SqlValue from a serde_json::Value ──────────────────────
225
226#[cfg(feature = "postgres")]
227fn json_to_sqlvalue(val: &serde_json::Value) -> rok_orm_core::SqlValue {
228    match val {
229        serde_json::Value::Number(n) => n
230            .as_i64()
231            .map(rok_orm_core::SqlValue::Integer)
232            .or_else(|| n.as_f64().map(rok_orm_core::SqlValue::Float))
233            .unwrap_or(rok_orm_core::SqlValue::Null),
234        serde_json::Value::String(s) => rok_orm_core::SqlValue::Text(s.clone()),
235        serde_json::Value::Bool(b) => rok_orm_core::SqlValue::Bool(*b),
236        _ => rok_orm_core::SqlValue::Null,
237    }
238}
239
240/// Serialize a model to column-value pairs via serde_json, then execute
241/// `INSERT … RETURNING *` and return the persisted row.
242#[cfg(feature = "postgres")]
243async fn persist_model<T>(model: &T, pool: &sqlx::PgPool) -> Result<T, sqlx::Error>
244where
245    T: serde::Serialize
246        + rok_orm_core::Model
247        + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
248        + Send
249        + Unpin,
250{
251    let json = serde_json::to_value(model).map_err(|e| sqlx::Error::Protocol(e.to_string()))?;
252
253    let cols = T::columns();
254    let mut data: Vec<(&str, rok_orm_core::SqlValue)> = Vec::with_capacity(cols.len());
255    for col in cols {
256        let val = json_to_sqlvalue(json.get(col).unwrap_or(&serde_json::Value::Null));
257        data.push((col, val));
258    }
259
260    let (sql, params) = rok_orm_core::QueryBuilder::<T>::insert_sql(T::table_name(), &data);
261    let sql = format!("{sql} RETURNING *");
262
263    let mut query = sqlx::query_as::<_, T>(&sql);
264    for param in params {
265        query = match param {
266            rok_orm_core::SqlValue::Text(s) => query.bind(s),
267            rok_orm_core::SqlValue::Integer(n) => query.bind(n),
268            rok_orm_core::SqlValue::Float(f) => query.bind(f),
269            rok_orm_core::SqlValue::Bool(b) => query.bind(b),
270            rok_orm_core::SqlValue::Null => query.bind(Option::<String>::None),
271            _ => query.bind(Option::<String>::None),
272        };
273    }
274    query.fetch_one(pool).await
275}
276
277// ── Async create (postgres feature) ──────────────────────────────────────────
278
279#[cfg(feature = "postgres")]
280impl<T> FactoryBuilder<T>
281where
282    T: Factory
283        + serde::Serialize
284        + rok_orm_core::Model
285        + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow>
286        + Send
287        + Sync
288        + Unpin
289        + 'static,
290{
291    /// Persist one model to the database using the task-local pool.
292    ///
293    /// Requires a pool in scope via [`rok_orm::pool::with_pool`] or
294    /// [`rok_orm::OrmLayer`].
295    pub async fn create(self) -> Result<T, sqlx::Error> {
296        let pool = rok_orm::pool::try_current_pool().ok_or_else(|| {
297            sqlx::Error::Configuration(
298                "no database pool in scope — use pool::with_pool() or OrmLayer"
299                    .to_string()
300                    .into(),
301            )
302        })?;
303        self.create_with_pool(&pool).await
304    }
305
306    /// Persist one model to the database using the given pool.
307    pub async fn create_with_pool(&self, pool: &sqlx::PgPool) -> Result<T, sqlx::Error> {
308        let model = self.build_one();
309        persist_model::<T>(&model, pool).await
310    }
311
312    /// Persist `count` models to the database.
313    pub async fn create_many(self) -> Result<Vec<T>, sqlx::Error> {
314        let pool = rok_orm::pool::try_current_pool().ok_or_else(|| {
315            sqlx::Error::Configuration(
316                "no database pool in scope — use pool::with_pool() or OrmLayer"
317                    .to_string()
318                    .into(),
319            )
320        })?;
321        let mut results = Vec::with_capacity(self.count);
322        for _ in 0..self.count {
323            let model = self.build_one();
324            results.push(persist_model::<T>(&model, &pool).await?);
325        }
326        Ok(results)
327    }
328}