postrust_graphql/input/
mutation.rs

1//! Mutation input types for inserts and updates.
2//!
3//! Provides input type generation for GraphQL mutations based on table metadata.
4
5use crate::types::{pg_type_to_graphql, GraphQLType};
6use postrust_core::schema_cache::{Column, Table};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Represents a field in an insert input type.
11#[derive(Debug, Clone)]
12pub struct InsertField {
13    /// Field name
14    pub name: String,
15    /// GraphQL type
16    pub graphql_type: GraphQLType,
17    /// Whether the field is required (no default value and not nullable)
18    pub required: bool,
19    /// Field description
20    pub description: Option<String>,
21}
22
23impl InsertField {
24    /// Create an InsertField from a column.
25    pub fn from_column(column: &Column) -> Self {
26        let graphql_type = pg_type_to_graphql(&column.nominal_type);
27
28        // A field is required if:
29        // 1. It's not nullable AND
30        // 2. It has no default value AND
31        // 3. It's not a primary key with serial/auto-increment default
32        let has_auto_default = column.default.as_ref().map_or(false, |d| {
33            d.contains("nextval") || d.contains("gen_random_uuid")
34        });
35
36        let required = !column.nullable && column.default.is_none() && !has_auto_default;
37
38        Self {
39            name: column.name.clone(),
40            description: column.description.clone(),
41            graphql_type,
42            required,
43        }
44    }
45
46    /// Get the GraphQL type string for input.
47    pub fn type_string(&self) -> String {
48        let base = format!("{}", self.graphql_type);
49        if self.required {
50            format!("{}!", base)
51        } else {
52            base
53        }
54    }
55}
56
57/// Represents a field in an update input type.
58#[derive(Debug, Clone)]
59pub struct UpdateField {
60    /// Field name
61    pub name: String,
62    /// GraphQL type
63    pub graphql_type: GraphQLType,
64    /// Field description
65    pub description: Option<String>,
66    /// Whether this is a primary key (cannot be updated)
67    pub is_pk: bool,
68}
69
70impl UpdateField {
71    /// Create an UpdateField from a column.
72    pub fn from_column(column: &Column) -> Self {
73        let graphql_type = pg_type_to_graphql(&column.nominal_type);
74
75        Self {
76            name: column.name.clone(),
77            description: column.description.clone(),
78            graphql_type,
79            is_pk: column.is_pk,
80        }
81    }
82
83    /// Get the GraphQL type string for input (always nullable for updates).
84    pub fn type_string(&self) -> String {
85        format!("{}", self.graphql_type)
86    }
87
88    /// Check if this field can be updated (non-PK fields only).
89    pub fn is_updatable(&self) -> bool {
90        !self.is_pk
91    }
92}
93
94/// Represents an insert input type for a table.
95#[derive(Debug, Clone)]
96pub struct InsertInput {
97    /// Table being inserted into
98    pub table_name: String,
99    /// GraphQL type name (e.g., "UsersInsertInput")
100    pub type_name: String,
101    /// Fields that can be inserted
102    pub fields: Vec<InsertField>,
103}
104
105impl InsertInput {
106    /// Create an InsertInput from a table.
107    pub fn from_table(table: &Table) -> Self {
108        let type_name = format!("{}InsertInput", to_pascal_case(&table.name));
109
110        let fields = table
111            .columns
112            .values()
113            .map(InsertField::from_column)
114            .collect();
115
116        Self {
117            table_name: table.name.clone(),
118            type_name,
119            fields,
120        }
121    }
122
123    /// Get required fields.
124    pub fn required_fields(&self) -> Vec<&InsertField> {
125        self.fields.iter().filter(|f| f.required).collect()
126    }
127
128    /// Get optional fields.
129    pub fn optional_fields(&self) -> Vec<&InsertField> {
130        self.fields.iter().filter(|f| !f.required).collect()
131    }
132
133    /// Check if the table has any required fields.
134    pub fn has_required_fields(&self) -> bool {
135        self.fields.iter().any(|f| f.required)
136    }
137}
138
139/// Represents an update input type for a table.
140#[derive(Debug, Clone)]
141pub struct UpdateInput {
142    /// Table being updated
143    pub table_name: String,
144    /// GraphQL type name (e.g., "UsersSetInput")
145    pub type_name: String,
146    /// Fields that can be updated
147    pub fields: Vec<UpdateField>,
148}
149
150impl UpdateInput {
151    /// Create an UpdateInput from a table.
152    pub fn from_table(table: &Table) -> Self {
153        let type_name = format!("{}SetInput", to_pascal_case(&table.name));
154
155        let fields = table
156            .columns
157            .values()
158            .filter(|c| !c.is_pk) // Exclude primary keys from update
159            .map(UpdateField::from_column)
160            .collect();
161
162        Self {
163            table_name: table.name.clone(),
164            type_name,
165            fields,
166        }
167    }
168
169    /// Get updatable fields.
170    pub fn updatable_fields(&self) -> Vec<&UpdateField> {
171        self.fields.iter().filter(|f| f.is_updatable()).collect()
172    }
173}
174
175/// A dynamic input value that can hold different types.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(untagged)]
178pub enum InputValue {
179    /// Null value
180    Null,
181    /// Boolean value
182    Bool(bool),
183    /// Integer value
184    Int(i64),
185    /// Float value
186    Float(f64),
187    /// String value
188    String(String),
189    /// JSON object value
190    Object(HashMap<String, InputValue>),
191    /// JSON array value
192    Array(Vec<InputValue>),
193}
194
195impl InputValue {
196    /// Check if this is null.
197    pub fn is_null(&self) -> bool {
198        matches!(self, Self::Null)
199    }
200
201    /// Try to get as string.
202    pub fn as_string(&self) -> Option<&str> {
203        match self {
204            Self::String(s) => Some(s),
205            _ => None,
206        }
207    }
208
209    /// Try to get as i64.
210    pub fn as_int(&self) -> Option<i64> {
211        match self {
212            Self::Int(i) => Some(*i),
213            _ => None,
214        }
215    }
216
217    /// Try to get as f64.
218    pub fn as_float(&self) -> Option<f64> {
219        match self {
220            Self::Float(f) => Some(*f),
221            Self::Int(i) => Some(*i as f64),
222            _ => None,
223        }
224    }
225
226    /// Try to get as bool.
227    pub fn as_bool(&self) -> Option<bool> {
228        match self {
229            Self::Bool(b) => Some(*b),
230            _ => None,
231        }
232    }
233
234    /// Convert to SQL string representation.
235    pub fn to_sql_string(&self) -> String {
236        match self {
237            Self::Null => "NULL".to_string(),
238            Self::Bool(b) => if *b { "true" } else { "false" }.to_string(),
239            Self::Int(i) => i.to_string(),
240            Self::Float(f) => f.to_string(),
241            Self::String(s) => s.clone(),
242            Self::Object(o) => serde_json::to_string(o).unwrap_or_default(),
243            Self::Array(a) => serde_json::to_string(a).unwrap_or_default(),
244        }
245    }
246}
247
248/// Helper to convert snake_case to PascalCase.
249fn to_pascal_case(s: &str) -> String {
250    s.split('_')
251        .map(|word| {
252            let mut chars = word.chars();
253            match chars.next() {
254                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
255                None => String::new(),
256            }
257        })
258        .collect()
259}
260
261/// Check if a table is insertable based on its permissions.
262pub fn is_insertable(table: &Table) -> bool {
263    table.insertable
264}
265
266/// Check if a table is updatable based on its permissions.
267pub fn is_updatable(table: &Table) -> bool {
268    table.updatable
269}
270
271/// Check if a table is deletable based on its permissions.
272pub fn is_deletable(table: &Table) -> bool {
273    table.deletable
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use indexmap::IndexMap;
280    use pretty_assertions::assert_eq;
281
282    fn create_test_table() -> Table {
283        let mut columns = IndexMap::new();
284        columns.insert(
285            "id".into(),
286            Column {
287                name: "id".into(),
288                description: Some("Primary key".into()),
289                nullable: false,
290                data_type: "integer".into(),
291                nominal_type: "int4".into(),
292                max_len: None,
293                default: Some("nextval('users_id_seq')".into()),
294                enum_values: vec![],
295                is_pk: true,
296                position: 1,
297            },
298        );
299        columns.insert(
300            "name".into(),
301            Column {
302                name: "name".into(),
303                description: Some("User name".into()),
304                nullable: false,
305                data_type: "text".into(),
306                nominal_type: "text".into(),
307                max_len: None,
308                default: None,
309                enum_values: vec![],
310                is_pk: false,
311                position: 2,
312            },
313        );
314        columns.insert(
315            "email".into(),
316            Column {
317                name: "email".into(),
318                description: None,
319                nullable: true,
320                data_type: "text".into(),
321                nominal_type: "text".into(),
322                max_len: None,
323                default: None,
324                enum_values: vec![],
325                is_pk: false,
326                position: 3,
327            },
328        );
329        columns.insert(
330            "created_at".into(),
331            Column {
332                name: "created_at".into(),
333                description: None,
334                nullable: false,
335                data_type: "timestamptz".into(),
336                nominal_type: "timestamptz".into(),
337                max_len: None,
338                default: Some("now()".into()),
339                enum_values: vec![],
340                is_pk: false,
341                position: 4,
342            },
343        );
344
345        Table {
346            schema: "public".into(),
347            name: "users".into(),
348            description: Some("User accounts".into()),
349            is_view: false,
350            insertable: true,
351            updatable: true,
352            deletable: true,
353            pk_cols: vec!["id".into()],
354            columns,
355        }
356    }
357
358    fn create_readonly_table() -> Table {
359        let mut table = create_test_table();
360        table.insertable = false;
361        table.updatable = false;
362        table.deletable = false;
363        table
364    }
365
366    // ============================================================================
367    // InsertField Tests
368    // ============================================================================
369
370    #[test]
371    fn test_insert_field_required() {
372        let table = create_test_table();
373        let name_col = table.columns.get("name").unwrap();
374        let field = InsertField::from_column(name_col);
375
376        assert_eq!(field.name, "name");
377        assert!(field.required); // Not nullable, no default
378        assert_eq!(field.type_string(), "String!");
379    }
380
381    #[test]
382    fn test_insert_field_optional_nullable() {
383        let table = create_test_table();
384        let email_col = table.columns.get("email").unwrap();
385        let field = InsertField::from_column(email_col);
386
387        assert_eq!(field.name, "email");
388        assert!(!field.required); // Nullable
389        assert_eq!(field.type_string(), "String");
390    }
391
392    #[test]
393    fn test_insert_field_optional_with_default() {
394        let table = create_test_table();
395        let created_at_col = table.columns.get("created_at").unwrap();
396        let field = InsertField::from_column(created_at_col);
397
398        assert_eq!(field.name, "created_at");
399        assert!(!field.required); // Has default
400        assert_eq!(field.type_string(), "DateTime");
401    }
402
403    #[test]
404    fn test_insert_field_auto_pk() {
405        let table = create_test_table();
406        let id_col = table.columns.get("id").unwrap();
407        let field = InsertField::from_column(id_col);
408
409        assert_eq!(field.name, "id");
410        assert!(!field.required); // Has auto-increment default
411    }
412
413    // ============================================================================
414    // UpdateField Tests
415    // ============================================================================
416
417    #[test]
418    fn test_update_field_non_pk() {
419        let table = create_test_table();
420        let name_col = table.columns.get("name").unwrap();
421        let field = UpdateField::from_column(name_col);
422
423        assert_eq!(field.name, "name");
424        assert!(!field.is_pk);
425        assert!(field.is_updatable());
426        assert_eq!(field.type_string(), "String"); // All update fields are nullable
427    }
428
429    #[test]
430    fn test_update_field_pk() {
431        let table = create_test_table();
432        let id_col = table.columns.get("id").unwrap();
433        let field = UpdateField::from_column(id_col);
434
435        assert_eq!(field.name, "id");
436        assert!(field.is_pk);
437        assert!(!field.is_updatable());
438    }
439
440    // ============================================================================
441    // InsertInput Tests
442    // ============================================================================
443
444    #[test]
445    fn test_insert_input_from_table() {
446        let table = create_test_table();
447        let input = InsertInput::from_table(&table);
448
449        assert_eq!(input.table_name, "users");
450        assert_eq!(input.type_name, "UsersInsertInput");
451        assert_eq!(input.fields.len(), 4);
452    }
453
454    #[test]
455    fn test_insert_input_required_fields() {
456        let table = create_test_table();
457        let input = InsertInput::from_table(&table);
458
459        let required = input.required_fields();
460        assert_eq!(required.len(), 1); // Only "name" is required
461        assert_eq!(required[0].name, "name");
462    }
463
464    #[test]
465    fn test_insert_input_optional_fields() {
466        let table = create_test_table();
467        let input = InsertInput::from_table(&table);
468
469        let optional = input.optional_fields();
470        assert_eq!(optional.len(), 3); // id, email, created_at
471    }
472
473    #[test]
474    fn test_insert_input_has_required_fields() {
475        let table = create_test_table();
476        let input = InsertInput::from_table(&table);
477
478        assert!(input.has_required_fields());
479    }
480
481    // ============================================================================
482    // UpdateInput Tests
483    // ============================================================================
484
485    #[test]
486    fn test_update_input_from_table() {
487        let table = create_test_table();
488        let input = UpdateInput::from_table(&table);
489
490        assert_eq!(input.table_name, "users");
491        assert_eq!(input.type_name, "UsersSetInput");
492        assert_eq!(input.fields.len(), 3); // Excludes PK
493    }
494
495    #[test]
496    fn test_update_input_excludes_pk() {
497        let table = create_test_table();
498        let input = UpdateInput::from_table(&table);
499
500        let field_names: Vec<_> = input.fields.iter().map(|f| f.name.as_str()).collect();
501        assert!(!field_names.contains(&"id"));
502    }
503
504    #[test]
505    fn test_update_input_updatable_fields() {
506        let table = create_test_table();
507        let input = UpdateInput::from_table(&table);
508
509        let updatable = input.updatable_fields();
510        assert_eq!(updatable.len(), 3);
511    }
512
513    // ============================================================================
514    // InputValue Tests
515    // ============================================================================
516
517    #[test]
518    fn test_input_value_null() {
519        let value = InputValue::Null;
520        assert!(value.is_null());
521        assert_eq!(value.to_sql_string(), "NULL");
522    }
523
524    #[test]
525    fn test_input_value_bool() {
526        let value = InputValue::Bool(true);
527        assert_eq!(value.as_bool(), Some(true));
528        assert_eq!(value.to_sql_string(), "true");
529
530        let value = InputValue::Bool(false);
531        assert_eq!(value.to_sql_string(), "false");
532    }
533
534    #[test]
535    fn test_input_value_int() {
536        let value = InputValue::Int(42);
537        assert_eq!(value.as_int(), Some(42));
538        assert_eq!(value.as_float(), Some(42.0)); // Can coerce to float
539        assert_eq!(value.to_sql_string(), "42");
540    }
541
542    #[test]
543    fn test_input_value_float() {
544        let value = InputValue::Float(3.14);
545        assert_eq!(value.as_float(), Some(3.14));
546        assert_eq!(value.to_sql_string(), "3.14");
547    }
548
549    #[test]
550    fn test_input_value_string() {
551        let value = InputValue::String("hello".to_string());
552        assert_eq!(value.as_string(), Some("hello"));
553        assert_eq!(value.to_sql_string(), "hello");
554    }
555
556    // ============================================================================
557    // Table Permission Tests
558    // ============================================================================
559
560    #[test]
561    fn test_is_insertable() {
562        let table = create_test_table();
563        assert!(is_insertable(&table));
564
565        let readonly = create_readonly_table();
566        assert!(!is_insertable(&readonly));
567    }
568
569    #[test]
570    fn test_is_updatable() {
571        let table = create_test_table();
572        assert!(is_updatable(&table));
573
574        let readonly = create_readonly_table();
575        assert!(!is_updatable(&readonly));
576    }
577
578    #[test]
579    fn test_is_deletable() {
580        let table = create_test_table();
581        assert!(is_deletable(&table));
582
583        let readonly = create_readonly_table();
584        assert!(!is_deletable(&readonly));
585    }
586
587    // ============================================================================
588    // PascalCase Tests
589    // ============================================================================
590
591    #[test]
592    fn test_to_pascal_case() {
593        assert_eq!(to_pascal_case("users"), "Users");
594        assert_eq!(to_pascal_case("user_accounts"), "UserAccounts");
595        assert_eq!(to_pascal_case("my_table_name"), "MyTableName");
596    }
597}