prax_query/operations/
create.rs

1//! Create operation for inserting new records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::FilterValue;
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// A create operation for inserting a new record.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let user = client
16///     .user()
17///     .create(user::Create {
18///         email: "new@example.com".into(),
19///         name: Some("New User".into()),
20///     })
21///     .exec()
22///     .await?;
23/// ```
24pub struct CreateOperation<E: QueryEngine, M: Model> {
25    engine: E,
26    columns: Vec<String>,
27    values: Vec<FilterValue>,
28    select: Select,
29    _model: PhantomData<M>,
30}
31
32impl<E: QueryEngine, M: Model> CreateOperation<E, M> {
33    /// Create a new Create operation.
34    pub fn new(engine: E) -> Self {
35        Self {
36            engine,
37            columns: Vec::new(),
38            values: Vec::new(),
39            select: Select::All,
40            _model: PhantomData,
41        }
42    }
43
44    /// Set a column value.
45    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
46        self.columns.push(column.into());
47        self.values.push(value.into());
48        self
49    }
50
51    /// Set multiple column values from an iterator.
52    pub fn set_many(
53        mut self,
54        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
55    ) -> Self {
56        for (col, val) in values {
57            self.columns.push(col.into());
58            self.values.push(val.into());
59        }
60        self
61    }
62
63    /// Select specific fields to return.
64    pub fn select(mut self, select: impl Into<Select>) -> Self {
65        self.select = select.into();
66        self
67    }
68
69    /// Build the SQL query.
70    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
71        let mut sql = String::new();
72
73        // INSERT INTO clause
74        sql.push_str("INSERT INTO ");
75        sql.push_str(M::TABLE_NAME);
76
77        // Columns
78        sql.push_str(" (");
79        sql.push_str(&self.columns.join(", "));
80        sql.push(')');
81
82        // VALUES
83        sql.push_str(" VALUES (");
84        let placeholders: Vec<_> = (1..=self.values.len()).map(|i| format!("${}", i)).collect();
85        sql.push_str(&placeholders.join(", "));
86        sql.push(')');
87
88        // RETURNING clause
89        sql.push_str(" RETURNING ");
90        sql.push_str(&self.select.to_sql());
91
92        (sql, self.values.clone())
93    }
94
95    /// Execute the create operation and return the created record.
96    pub async fn exec(self) -> QueryResult<M>
97    where
98        M: Send + 'static,
99    {
100        let (sql, params) = self.build_sql();
101        self.engine.execute_insert::<M>(&sql, params).await
102    }
103}
104
105/// Create many records at once.
106pub struct CreateManyOperation<E: QueryEngine, M: Model> {
107    engine: E,
108    columns: Vec<String>,
109    rows: Vec<Vec<FilterValue>>,
110    skip_duplicates: bool,
111    _model: PhantomData<M>,
112}
113
114impl<E: QueryEngine, M: Model> CreateManyOperation<E, M> {
115    /// Create a new CreateMany operation.
116    pub fn new(engine: E) -> Self {
117        Self {
118            engine,
119            columns: Vec::new(),
120            rows: Vec::new(),
121            skip_duplicates: false,
122            _model: PhantomData,
123        }
124    }
125
126    /// Set the columns for insertion.
127    pub fn columns(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
128        self.columns = columns.into_iter().map(Into::into).collect();
129        self
130    }
131
132    /// Add a row of values.
133    pub fn row(mut self, values: impl IntoIterator<Item = impl Into<FilterValue>>) -> Self {
134        self.rows.push(values.into_iter().map(Into::into).collect());
135        self
136    }
137
138    /// Add multiple rows.
139    pub fn rows(
140        mut self,
141        rows: impl IntoIterator<Item = impl IntoIterator<Item = impl Into<FilterValue>>>,
142    ) -> Self {
143        for row in rows {
144            self.rows.push(row.into_iter().map(Into::into).collect());
145        }
146        self
147    }
148
149    /// Skip records that violate unique constraints.
150    pub fn skip_duplicates(mut self) -> Self {
151        self.skip_duplicates = true;
152        self
153    }
154
155    /// Build the SQL query.
156    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
157        let mut sql = String::new();
158        let mut all_params = Vec::new();
159
160        // INSERT INTO clause
161        sql.push_str("INSERT INTO ");
162        sql.push_str(M::TABLE_NAME);
163
164        // Columns
165        sql.push_str(" (");
166        sql.push_str(&self.columns.join(", "));
167        sql.push(')');
168
169        // VALUES
170        sql.push_str(" VALUES ");
171
172        let mut value_groups = Vec::new();
173        let mut param_idx = 1;
174
175        for row in &self.rows {
176            let placeholders: Vec<_> = row
177                .iter()
178                .map(|v| {
179                    all_params.push(v.clone());
180                    let placeholder = format!("${}", param_idx);
181                    param_idx += 1;
182                    placeholder
183                })
184                .collect();
185            value_groups.push(format!("({})", placeholders.join(", ")));
186        }
187
188        sql.push_str(&value_groups.join(", "));
189
190        // ON CONFLICT for skip_duplicates
191        if self.skip_duplicates {
192            sql.push_str(" ON CONFLICT DO NOTHING");
193        }
194
195        (sql, all_params)
196    }
197
198    /// Execute the create operation and return the number of created records.
199    pub async fn exec(self) -> QueryResult<u64> {
200        let (sql, params) = self.build_sql();
201        self.engine.execute_raw(&sql, params).await
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::error::QueryError;
209
210    struct TestModel;
211
212    impl Model for TestModel {
213        const MODEL_NAME: &'static str = "TestModel";
214        const TABLE_NAME: &'static str = "test_models";
215        const PRIMARY_KEY: &'static [&'static str] = &["id"];
216        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
217    }
218
219    #[derive(Clone)]
220    struct MockEngine {
221        insert_count: u64,
222    }
223
224    impl MockEngine {
225        fn new() -> Self {
226            Self { insert_count: 0 }
227        }
228
229        fn with_count(count: u64) -> Self {
230            Self {
231                insert_count: count,
232            }
233        }
234    }
235
236    impl QueryEngine for MockEngine {
237        fn query_many<T: Model + Send + 'static>(
238            &self,
239            _sql: &str,
240            _params: Vec<FilterValue>,
241        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
242            Box::pin(async { Ok(Vec::new()) })
243        }
244
245        fn query_one<T: Model + Send + 'static>(
246            &self,
247            _sql: &str,
248            _params: Vec<FilterValue>,
249        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
250            Box::pin(async { Err(QueryError::not_found("test")) })
251        }
252
253        fn query_optional<T: Model + Send + 'static>(
254            &self,
255            _sql: &str,
256            _params: Vec<FilterValue>,
257        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
258            Box::pin(async { Ok(None) })
259        }
260
261        fn execute_insert<T: Model + Send + 'static>(
262            &self,
263            _sql: &str,
264            _params: Vec<FilterValue>,
265        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
266            Box::pin(async { Err(QueryError::not_found("test")) })
267        }
268
269        fn execute_update<T: Model + Send + 'static>(
270            &self,
271            _sql: &str,
272            _params: Vec<FilterValue>,
273        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
274            Box::pin(async { Ok(Vec::new()) })
275        }
276
277        fn execute_delete(
278            &self,
279            _sql: &str,
280            _params: Vec<FilterValue>,
281        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
282            Box::pin(async { Ok(0) })
283        }
284
285        fn execute_raw(
286            &self,
287            _sql: &str,
288            _params: Vec<FilterValue>,
289        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
290            let count = self.insert_count;
291            Box::pin(async move { Ok(count) })
292        }
293
294        fn count(
295            &self,
296            _sql: &str,
297            _params: Vec<FilterValue>,
298        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
299            Box::pin(async { Ok(0) })
300        }
301    }
302
303    // ========== CreateOperation Tests ==========
304
305    #[test]
306    fn test_create_new() {
307        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
308        let (sql, params) = op.build_sql();
309
310        assert!(sql.contains("INSERT INTO test_models"));
311        assert!(sql.contains("RETURNING *"));
312        assert!(params.is_empty());
313    }
314
315    #[test]
316    fn test_create_basic() {
317        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
318            .set("name", "Alice")
319            .set("email", "alice@example.com");
320
321        let (sql, params) = op.build_sql();
322
323        assert!(sql.contains("INSERT INTO test_models"));
324        assert!(sql.contains("(name, email)"));
325        assert!(sql.contains("VALUES ($1, $2)"));
326        assert!(sql.contains("RETURNING *"));
327        assert_eq!(params.len(), 2);
328    }
329
330    #[test]
331    fn test_create_single_field() {
332        let op =
333            CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Alice");
334
335        let (sql, params) = op.build_sql();
336
337        assert!(sql.contains("(name)"));
338        assert!(sql.contains("VALUES ($1)"));
339        assert_eq!(params.len(), 1);
340    }
341
342    #[test]
343    fn test_create_with_set_many() {
344        let values = vec![
345            ("name", FilterValue::String("Bob".to_string())),
346            ("email", FilterValue::String("bob@test.com".to_string())),
347            ("age", FilterValue::Int(25)),
348        ];
349        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(values);
350
351        let (sql, params) = op.build_sql();
352
353        assert!(sql.contains("(name, email, age)"));
354        assert!(sql.contains("VALUES ($1, $2, $3)"));
355        assert_eq!(params.len(), 3);
356    }
357
358    #[test]
359    fn test_create_with_select() {
360        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
361            .set("name", "Alice")
362            .select(Select::fields(["id", "name"]));
363
364        let (sql, _) = op.build_sql();
365
366        assert!(sql.contains("RETURNING id, name"));
367        assert!(!sql.contains("RETURNING *"));
368    }
369
370    #[test]
371    fn test_create_with_null_value() {
372        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
373            .set("name", "Alice")
374            .set("nickname", FilterValue::Null);
375
376        let (sql, params) = op.build_sql();
377
378        assert_eq!(params.len(), 2);
379        assert_eq!(params[1], FilterValue::Null);
380    }
381
382    #[test]
383    fn test_create_with_boolean_value() {
384        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
385            .set("active", FilterValue::Bool(true));
386
387        let (_, params) = op.build_sql();
388
389        assert_eq!(params[0], FilterValue::Bool(true));
390    }
391
392    #[test]
393    fn test_create_with_numeric_values() {
394        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
395            .set("count", FilterValue::Int(42))
396            .set("price", FilterValue::Float(99.99));
397
398        let (_, params) = op.build_sql();
399
400        assert_eq!(params[0], FilterValue::Int(42));
401        assert_eq!(params[1], FilterValue::Float(99.99));
402    }
403
404    #[test]
405    fn test_create_with_json_value() {
406        let json = serde_json::json!({"key": "value", "nested": {"a": 1}});
407        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
408            .set("metadata", FilterValue::Json(json.clone()));
409
410        let (_, params) = op.build_sql();
411
412        assert_eq!(params[0], FilterValue::Json(json));
413    }
414
415    #[tokio::test]
416    async fn test_create_exec() {
417        let op =
418            CreateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Alice");
419
420        let result = op.exec().await;
421
422        // MockEngine returns not_found error for execute_insert
423        assert!(result.is_err());
424    }
425
426    // ========== CreateManyOperation Tests ==========
427
428    #[test]
429    fn test_create_many_new() {
430        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
431        let (sql, params) = op.build_sql();
432
433        assert!(sql.contains("INSERT INTO test_models"));
434        assert!(!sql.contains("RETURNING")); // CreateMany doesn't return
435        assert!(params.is_empty());
436    }
437
438    #[test]
439    fn test_create_many() {
440        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
441            .columns(["name", "email"])
442            .row(["Alice", "alice@example.com"])
443            .row(["Bob", "bob@example.com"]);
444
445        let (sql, params) = op.build_sql();
446
447        assert!(sql.contains("INSERT INTO test_models"));
448        assert!(sql.contains("(name, email)"));
449        assert!(sql.contains("VALUES ($1, $2), ($3, $4)"));
450        assert_eq!(params.len(), 4);
451    }
452
453    #[test]
454    fn test_create_many_single_row() {
455        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
456            .columns(["name"])
457            .row(["Alice"]);
458
459        let (sql, params) = op.build_sql();
460
461        assert!(sql.contains("VALUES ($1)"));
462        assert_eq!(params.len(), 1);
463    }
464
465    #[test]
466    fn test_create_many_skip_duplicates() {
467        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
468            .columns(["name", "email"])
469            .row(["Alice", "alice@example.com"])
470            .skip_duplicates();
471
472        let (sql, _) = op.build_sql();
473
474        assert!(sql.contains("ON CONFLICT DO NOTHING"));
475    }
476
477    #[test]
478    fn test_create_many_without_skip_duplicates() {
479        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
480            .columns(["name"])
481            .row(["Alice"]);
482
483        let (sql, _) = op.build_sql();
484
485        assert!(!sql.contains("ON CONFLICT"));
486    }
487
488    #[test]
489    fn test_create_many_with_rows() {
490        let rows = vec![
491            vec!["Alice", "alice@test.com"],
492            vec!["Bob", "bob@test.com"],
493            vec!["Charlie", "charlie@test.com"],
494        ];
495        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
496            .columns(["name", "email"])
497            .rows(rows);
498
499        let (sql, params) = op.build_sql();
500
501        assert!(sql.contains("VALUES ($1, $2), ($3, $4), ($5, $6)"));
502        assert_eq!(params.len(), 6);
503    }
504
505    #[test]
506    fn test_create_many_param_ordering() {
507        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
508            .columns(["a", "b"])
509            .row(["1", "2"])
510            .row(["3", "4"]);
511
512        let (_, params) = op.build_sql();
513
514        // Params should be ordered: row1.a, row1.b, row2.a, row2.b
515        assert_eq!(params[0], FilterValue::String("1".to_string()));
516        assert_eq!(params[1], FilterValue::String("2".to_string()));
517        assert_eq!(params[2], FilterValue::String("3".to_string()));
518        assert_eq!(params[3], FilterValue::String("4".to_string()));
519    }
520
521    #[tokio::test]
522    async fn test_create_many_exec() {
523        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(3))
524            .columns(["name"])
525            .row(["Alice"])
526            .row(["Bob"])
527            .row(["Charlie"]);
528
529        let result = op.exec().await;
530
531        assert!(result.is_ok());
532        assert_eq!(result.unwrap(), 3);
533    }
534
535    // ========== SQL Structure Tests ==========
536
537    #[test]
538    fn test_create_sql_structure() {
539        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
540            .set("name", "Alice")
541            .select(Select::fields(["id"]));
542
543        let (sql, _) = op.build_sql();
544
545        let insert_pos = sql.find("INSERT INTO").unwrap();
546        let columns_pos = sql.find("(name)").unwrap();
547        let values_pos = sql.find("VALUES").unwrap();
548        let returning_pos = sql.find("RETURNING").unwrap();
549
550        assert!(insert_pos < columns_pos);
551        assert!(columns_pos < values_pos);
552        assert!(values_pos < returning_pos);
553    }
554
555    #[test]
556    fn test_create_many_sql_structure() {
557        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
558            .columns(["name", "email"])
559            .row(["Alice", "alice@test.com"])
560            .skip_duplicates();
561
562        let (sql, _) = op.build_sql();
563
564        let insert_pos = sql.find("INSERT INTO").unwrap();
565        let columns_pos = sql.find("(name, email)").unwrap();
566        let values_pos = sql.find("VALUES").unwrap();
567        let conflict_pos = sql.find("ON CONFLICT").unwrap();
568
569        assert!(insert_pos < columns_pos);
570        assert!(columns_pos < values_pos);
571        assert!(values_pos < conflict_pos);
572    }
573
574    #[test]
575    fn test_create_table_name() {
576        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
577        let (sql, _) = op.build_sql();
578
579        assert!(sql.contains("test_models"));
580    }
581
582    // ========== Method Chaining Tests ==========
583
584    #[test]
585    fn test_create_method_chaining() {
586        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
587            .set("name", "Alice")
588            .set("email", "alice@test.com")
589            .select(Select::fields(["id", "name"]));
590
591        let (sql, params) = op.build_sql();
592
593        assert!(sql.contains("(name, email)"));
594        assert!(sql.contains("VALUES ($1, $2)"));
595        assert!(sql.contains("RETURNING id, name"));
596        assert_eq!(params.len(), 2);
597    }
598
599    #[test]
600    fn test_create_many_method_chaining() {
601        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
602            .columns(["a", "b"])
603            .row(["1", "2"])
604            .row(["3", "4"])
605            .skip_duplicates();
606
607        let (sql, params) = op.build_sql();
608
609        assert!(sql.contains("ON CONFLICT DO NOTHING"));
610        assert_eq!(params.len(), 4);
611    }
612}