prax_query/operations/
create.rs

1//! Create operation for inserting new records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::FilterValue;
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// A create operation for inserting a new record.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let user = client
16///     .user()
17///     .create(user::Create {
18///         email: "new@example.com".into(),
19///         name: Some("New User".into()),
20///     })
21///     .exec()
22///     .await?;
23/// ```
24pub struct CreateOperation<E: QueryEngine, M: Model> {
25    engine: E,
26    columns: Vec<String>,
27    values: Vec<FilterValue>,
28    select: Select,
29    _model: PhantomData<M>,
30}
31
32impl<E: QueryEngine, M: Model> CreateOperation<E, M> {
33    /// Create a new Create operation.
34    pub fn new(engine: E) -> Self {
35        Self {
36            engine,
37            columns: Vec::new(),
38            values: Vec::new(),
39            select: Select::All,
40            _model: PhantomData,
41        }
42    }
43
44    /// Set a column value.
45    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
46        self.columns.push(column.into());
47        self.values.push(value.into());
48        self
49    }
50
51    /// Set multiple column values from an iterator.
52    pub fn set_many(
53        mut self,
54        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
55    ) -> Self {
56        for (col, val) in values {
57            self.columns.push(col.into());
58            self.values.push(val.into());
59        }
60        self
61    }
62
63    /// Select specific fields to return.
64    pub fn select(mut self, select: impl Into<Select>) -> Self {
65        self.select = select.into();
66        self
67    }
68
69    /// Build the SQL query.
70    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
71        let mut sql = String::new();
72
73        // INSERT INTO clause
74        sql.push_str("INSERT INTO ");
75        sql.push_str(M::TABLE_NAME);
76
77        // Columns
78        sql.push_str(" (");
79        sql.push_str(&self.columns.join(", "));
80        sql.push(')');
81
82        // VALUES
83        sql.push_str(" VALUES (");
84        let placeholders: Vec<_> = (1..=self.values.len())
85            .map(|i| format!("${}", i))
86            .collect();
87        sql.push_str(&placeholders.join(", "));
88        sql.push(')');
89
90        // RETURNING clause
91        sql.push_str(" RETURNING ");
92        sql.push_str(&self.select.to_sql());
93
94        (sql, self.values.clone())
95    }
96
97    /// Execute the create operation and return the created record.
98    pub async fn exec(self) -> QueryResult<M>
99    where
100        M: Send + 'static,
101    {
102        let (sql, params) = self.build_sql();
103        self.engine.execute_insert::<M>(&sql, params).await
104    }
105}
106
107/// Create many records at once.
108pub struct CreateManyOperation<E: QueryEngine, M: Model> {
109    engine: E,
110    columns: Vec<String>,
111    rows: Vec<Vec<FilterValue>>,
112    skip_duplicates: bool,
113    _model: PhantomData<M>,
114}
115
116impl<E: QueryEngine, M: Model> CreateManyOperation<E, M> {
117    /// Create a new CreateMany operation.
118    pub fn new(engine: E) -> Self {
119        Self {
120            engine,
121            columns: Vec::new(),
122            rows: Vec::new(),
123            skip_duplicates: false,
124            _model: PhantomData,
125        }
126    }
127
128    /// Set the columns for insertion.
129    pub fn columns(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
130        self.columns = columns.into_iter().map(Into::into).collect();
131        self
132    }
133
134    /// Add a row of values.
135    pub fn row(mut self, values: impl IntoIterator<Item = impl Into<FilterValue>>) -> Self {
136        self.rows.push(values.into_iter().map(Into::into).collect());
137        self
138    }
139
140    /// Add multiple rows.
141    pub fn rows(
142        mut self,
143        rows: impl IntoIterator<Item = impl IntoIterator<Item = impl Into<FilterValue>>>,
144    ) -> Self {
145        for row in rows {
146            self.rows.push(row.into_iter().map(Into::into).collect());
147        }
148        self
149    }
150
151    /// Skip records that violate unique constraints.
152    pub fn skip_duplicates(mut self) -> Self {
153        self.skip_duplicates = true;
154        self
155    }
156
157    /// Build the SQL query.
158    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
159        let mut sql = String::new();
160        let mut all_params = Vec::new();
161
162        // INSERT INTO clause
163        sql.push_str("INSERT INTO ");
164        sql.push_str(M::TABLE_NAME);
165
166        // Columns
167        sql.push_str(" (");
168        sql.push_str(&self.columns.join(", "));
169        sql.push(')');
170
171        // VALUES
172        sql.push_str(" VALUES ");
173
174        let mut value_groups = Vec::new();
175        let mut param_idx = 1;
176
177        for row in &self.rows {
178            let placeholders: Vec<_> = row
179                .iter()
180                .map(|v| {
181                    all_params.push(v.clone());
182                    let placeholder = format!("${}", param_idx);
183                    param_idx += 1;
184                    placeholder
185                })
186                .collect();
187            value_groups.push(format!("({})", placeholders.join(", ")));
188        }
189
190        sql.push_str(&value_groups.join(", "));
191
192        // ON CONFLICT for skip_duplicates
193        if self.skip_duplicates {
194            sql.push_str(" ON CONFLICT DO NOTHING");
195        }
196
197        (sql, all_params)
198    }
199
200    /// Execute the create operation and return the number of created records.
201    pub async fn exec(self) -> QueryResult<u64> {
202        let (sql, params) = self.build_sql();
203        self.engine.execute_raw(&sql, params).await
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::error::QueryError;
211
212    struct TestModel;
213
214    impl Model for TestModel {
215        const MODEL_NAME: &'static str = "TestModel";
216        const TABLE_NAME: &'static str = "test_models";
217        const PRIMARY_KEY: &'static [&'static str] = &["id"];
218        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
219    }
220
221    #[derive(Clone)]
222    struct MockEngine {
223        insert_count: u64,
224    }
225
226    impl MockEngine {
227        fn new() -> Self {
228            Self { insert_count: 0 }
229        }
230
231        fn with_count(count: u64) -> Self {
232            Self { insert_count: count }
233        }
234    }
235
236    impl QueryEngine for MockEngine {
237        fn query_many<T: Model + Send + 'static>(
238            &self,
239            _sql: &str,
240            _params: Vec<FilterValue>,
241        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
242            Box::pin(async { Ok(Vec::new()) })
243        }
244
245        fn query_one<T: Model + Send + 'static>(
246            &self,
247            _sql: &str,
248            _params: Vec<FilterValue>,
249        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
250            Box::pin(async { Err(QueryError::not_found("test")) })
251        }
252
253        fn query_optional<T: Model + Send + 'static>(
254            &self,
255            _sql: &str,
256            _params: Vec<FilterValue>,
257        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
258            Box::pin(async { Ok(None) })
259        }
260
261        fn execute_insert<T: Model + Send + 'static>(
262            &self,
263            _sql: &str,
264            _params: Vec<FilterValue>,
265        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
266            Box::pin(async { Err(QueryError::not_found("test")) })
267        }
268
269        fn execute_update<T: Model + Send + 'static>(
270            &self,
271            _sql: &str,
272            _params: Vec<FilterValue>,
273        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
274            Box::pin(async { Ok(Vec::new()) })
275        }
276
277        fn execute_delete(
278            &self,
279            _sql: &str,
280            _params: Vec<FilterValue>,
281        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
282            Box::pin(async { Ok(0) })
283        }
284
285        fn execute_raw(
286            &self,
287            _sql: &str,
288            _params: Vec<FilterValue>,
289        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
290            let count = self.insert_count;
291            Box::pin(async move { Ok(count) })
292        }
293
294        fn count(
295            &self,
296            _sql: &str,
297            _params: Vec<FilterValue>,
298        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
299            Box::pin(async { Ok(0) })
300        }
301    }
302
303    // ========== CreateOperation Tests ==========
304
305    #[test]
306    fn test_create_new() {
307        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
308        let (sql, params) = op.build_sql();
309
310        assert!(sql.contains("INSERT INTO test_models"));
311        assert!(sql.contains("RETURNING *"));
312        assert!(params.is_empty());
313    }
314
315    #[test]
316    fn test_create_basic() {
317        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
318            .set("name", "Alice")
319            .set("email", "alice@example.com");
320
321        let (sql, params) = op.build_sql();
322
323        assert!(sql.contains("INSERT INTO test_models"));
324        assert!(sql.contains("(name, email)"));
325        assert!(sql.contains("VALUES ($1, $2)"));
326        assert!(sql.contains("RETURNING *"));
327        assert_eq!(params.len(), 2);
328    }
329
330    #[test]
331    fn test_create_single_field() {
332        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
333            .set("name", "Alice");
334
335        let (sql, params) = op.build_sql();
336
337        assert!(sql.contains("(name)"));
338        assert!(sql.contains("VALUES ($1)"));
339        assert_eq!(params.len(), 1);
340    }
341
342    #[test]
343    fn test_create_with_set_many() {
344        let values = vec![
345            ("name", FilterValue::String("Bob".to_string())),
346            ("email", FilterValue::String("bob@test.com".to_string())),
347            ("age", FilterValue::Int(25)),
348        ];
349        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
350            .set_many(values);
351
352        let (sql, params) = op.build_sql();
353
354        assert!(sql.contains("(name, email, age)"));
355        assert!(sql.contains("VALUES ($1, $2, $3)"));
356        assert_eq!(params.len(), 3);
357    }
358
359    #[test]
360    fn test_create_with_select() {
361        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
362            .set("name", "Alice")
363            .select(Select::fields(["id", "name"]));
364
365        let (sql, _) = op.build_sql();
366
367        assert!(sql.contains("RETURNING id, name"));
368        assert!(!sql.contains("RETURNING *"));
369    }
370
371    #[test]
372    fn test_create_with_null_value() {
373        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
374            .set("name", "Alice")
375            .set("nickname", FilterValue::Null);
376
377        let (sql, params) = op.build_sql();
378
379        assert_eq!(params.len(), 2);
380        assert_eq!(params[1], FilterValue::Null);
381    }
382
383    #[test]
384    fn test_create_with_boolean_value() {
385        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
386            .set("active", FilterValue::Bool(true));
387
388        let (_, params) = op.build_sql();
389
390        assert_eq!(params[0], FilterValue::Bool(true));
391    }
392
393    #[test]
394    fn test_create_with_numeric_values() {
395        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
396            .set("count", FilterValue::Int(42))
397            .set("price", FilterValue::Float(99.99));
398
399        let (_, params) = op.build_sql();
400
401        assert_eq!(params[0], FilterValue::Int(42));
402        assert_eq!(params[1], FilterValue::Float(99.99));
403    }
404
405    #[test]
406    fn test_create_with_json_value() {
407        let json = serde_json::json!({"key": "value", "nested": {"a": 1}});
408        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
409            .set("metadata", FilterValue::Json(json.clone()));
410
411        let (_, params) = op.build_sql();
412
413        assert_eq!(params[0], FilterValue::Json(json));
414    }
415
416    #[tokio::test]
417    async fn test_create_exec() {
418        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
419            .set("name", "Alice");
420
421        let result = op.exec().await;
422
423        // MockEngine returns not_found error for execute_insert
424        assert!(result.is_err());
425    }
426
427    // ========== CreateManyOperation Tests ==========
428
429    #[test]
430    fn test_create_many_new() {
431        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
432        let (sql, params) = op.build_sql();
433
434        assert!(sql.contains("INSERT INTO test_models"));
435        assert!(!sql.contains("RETURNING")); // CreateMany doesn't return
436        assert!(params.is_empty());
437    }
438
439    #[test]
440    fn test_create_many() {
441        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
442            .columns(["name", "email"])
443            .row(["Alice", "alice@example.com"])
444            .row(["Bob", "bob@example.com"]);
445
446        let (sql, params) = op.build_sql();
447
448        assert!(sql.contains("INSERT INTO test_models"));
449        assert!(sql.contains("(name, email)"));
450        assert!(sql.contains("VALUES ($1, $2), ($3, $4)"));
451        assert_eq!(params.len(), 4);
452    }
453
454    #[test]
455    fn test_create_many_single_row() {
456        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
457            .columns(["name"])
458            .row(["Alice"]);
459
460        let (sql, params) = op.build_sql();
461
462        assert!(sql.contains("VALUES ($1)"));
463        assert_eq!(params.len(), 1);
464    }
465
466    #[test]
467    fn test_create_many_skip_duplicates() {
468        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
469            .columns(["name", "email"])
470            .row(["Alice", "alice@example.com"])
471            .skip_duplicates();
472
473        let (sql, _) = op.build_sql();
474
475        assert!(sql.contains("ON CONFLICT DO NOTHING"));
476    }
477
478    #[test]
479    fn test_create_many_without_skip_duplicates() {
480        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
481            .columns(["name"])
482            .row(["Alice"]);
483
484        let (sql, _) = op.build_sql();
485
486        assert!(!sql.contains("ON CONFLICT"));
487    }
488
489    #[test]
490    fn test_create_many_with_rows() {
491        let rows = vec![
492            vec!["Alice", "alice@test.com"],
493            vec!["Bob", "bob@test.com"],
494            vec!["Charlie", "charlie@test.com"],
495        ];
496        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
497            .columns(["name", "email"])
498            .rows(rows);
499
500        let (sql, params) = op.build_sql();
501
502        assert!(sql.contains("VALUES ($1, $2), ($3, $4), ($5, $6)"));
503        assert_eq!(params.len(), 6);
504    }
505
506    #[test]
507    fn test_create_many_param_ordering() {
508        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
509            .columns(["a", "b"])
510            .row(["1", "2"])
511            .row(["3", "4"]);
512
513        let (_, params) = op.build_sql();
514
515        // Params should be ordered: row1.a, row1.b, row2.a, row2.b
516        assert_eq!(params[0], FilterValue::String("1".to_string()));
517        assert_eq!(params[1], FilterValue::String("2".to_string()));
518        assert_eq!(params[2], FilterValue::String("3".to_string()));
519        assert_eq!(params[3], FilterValue::String("4".to_string()));
520    }
521
522    #[tokio::test]
523    async fn test_create_many_exec() {
524        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(3))
525            .columns(["name"])
526            .row(["Alice"])
527            .row(["Bob"])
528            .row(["Charlie"]);
529
530        let result = op.exec().await;
531
532        assert!(result.is_ok());
533        assert_eq!(result.unwrap(), 3);
534    }
535
536    // ========== SQL Structure Tests ==========
537
538    #[test]
539    fn test_create_sql_structure() {
540        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
541            .set("name", "Alice")
542            .select(Select::fields(["id"]));
543
544        let (sql, _) = op.build_sql();
545
546        let insert_pos = sql.find("INSERT INTO").unwrap();
547        let columns_pos = sql.find("(name)").unwrap();
548        let values_pos = sql.find("VALUES").unwrap();
549        let returning_pos = sql.find("RETURNING").unwrap();
550
551        assert!(insert_pos < columns_pos);
552        assert!(columns_pos < values_pos);
553        assert!(values_pos < returning_pos);
554    }
555
556    #[test]
557    fn test_create_many_sql_structure() {
558        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
559            .columns(["name", "email"])
560            .row(["Alice", "alice@test.com"])
561            .skip_duplicates();
562
563        let (sql, _) = op.build_sql();
564
565        let insert_pos = sql.find("INSERT INTO").unwrap();
566        let columns_pos = sql.find("(name, email)").unwrap();
567        let values_pos = sql.find("VALUES").unwrap();
568        let conflict_pos = sql.find("ON CONFLICT").unwrap();
569
570        assert!(insert_pos < columns_pos);
571        assert!(columns_pos < values_pos);
572        assert!(values_pos < conflict_pos);
573    }
574
575    #[test]
576    fn test_create_table_name() {
577        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new());
578        let (sql, _) = op.build_sql();
579
580        assert!(sql.contains("test_models"));
581    }
582
583    // ========== Method Chaining Tests ==========
584
585    #[test]
586    fn test_create_method_chaining() {
587        let op = CreateOperation::<MockEngine, TestModel>::new(MockEngine::new())
588            .set("name", "Alice")
589            .set("email", "alice@test.com")
590            .select(Select::fields(["id", "name"]));
591
592        let (sql, params) = op.build_sql();
593
594        assert!(sql.contains("(name, email)"));
595        assert!(sql.contains("VALUES ($1, $2)"));
596        assert!(sql.contains("RETURNING id, name"));
597        assert_eq!(params.len(), 2);
598    }
599
600    #[test]
601    fn test_create_many_method_chaining() {
602        let op = CreateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
603            .columns(["a", "b"])
604            .row(["1", "2"])
605            .row(["3", "4"])
606            .skip_duplicates();
607
608        let (sql, params) = op.build_sql();
609
610        assert!(sql.contains("ON CONFLICT DO NOTHING"));
611        assert_eq!(params.len(), 4);
612    }
613}
614