postrust_graphql/
handler.rs

1//! Axum handler for the /graphql endpoint.
2//!
3//! Provides GraphQL request handling using async-graphql with dynamic schema
4//! generation from the PostgreSQL schema cache.
5
6use crate::context::GraphQLContext;
7use crate::error::GraphQLError;
8use crate::schema::object::TableObjectType;
9use crate::schema::{build_schema, GeneratedSchema, MutationType, SchemaConfig};
10use crate::subscription::{
11    generate_subscription_fields, NotifyBroker, SubscriptionField as SubField, TableChangePayload,
12};
13use async_graphql::dynamic::*;
14use async_graphql::Value;
15use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
16use axum::extract::State;
17use axum::response::IntoResponse;
18use futures::stream::StreamExt;
19use postrust_core::schema_cache::SchemaCache;
20use sqlx::PgPool;
21use std::collections::HashMap;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24use tracing::{debug, info, trace};
25
26/// GraphQL execution state shared across requests.
27pub struct GraphQLState {
28    /// Database connection pool
29    pub pool: PgPool,
30    /// Schema cache
31    pub schema_cache: Arc<SchemaCache>,
32    /// Generated GraphQL schema
33    pub generated_schema: GeneratedSchema,
34    /// async-graphql Schema (built dynamically)
35    pub schema: Schema,
36    /// Schema configuration
37    pub config: SchemaConfig,
38    /// Subscription fields
39    pub subscription_fields: Vec<SubField>,
40    /// Notification broker for subscriptions
41    pub broker: Arc<RwLock<Option<NotifyBroker>>>,
42}
43
44impl GraphQLState {
45    /// Create new GraphQL state from schema cache.
46    pub fn new(
47        pool: PgPool,
48        schema_cache: Arc<SchemaCache>,
49        config: SchemaConfig,
50    ) -> Result<Self, GraphQLError> {
51        let generated_schema = build_schema(&schema_cache, &config);
52        let subscription_fields = if config.enable_subscriptions {
53            generate_subscription_fields(&schema_cache, &generated_schema)
54        } else {
55            Vec::new()
56        };
57        let schema = build_dynamic_schema(
58            &generated_schema,
59            &schema_cache,
60            if config.enable_subscriptions {
61                Some(subscription_fields.as_slice())
62            } else {
63                None
64            },
65        )?;
66
67        Ok(Self {
68            pool: pool.clone(),
69            schema_cache,
70            generated_schema,
71            schema,
72            config,
73            subscription_fields,
74            broker: Arc::new(RwLock::new(None)),
75        })
76    }
77
78    /// Rebuild the schema (e.g., after schema cache refresh).
79    pub fn rebuild(&mut self) -> Result<(), GraphQLError> {
80        self.generated_schema = build_schema(&self.schema_cache, &self.config);
81        self.subscription_fields = if self.config.enable_subscriptions {
82            generate_subscription_fields(&self.schema_cache, &self.generated_schema)
83        } else {
84            Vec::new()
85        };
86        self.schema = build_dynamic_schema(
87            &self.generated_schema,
88            &self.schema_cache,
89            if self.config.enable_subscriptions {
90                Some(self.subscription_fields.as_slice())
91            } else {
92                None
93            },
94        )?;
95        Ok(())
96    }
97
98    /// Initialize the subscription broker.
99    ///
100    /// This should be called after creating the state to enable subscriptions.
101    pub async fn init_subscriptions(&self) -> Result<(), crate::subscription::BrokerError> {
102        if !self.config.enable_subscriptions {
103            return Ok(());
104        }
105
106        let broker = NotifyBroker::new(self.pool.clone());
107
108        // Collect all channels to listen on
109        let channels: Vec<String> = self
110            .subscription_fields
111            .iter()
112            .map(|f| f.channel_name())
113            .collect();
114
115        if !channels.is_empty() {
116            broker.start(channels).await?;
117            info!(
118                "Subscription broker started with {} channels",
119                self.subscription_fields.len()
120            );
121        }
122
123        // Store the broker
124        let mut broker_guard = self.broker.write().await;
125        *broker_guard = Some(broker);
126
127        Ok(())
128    }
129
130    /// Stop the subscription broker.
131    pub async fn stop_subscriptions(&self) {
132        let broker_guard = self.broker.read().await;
133        if let Some(broker) = broker_guard.as_ref() {
134            broker.stop().await;
135        }
136    }
137
138    /// Get the notification broker.
139    pub async fn get_broker(&self) -> Option<Arc<RwLock<Option<NotifyBroker>>>> {
140        Some(Arc::clone(&self.broker))
141    }
142}
143
144/// Handle a GraphQL request.
145pub async fn graphql_handler(
146    State(state): State<Arc<GraphQLState>>,
147    ctx: GraphQLContext,
148    req: GraphQLRequest,
149) -> GraphQLResponse {
150    let request = req
151        .into_inner()
152        .data(ctx)
153        .data(state.pool.clone())
154        .data(Arc::clone(&state.broker));
155    state.schema.execute(request).await.into()
156}
157
158/// Handle GraphQL WebSocket subscription upgrade.
159///
160/// This should be called with a WebSocket upgrade request to enable
161/// GraphQL subscriptions over WebSocket.
162pub async fn graphql_ws_handler(
163    State(state): State<Arc<GraphQLState>>,
164    protocol: async_graphql_axum::GraphQLProtocol,
165    ws: axum::extract::WebSocketUpgrade,
166) -> impl IntoResponse {
167    let schema = state.schema.clone();
168    let pool = state.pool.clone();
169    let broker = Arc::clone(&state.broker);
170
171    ws.protocols(["graphql-transport-ws", "graphql-ws"])
172        .on_upgrade(move |socket| async move {
173            let mut data = async_graphql::Data::default();
174            data.insert(pool);
175            data.insert(broker);
176
177            async_graphql_axum::GraphQLWebSocket::new(socket, schema, protocol)
178                .with_data(data)
179                .serve()
180                .await
181        })
182}
183
184/// Handle GraphQL playground request.
185pub async fn graphql_playground() -> impl axum::response::IntoResponse {
186    axum::response::Html(async_graphql::http::playground_source(
187        async_graphql::http::GraphQLPlaygroundConfig::new("/graphql")
188            .subscription_endpoint("/graphql/ws"),
189    ))
190}
191
192/// Build the dynamic async-graphql schema from our generated schema.
193fn build_dynamic_schema(
194    generated: &GeneratedSchema,
195    _schema_cache: &SchemaCache,
196    subscription_fields: Option<&[SubField]>,
197) -> Result<Schema, GraphQLError> {
198    // Create object types for each table
199    let mut object_types: HashMap<String, Object> = HashMap::new();
200
201    for (type_name, obj) in &generated.object_types {
202        let table_obj = create_object_type(obj);
203        object_types.insert(type_name.clone(), table_obj);
204    }
205
206    // Create query type
207    let query = create_query_type(generated);
208
209    // Create mutation type
210    let mutation = if !generated.mutation_fields.is_empty() {
211        Some(create_mutation_type(generated))
212    } else {
213        None
214    };
215
216    // Create subscription type if enabled
217    let subscription = subscription_fields.map(create_subscription_type);
218
219    // Build schema
220    let mut builder = Schema::build(
221        "Query",
222        mutation.as_ref().map(|_| "Mutation"),
223        subscription.as_ref().map(|_| "Subscription"),
224    );
225
226    // Register all object types
227    for (_, obj) in object_types {
228        builder = builder.register(obj);
229    }
230
231    // Register query type
232    builder = builder.register(query);
233
234    // Register mutation type if present
235    if let Some(mutation) = mutation {
236        builder = builder.register(mutation);
237    }
238
239    // Register subscription type if present
240    if let Some(subscription) = subscription {
241        builder = builder.register(subscription);
242    }
243
244    // Register scalar types
245    builder = builder.register(create_bigint_scalar());
246    builder = builder.register(create_bigdecimal_scalar());
247    builder = builder.register(create_json_scalar());
248    builder = builder.register(create_uuid_scalar());
249    builder = builder.register(create_date_scalar());
250    builder = builder.register(create_datetime_scalar());
251    builder = builder.register(create_time_scalar());
252
253    // Register input types
254    builder = register_filter_input_types(builder);
255
256    builder
257        .finish()
258        .map_err(|e| GraphQLError::SchemaError(e.to_string()))
259}
260
261/// Create an object type from a TableObjectType.
262fn create_object_type(obj: &TableObjectType) -> Object {
263    let mut object = Object::new(&obj.name);
264
265    if let Some(desc) = obj.description() {
266        object = object.description(desc);
267    }
268
269    for field in &obj.fields {
270        let field_name = field.name.clone();
271        let field_type = graphql_type_ref(&field.type_string());
272
273        // Create field with resolver that extracts from parent async_graphql::Value
274        // The query resolver stores rows as FieldValue::value(Value::Object)
275        // so we use as_value() to get the Value and extract fields from the Object
276        let gql_field = Field::new(&field.name, field_type, move |ctx| {
277            let field_name = field_name.clone();
278            FieldFuture::new(async move {
279                // Get the parent value as async_graphql::Value using as_value()
280                if let Some(Value::Object(map)) = ctx.parent_value.as_value() {
281                    // Convert field name to async_graphql::Name for lookup
282                    let key = async_graphql::Name::new(&field_name);
283                    if let Some(val) = map.get(&key) {
284                        return Ok(Some(FieldValue::value(val.clone())));
285                    }
286                }
287
288                // Field not found or parent not a Value::Object
289                Ok(None)
290            })
291        });
292
293        let gql_field = if let Some(desc) = &field.description {
294            gql_field.description(desc)
295        } else {
296            gql_field
297        };
298
299        object = object.field(gql_field);
300    }
301
302    object
303}
304
305/// Create the Query type with all table query fields.
306fn create_query_type(generated: &GeneratedSchema) -> Object {
307    let mut query = Object::new("Query");
308
309    for field in &generated.query_fields {
310        let table_name = field.table_name.clone();
311        let type_name = field.type_name.clone();
312        let is_by_pk = field.is_by_pk;
313        let return_type = graphql_type_ref(&field.return_type);
314
315        let mut gql_field = Field::new(&field.name, return_type, move |ctx| {
316            let table_name = table_name.clone();
317            let type_name = type_name.clone();
318            FieldFuture::new(async move {
319                resolve_query(&ctx, &table_name, &type_name, is_by_pk).await
320            })
321        });
322
323        // Add standard query arguments
324        if !is_by_pk {
325            gql_field = gql_field
326                .argument(InputValue::new("filter", TypeRef::named("JSON")))
327                .argument(InputValue::new("orderBy", TypeRef::named_list("String")))
328                .argument(InputValue::new("limit", TypeRef::named("Int")))
329                .argument(InputValue::new("offset", TypeRef::named("Int")));
330        } else {
331            // Add PK arguments
332            gql_field = gql_field.argument(InputValue::new("id", TypeRef::named_nn("Int")));
333        }
334
335        if let Some(desc) = &field.description {
336            gql_field = gql_field.description(desc);
337        }
338
339        query = query.field(gql_field);
340    }
341
342    // Add introspection queries
343    query = query.field(
344        Field::new("_schema", TypeRef::named("String"), |_| {
345            FieldFuture::new(async move {
346                Ok(Some(Value::String("Postrust GraphQL Schema".to_string())))
347            })
348        })
349        .description("Schema introspection"),
350    );
351
352    query
353}
354
355/// Create the Mutation type with all mutation fields.
356fn create_mutation_type(generated: &GeneratedSchema) -> Object {
357    let mut mutation = Object::new("Mutation");
358
359    for field in &generated.mutation_fields {
360        let table_name = field.table_name.clone();
361        let mutation_type = field.mutation_type;
362        let return_type = graphql_type_ref(&field.return_type);
363
364        let mut gql_field = Field::new(&field.name, return_type, move |ctx| {
365            let table_name = table_name.clone();
366            FieldFuture::new(async move {
367                resolve_mutation(&ctx, &table_name, mutation_type).await
368            })
369        });
370
371        // Add mutation-specific arguments
372        match mutation_type {
373            MutationType::Insert | MutationType::InsertOne => {
374                gql_field = gql_field
375                    .argument(InputValue::new("objects", TypeRef::named_nn_list("JSON")));
376            }
377            MutationType::Update | MutationType::UpdateByPk => {
378                gql_field = gql_field
379                    .argument(InputValue::new("where", TypeRef::named("JSON")))
380                    .argument(InputValue::new("set", TypeRef::named_nn("JSON")));
381            }
382            MutationType::Delete | MutationType::DeleteByPk => {
383                gql_field = gql_field.argument(InputValue::new("where", TypeRef::named("JSON")));
384            }
385        }
386
387        if let Some(desc) = &field.description {
388            gql_field = gql_field.description(desc);
389        }
390
391        mutation = mutation.field(gql_field);
392    }
393
394    mutation
395}
396
397/// Create the Subscription type with all subscription fields.
398fn create_subscription_type(fields: &[SubField]) -> Subscription {
399    let mut subscription = Subscription::new("Subscription");
400
401    for field in fields {
402        let channel_name = field.channel_name();
403        let return_type = TypeRef::named(&field.return_type);
404        let field_name = field.name.clone();
405        let description = field.description.clone();
406
407        let gql_field = SubscriptionField::new(&field_name, return_type, move |ctx| {
408            let channel_name = channel_name.clone();
409            SubscriptionFieldFuture::new(async move {
410                let broker_arc = ctx.data::<Arc<RwLock<Option<NotifyBroker>>>>()?;
411                let broker_guard = broker_arc.read().await;
412
413                let broker = broker_guard
414                    .as_ref()
415                    .ok_or_else(|| async_graphql::Error::new("Subscription broker not initialized"))?;
416
417                let stream = broker
418                    .subscribe(&channel_name)
419                    .await
420                    .map_err(|e| async_graphql::Error::new(format!("Subscription error: {}", e)))?;
421
422                // Transform notification stream to GraphQL values
423                // Use FieldValue::value() so field resolvers can use as_value()
424                let value_stream = stream.filter_map(|notification| async move {
425                    match TableChangePayload::from_payload(&notification.payload) {
426                        Ok(payload) => {
427                            if let Some(data) = payload.data() {
428                                // Convert to async_graphql::Value so field resolvers can extract fields
429                                Some(Ok(FieldValue::value(json_to_value(data.clone()))))
430                            } else {
431                                None
432                            }
433                        }
434                        Err(e) => {
435                            debug!("Failed to parse notification payload: {}", e);
436                            None
437                        }
438                    }
439                });
440
441                Ok(value_stream)
442            })
443        });
444
445        let gql_field = if let Some(desc) = description {
446            gql_field.description(desc)
447        } else {
448            gql_field
449        };
450
451        subscription = subscription.field(gql_field);
452    }
453
454    subscription
455}
456
457/// Resolve a query field.
458async fn resolve_query<'a>(
459    ctx: &ResolverContext<'a>,
460    table_name: &str,
461    _type_name: &str,
462    is_by_pk: bool,
463) -> Result<Option<FieldValue<'a>>, async_graphql::Error> {
464    let pool = ctx.data::<PgPool>()?;
465    let gql_ctx = ctx.data::<GraphQLContext>()?;
466
467    debug!("Resolving query for table: {}", table_name);
468
469    // Extract pagination arguments
470    let limit: Option<i64> = ctx
471        .args
472        .try_get("limit")
473        .ok()
474        .and_then(|v| v.i64().ok());
475
476    let offset: Option<i64> = ctx
477        .args
478        .try_get("offset")
479        .ok()
480        .and_then(|v| v.i64().ok());
481
482    // Build simple query
483    let mut sql = format!(
484        "SELECT row_to_json(t) FROM (SELECT * FROM public.{}) t",
485        table_name
486    );
487
488    if let Some(limit) = limit {
489        sql.push_str(&format!(" LIMIT {}", limit));
490    }
491
492    if let Some(offset) = offset {
493        sql.push_str(&format!(" OFFSET {}", offset));
494    }
495
496    // Execute query - returns Vec<serde_json::Value>
497    let result = execute_query(pool, &sql, gql_ctx.role()).await?;
498
499    if is_by_pk {
500        // Return single item as Value::Object
501        // json_to_value converts serde_json to async_graphql Value
502        Ok(result.into_iter().next().map(|v| FieldValue::value(json_to_value(v))))
503    } else {
504        // Return list with each item as Value::Object
505        let items: Vec<FieldValue> = result
506            .into_iter()
507            .map(|v| FieldValue::value(json_to_value(v)))
508            .collect();
509        Ok(Some(FieldValue::list(items)))
510    }
511}
512
513/// Resolve a mutation field.
514async fn resolve_mutation<'a>(
515    ctx: &ResolverContext<'a>,
516    table_name: &str,
517    mutation_type: MutationType,
518) -> Result<Option<FieldValue<'a>>, async_graphql::Error> {
519    let pool = ctx.data::<PgPool>()?;
520    let gql_ctx = ctx.data::<GraphQLContext>()?;
521
522    debug!("Resolving mutation for table: {} type: {:?}", table_name, mutation_type);
523
524    let result = match mutation_type {
525        MutationType::Insert | MutationType::InsertOne => {
526            let objects = ctx
527                .args
528                .try_get("objects")
529                .ok()
530                .map(|v| accessor_to_json(&v))
531                .unwrap_or_else(|| serde_json::Value::Array(vec![]));
532
533            execute_insert(pool, table_name, gql_ctx.role(), objects, mutation_type).await?
534        }
535        MutationType::Update | MutationType::UpdateByPk => {
536            let set_value = ctx
537                .args
538                .try_get("set")
539                .ok()
540                .map(|v| accessor_to_json(&v))
541                .unwrap_or_else(|| serde_json::json!({}));
542
543            let where_clause = ctx
544                .args
545                .try_get("where")
546                .ok()
547                .map(|v| accessor_to_json(&v));
548
549            execute_update(pool, table_name, gql_ctx.role(), set_value, where_clause, mutation_type).await?
550        }
551        MutationType::Delete | MutationType::DeleteByPk => {
552            let where_clause = ctx
553                .args
554                .try_get("where")
555                .ok()
556                .map(|v| accessor_to_json(&v));
557
558            execute_delete(pool, table_name, gql_ctx.role(), where_clause, mutation_type).await?
559        }
560    };
561
562    Ok(result)
563}
564
565/// Execute a SQL query and return results as serde_json::Value.
566/// We keep data as serde_json::Value so field resolvers can use try_downcast_ref.
567async fn execute_query(
568    pool: &PgPool,
569    sql: &str,
570    role: &str,
571) -> Result<Vec<serde_json::Value>, async_graphql::Error> {
572    use sqlx::Row;
573
574    trace!("Executing SQL: {}", sql);
575
576    let mut conn = pool.acquire().await?;
577
578    // Set role
579    sqlx::query(&format!("SET LOCAL ROLE {}", postrust_sql::escape_ident(role)))
580        .execute(&mut *conn)
581        .await?;
582
583    // Execute query
584    let rows = sqlx::query(sql).fetch_all(&mut *conn).await?;
585
586    // Return raw JSON values - don't convert to async_graphql::Value
587    // This allows field resolvers to use try_downcast_ref::<serde_json::Value>()
588    let results: Vec<serde_json::Value> = rows
589        .iter()
590        .filter_map(|row| row.try_get::<serde_json::Value, _>(0).ok())
591        .collect();
592
593    Ok(results)
594}
595
596/// Execute an insert mutation.
597async fn execute_insert<'a>(
598    pool: &PgPool,
599    table_name: &str,
600    role: &str,
601    objects: serde_json::Value,
602    mutation_type: MutationType,
603) -> Result<Option<FieldValue<'a>>, async_graphql::Error> {
604    use sqlx::Row;
605
606    trace!("Insert mutation for {}: {:?}", table_name, objects);
607
608    // Handle both array and single object
609    let objects_array = match objects {
610        serde_json::Value::Array(arr) => arr,
611        serde_json::Value::Object(obj) => vec![serde_json::Value::Object(obj)],
612        _ => return Err(async_graphql::Error::new("objects must be an array or object")),
613    };
614
615    if objects_array.is_empty() {
616        return Err(async_graphql::Error::new("objects cannot be empty"));
617    }
618
619    let mut conn = pool.acquire().await?;
620
621    // Set role
622    sqlx::query(&format!("SET LOCAL ROLE {}", postrust_sql::escape_ident(role)))
623        .execute(&mut *conn)
624        .await?;
625
626    let mut inserted: Vec<FieldValue> = Vec::new();
627
628    for obj in objects_array {
629        if let serde_json::Value::Object(map) = obj {
630            // Build INSERT query
631            let columns: Vec<&str> = map.keys().map(|k| k.as_str()).collect();
632            let placeholders: Vec<String> = (1..=columns.len()).map(|i| format!("${}", i)).collect();
633
634            let sql = format!(
635                "INSERT INTO public.{} ({}) VALUES ({}) RETURNING row_to_json(public.{}.*)",
636                postrust_sql::escape_ident(table_name),
637                columns.iter().map(|c| postrust_sql::escape_ident(c)).collect::<Vec<_>>().join(", "),
638                placeholders.join(", "),
639                postrust_sql::escape_ident(table_name)
640            );
641
642            trace!("Executing INSERT SQL: {}", sql);
643
644            // Build query with parameters
645            let mut query = sqlx::query(&sql);
646            for col in &columns {
647                if let Some(val) = map.get(*col) {
648                    query = bind_json_value(query, val);
649                }
650            }
651
652            let row = query.fetch_one(&mut *conn).await?;
653            if let Ok(json_val) = row.try_get::<serde_json::Value, _>(0) {
654                inserted.push(FieldValue::value(json_to_value(json_val)));
655            }
656        }
657    }
658
659    // Return based on mutation type
660    match mutation_type {
661        MutationType::InsertOne => {
662            // Return single item
663            Ok(inserted.into_iter().next())
664        }
665        _ => {
666            // Return list
667            Ok(Some(FieldValue::list(inserted)))
668        }
669    }
670}
671
672/// Bind a JSON value to a sqlx query.
673fn bind_json_value<'q>(
674    query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
675    value: &serde_json::Value,
676) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
677    match value {
678        serde_json::Value::Null => query.bind(None::<String>),
679        serde_json::Value::Bool(b) => query.bind(*b),
680        serde_json::Value::Number(n) => {
681            if let Some(i) = n.as_i64() {
682                query.bind(i)
683            } else if let Some(f) = n.as_f64() {
684                query.bind(f)
685            } else {
686                query.bind(n.to_string())
687            }
688        }
689        serde_json::Value::String(s) => query.bind(s.clone()),
690        _ => query.bind(value.to_string()),
691    }
692}
693
694/// Execute an update mutation.
695async fn execute_update<'a>(
696    pool: &PgPool,
697    table_name: &str,
698    role: &str,
699    set_value: serde_json::Value,
700    where_clause: Option<serde_json::Value>,
701    mutation_type: MutationType,
702) -> Result<Option<FieldValue<'a>>, async_graphql::Error> {
703    use sqlx::Row;
704
705    trace!("Update mutation for {}: {:?}", table_name, set_value);
706
707    let set_map = match set_value {
708        serde_json::Value::Object(map) => map,
709        _ => return Err(async_graphql::Error::new("set must be an object")),
710    };
711
712    if set_map.is_empty() {
713        return Err(async_graphql::Error::new("set cannot be empty"));
714    }
715
716    let mut conn = pool.acquire().await?;
717
718    // Set role
719    sqlx::query(&format!("SET LOCAL ROLE {}", postrust_sql::escape_ident(role)))
720        .execute(&mut *conn)
721        .await?;
722
723    // Build SET clause
724    let mut set_parts: Vec<String> = Vec::new();
725    let mut param_idx = 1;
726    for key in set_map.keys() {
727        set_parts.push(format!("{} = ${}", postrust_sql::escape_ident(key), param_idx));
728        param_idx += 1;
729    }
730
731    // Build WHERE clause
732    let (where_sql, where_values) = build_where_clause(where_clause.as_ref(), param_idx)?;
733
734    let sql = format!(
735        "UPDATE public.{} SET {} {} RETURNING row_to_json(public.{}.*)",
736        postrust_sql::escape_ident(table_name),
737        set_parts.join(", "),
738        where_sql,
739        postrust_sql::escape_ident(table_name)
740    );
741
742    trace!("Executing UPDATE SQL: {}", sql);
743
744    // Build query with parameters
745    let mut query = sqlx::query(&sql);
746
747    // Bind SET values
748    for val in set_map.values() {
749        query = bind_json_value(query, val);
750    }
751
752    // Bind WHERE values
753    for val in &where_values {
754        query = bind_json_value(query, val);
755    }
756
757    let rows = query.fetch_all(&mut *conn).await?;
758
759    let updated: Vec<FieldValue> = rows
760        .iter()
761        .filter_map(|row| row.try_get::<serde_json::Value, _>(0).ok())
762        .map(|v| FieldValue::value(json_to_value(v)))
763        .collect();
764
765    // Return based on mutation type
766    match mutation_type {
767        MutationType::UpdateByPk => {
768            Ok(updated.into_iter().next())
769        }
770        _ => {
771            Ok(Some(FieldValue::list(updated)))
772        }
773    }
774}
775
776/// Execute a delete mutation.
777async fn execute_delete<'a>(
778    pool: &PgPool,
779    table_name: &str,
780    role: &str,
781    where_clause: Option<serde_json::Value>,
782    mutation_type: MutationType,
783) -> Result<Option<FieldValue<'a>>, async_graphql::Error> {
784    use sqlx::Row;
785
786    trace!("Delete mutation for {}", table_name);
787
788    let mut conn = pool.acquire().await?;
789
790    // Set role
791    sqlx::query(&format!("SET LOCAL ROLE {}", postrust_sql::escape_ident(role)))
792        .execute(&mut *conn)
793        .await?;
794
795    // Build WHERE clause
796    let (where_sql, where_values) = build_where_clause(where_clause.as_ref(), 1)?;
797
798    let sql = format!(
799        "DELETE FROM public.{} {} RETURNING row_to_json(public.{}.*)",
800        postrust_sql::escape_ident(table_name),
801        where_sql,
802        postrust_sql::escape_ident(table_name)
803    );
804
805    trace!("Executing DELETE SQL: {}", sql);
806
807    // Build query with parameters
808    let mut query = sqlx::query(&sql);
809
810    // Bind WHERE values
811    for val in &where_values {
812        query = bind_json_value(query, val);
813    }
814
815    let rows = query.fetch_all(&mut *conn).await?;
816
817    let deleted: Vec<FieldValue> = rows
818        .iter()
819        .filter_map(|row| row.try_get::<serde_json::Value, _>(0).ok())
820        .map(|v| FieldValue::value(json_to_value(v)))
821        .collect();
822
823    // Return based on mutation type
824    match mutation_type {
825        MutationType::DeleteByPk => {
826            Ok(deleted.into_iter().next())
827        }
828        _ => {
829            Ok(Some(FieldValue::list(deleted)))
830        }
831    }
832}
833
834/// Build a WHERE clause from a JSON filter object.
835fn build_where_clause(
836    where_value: Option<&serde_json::Value>,
837    start_param_idx: usize,
838) -> Result<(String, Vec<serde_json::Value>), async_graphql::Error> {
839    let mut conditions: Vec<String> = Vec::new();
840    let mut values: Vec<serde_json::Value> = Vec::new();
841    let mut param_idx = start_param_idx;
842
843    if let Some(serde_json::Value::Object(map)) = where_value {
844        for (key, val) in map {
845            match val {
846                serde_json::Value::Object(op_map) => {
847                    // Handle operators like {eq: value}, {gt: value}, etc.
848                    for (op, op_val) in op_map {
849                        let condition = match op.as_str() {
850                            "eq" | "_eq" => format!("{} = ${}", postrust_sql::escape_ident(key), param_idx),
851                            "neq" | "_neq" => format!("{} != ${}", postrust_sql::escape_ident(key), param_idx),
852                            "gt" | "_gt" => format!("{} > ${}", postrust_sql::escape_ident(key), param_idx),
853                            "gte" | "_gte" => format!("{} >= ${}", postrust_sql::escape_ident(key), param_idx),
854                            "lt" | "_lt" => format!("{} < ${}", postrust_sql::escape_ident(key), param_idx),
855                            "lte" | "_lte" => format!("{} <= ${}", postrust_sql::escape_ident(key), param_idx),
856                            "like" | "_like" => format!("{} LIKE ${}", postrust_sql::escape_ident(key), param_idx),
857                            "ilike" | "_ilike" => format!("{} ILIKE ${}", postrust_sql::escape_ident(key), param_idx),
858                            "is_null" | "_is_null" => {
859                                if op_val.as_bool().unwrap_or(false) {
860                                    format!("{} IS NULL", postrust_sql::escape_ident(key))
861                                } else {
862                                    format!("{} IS NOT NULL", postrust_sql::escape_ident(key))
863                                }
864                            }
865                            _ => continue,
866                        };
867
868                        if !op.contains("is_null") {
869                            conditions.push(condition);
870                            values.push(op_val.clone());
871                            param_idx += 1;
872                        } else {
873                            conditions.push(condition);
874                        }
875                    }
876                }
877                _ => {
878                    // Direct equality: {field: value}
879                    conditions.push(format!("{} = ${}", postrust_sql::escape_ident(key), param_idx));
880                    values.push(val.clone());
881                    param_idx += 1;
882                }
883            }
884        }
885    }
886
887    let where_sql = if conditions.is_empty() {
888        String::new()
889    } else {
890        format!("WHERE {}", conditions.join(" AND "))
891    };
892
893    Ok((where_sql, values))
894}
895
896/// Convert a GraphQL type string to a TypeRef.
897fn graphql_type_ref(type_str: &str) -> TypeRef {
898    // Parse type string like "[Users!]!" or "String" or "Int!"
899    let is_list = type_str.starts_with('[');
900    let is_nn = type_str.ends_with('!');
901
902    // Strip outer modifiers: first the trailing !, then the brackets
903    let inner = if is_list {
904        let stripped = type_str
905            .trim_end_matches('!')  // Remove outer !
906            .trim_start_matches('[')  // Remove [
907            .trim_end_matches(']');   // Remove ]
908        stripped
909    } else {
910        type_str.trim_end_matches('!')
911    };
912
913    let inner_nn = inner.ends_with('!');
914    let base_type = inner.trim_end_matches('!');
915
916    if is_list {
917        if is_nn {
918            if inner_nn {
919                TypeRef::named_nn_list_nn(base_type)
920            } else {
921                TypeRef::named_list_nn(base_type)
922            }
923        } else if inner_nn {
924            TypeRef::named_nn_list(base_type)
925        } else {
926            TypeRef::named_list(base_type)
927        }
928    } else if is_nn {
929        TypeRef::named_nn(base_type)
930    } else {
931        TypeRef::named(base_type)
932    }
933}
934
935/// Convert ValueAccessor to JSON.
936fn accessor_to_json(accessor: &ValueAccessor<'_>) -> serde_json::Value {
937    // Use the deserialize method if available, or convert manually
938    if accessor.is_null() {
939        serde_json::Value::Null
940    } else if let Ok(b) = accessor.boolean() {
941        serde_json::Value::Bool(b)
942    } else if let Ok(i) = accessor.i64() {
943        serde_json::Value::Number(i.into())
944    } else if let Ok(f) = accessor.f64() {
945        serde_json::Number::from_f64(f)
946            .map(serde_json::Value::Number)
947            .unwrap_or(serde_json::Value::Null)
948    } else if let Ok(s) = accessor.string() {
949        serde_json::Value::String(s.to_string())
950    } else if let Ok(list) = accessor.list() {
951        serde_json::Value::Array(
952            list.iter()
953                .map(|v| accessor_to_json(&v))
954                .collect()
955        )
956    } else if let Ok(obj) = accessor.object() {
957        let map: serde_json::Map<String, serde_json::Value> = obj
958            .iter()
959            .map(|(k, v)| (k.to_string(), accessor_to_json(&v)))
960            .collect();
961        serde_json::Value::Object(map)
962    } else {
963        serde_json::Value::Null
964    }
965}
966
967/// Convert async-graphql Value to JSON.
968#[allow(dead_code)]
969fn value_to_json(value: &Value) -> serde_json::Value {
970    match value {
971        Value::Null => serde_json::Value::Null,
972        Value::Boolean(b) => serde_json::Value::Bool(*b),
973        Value::Number(n) => {
974            if let Some(i) = n.as_i64() {
975                serde_json::Value::Number(i.into())
976            } else if let Some(f) = n.as_f64() {
977                serde_json::Value::Number(serde_json::Number::from_f64(f).unwrap())
978            } else {
979                serde_json::Value::Null
980            }
981        }
982        Value::String(s) => serde_json::Value::String(s.clone()),
983        Value::List(arr) => {
984            serde_json::Value::Array(arr.iter().map(value_to_json).collect())
985        }
986        Value::Object(obj) => {
987            let map: serde_json::Map<String, serde_json::Value> = obj
988                .iter()
989                .map(|(k, v)| (k.to_string(), value_to_json(v)))
990                .collect();
991            serde_json::Value::Object(map)
992        }
993        Value::Binary(b) => serde_json::Value::String(base64::Engine::encode(
994            &base64::engine::general_purpose::STANDARD,
995            b,
996        )),
997        Value::Enum(e) => serde_json::Value::String(e.to_string()),
998    }
999}
1000
1001/// Convert JSON to async-graphql Value.
1002fn json_to_value(json: serde_json::Value) -> Value {
1003    match json {
1004        serde_json::Value::Null => Value::Null,
1005        serde_json::Value::Bool(b) => Value::Boolean(b),
1006        serde_json::Value::Number(n) => {
1007            if let Some(i) = n.as_i64() {
1008                Value::Number(i.into())
1009            } else if let Some(f) = n.as_f64() {
1010                Value::Number(async_graphql::Number::from_f64(f).unwrap())
1011            } else {
1012                Value::Null
1013            }
1014        }
1015        serde_json::Value::String(s) => Value::String(s),
1016        serde_json::Value::Array(arr) => {
1017            Value::List(arr.into_iter().map(json_to_value).collect())
1018        }
1019        serde_json::Value::Object(obj) => {
1020            let map: indexmap::IndexMap<async_graphql::Name, Value> = obj
1021                .into_iter()
1022                .map(|(k, v)| (async_graphql::Name::new(k), json_to_value(v)))
1023                .collect();
1024            Value::Object(map)
1025        }
1026    }
1027}
1028
1029/// Create BigInt scalar type.
1030fn create_bigint_scalar() -> Scalar {
1031    Scalar::new("BigInt")
1032        .description("64-bit integer")
1033        .specified_by_url("https://spec.graphql.org/draft/#sec-Int")
1034}
1035
1036/// Create BigDecimal scalar type.
1037fn create_bigdecimal_scalar() -> Scalar {
1038    Scalar::new("BigDecimal")
1039        .description("Arbitrary precision decimal number")
1040}
1041
1042/// Create JSON scalar type.
1043fn create_json_scalar() -> Scalar {
1044    Scalar::new("JSON")
1045        .description("Arbitrary JSON value")
1046        .specified_by_url("https://spec.graphql.org/draft/#sec-Scalars")
1047}
1048
1049/// Create UUID scalar type.
1050fn create_uuid_scalar() -> Scalar {
1051    Scalar::new("UUID").description("UUID string")
1052}
1053
1054/// Create Date scalar type.
1055fn create_date_scalar() -> Scalar {
1056    Scalar::new("Date").description("ISO 8601 date string (YYYY-MM-DD)")
1057}
1058
1059/// Create DateTime scalar type.
1060fn create_datetime_scalar() -> Scalar {
1061    Scalar::new("DateTime").description("ISO 8601 datetime string")
1062}
1063
1064/// Create Time scalar type.
1065fn create_time_scalar() -> Scalar {
1066    Scalar::new("Time").description("ISO 8601 time string (HH:MM:SS)")
1067}
1068
1069/// Register filter input types.
1070fn register_filter_input_types(builder: SchemaBuilder) -> SchemaBuilder {
1071    let string_filter = InputObject::new("StringFilterInput")
1072        .field(InputValue::new("eq", TypeRef::named("String")))
1073        .field(InputValue::new("neq", TypeRef::named("String")))
1074        .field(InputValue::new("like", TypeRef::named("String")))
1075        .field(InputValue::new("ilike", TypeRef::named("String")))
1076        .field(InputValue::new("in", TypeRef::named_list("String")))
1077        .field(InputValue::new("isNull", TypeRef::named("Boolean")));
1078
1079    let int_filter = InputObject::new("IntFilterInput")
1080        .field(InputValue::new("eq", TypeRef::named("Int")))
1081        .field(InputValue::new("neq", TypeRef::named("Int")))
1082        .field(InputValue::new("gt", TypeRef::named("Int")))
1083        .field(InputValue::new("gte", TypeRef::named("Int")))
1084        .field(InputValue::new("lt", TypeRef::named("Int")))
1085        .field(InputValue::new("lte", TypeRef::named("Int")))
1086        .field(InputValue::new("in", TypeRef::named_list("Int")));
1087
1088    let boolean_filter = InputObject::new("BooleanFilterInput")
1089        .field(InputValue::new("eq", TypeRef::named("Boolean")));
1090
1091    builder
1092        .register(string_filter)
1093        .register(int_filter)
1094        .register(boolean_filter)
1095}
1096
1097#[cfg(test)]
1098mod tests {
1099    use super::*;
1100    use indexmap::IndexMap;
1101    use postrust_core::schema_cache::{Column, Table};
1102    use std::collections::{HashMap, HashSet};
1103
1104    fn create_test_table(name: &str) -> Table {
1105        let mut columns = IndexMap::new();
1106        columns.insert(
1107            "id".into(),
1108            Column {
1109                name: "id".into(),
1110                description: None,
1111                nullable: false,
1112                data_type: "integer".into(),
1113                nominal_type: "int4".into(),
1114                max_len: None,
1115                default: Some("nextval('id_seq')".into()),
1116                enum_values: vec![],
1117                is_pk: true,
1118                position: 1,
1119            },
1120        );
1121        columns.insert(
1122            "name".into(),
1123            Column {
1124                name: "name".into(),
1125                description: None,
1126                nullable: false,
1127                data_type: "text".into(),
1128                nominal_type: "text".into(),
1129                max_len: None,
1130                default: None,
1131                enum_values: vec![],
1132                is_pk: false,
1133                position: 2,
1134            },
1135        );
1136
1137        Table {
1138            schema: "public".into(),
1139            name: name.into(),
1140            description: None,
1141            is_view: false,
1142            insertable: true,
1143            updatable: true,
1144            deletable: true,
1145            pk_cols: vec!["id".into()],
1146            columns,
1147        }
1148    }
1149
1150    fn create_test_schema_cache() -> SchemaCache {
1151        let mut tables = HashMap::new();
1152        let users = create_test_table("users");
1153        tables.insert(users.qualified_identifier(), users);
1154
1155        SchemaCache {
1156            tables,
1157            relationships: HashMap::new(),
1158            routines: HashMap::new(),
1159            timezones: HashSet::new(),
1160            pg_version: 150000,
1161        }
1162    }
1163
1164    // ============================================================================
1165    // Type Reference Tests
1166    // ============================================================================
1167
1168    #[test]
1169    fn test_graphql_type_ref_simple() {
1170        let _type_ref = graphql_type_ref("String");
1171        // TypeRef doesn't implement PartialEq, so we just test it doesn't panic
1172    }
1173
1174    #[test]
1175    fn test_graphql_type_ref_non_null() {
1176        let _type_ref = graphql_type_ref("String!");
1177    }
1178
1179    #[test]
1180    fn test_graphql_type_ref_list() {
1181        let _type_ref = graphql_type_ref("[String]");
1182    }
1183
1184    #[test]
1185    fn test_graphql_type_ref_list_non_null() {
1186        let _type_ref = graphql_type_ref("[String!]!");
1187    }
1188
1189    // ============================================================================
1190    // Value Conversion Tests
1191    // ============================================================================
1192
1193    #[test]
1194    fn test_value_to_json_null() {
1195        let value = Value::Null;
1196        let json = value_to_json(&value);
1197        assert_eq!(json, serde_json::Value::Null);
1198    }
1199
1200    #[test]
1201    fn test_value_to_json_boolean() {
1202        let value = Value::Boolean(true);
1203        let json = value_to_json(&value);
1204        assert_eq!(json, serde_json::Value::Bool(true));
1205    }
1206
1207    #[test]
1208    fn test_value_to_json_number() {
1209        let value = Value::Number(42.into());
1210        let json = value_to_json(&value);
1211        assert_eq!(json, serde_json::json!(42));
1212    }
1213
1214    #[test]
1215    fn test_value_to_json_string() {
1216        let value = Value::String("hello".to_string());
1217        let json = value_to_json(&value);
1218        assert_eq!(json, serde_json::Value::String("hello".to_string()));
1219    }
1220
1221    #[test]
1222    fn test_value_to_json_list() {
1223        let value = Value::List(vec![Value::Number(1.into()), Value::Number(2.into())]);
1224        let json = value_to_json(&value);
1225        assert_eq!(json, serde_json::json!([1, 2]));
1226    }
1227
1228    #[test]
1229    fn test_json_to_value_null() {
1230        let json = serde_json::Value::Null;
1231        let value = json_to_value(json);
1232        assert!(matches!(value, Value::Null));
1233    }
1234
1235    #[test]
1236    fn test_json_to_value_boolean() {
1237        let json = serde_json::Value::Bool(false);
1238        let value = json_to_value(json);
1239        assert!(matches!(value, Value::Boolean(false)));
1240    }
1241
1242    #[test]
1243    fn test_json_to_value_number() {
1244        let json = serde_json::json!(123);
1245        let value = json_to_value(json);
1246        assert!(matches!(value, Value::Number(_)));
1247    }
1248
1249    #[test]
1250    fn test_json_to_value_string() {
1251        let json = serde_json::Value::String("test".to_string());
1252        let value = json_to_value(json);
1253        assert!(matches!(value, Value::String(_)));
1254    }
1255
1256    #[test]
1257    fn test_json_to_value_array() {
1258        let json = serde_json::json!([1, 2, 3]);
1259        let value = json_to_value(json);
1260        assert!(matches!(value, Value::List(_)));
1261    }
1262
1263    #[test]
1264    fn test_json_to_value_object() {
1265        let json = serde_json::json!({"key": "value"});
1266        let value = json_to_value(json);
1267        assert!(matches!(value, Value::Object(_)));
1268    }
1269
1270    // ============================================================================
1271    // Schema Building Tests
1272    // ============================================================================
1273
1274    #[test]
1275    fn test_build_dynamic_schema() {
1276        let cache = create_test_schema_cache();
1277        let config = SchemaConfig::default();
1278        let generated = build_schema(&cache, &config);
1279
1280        let result = build_dynamic_schema(&generated, &cache, None);
1281        if let Err(ref e) = result {
1282            eprintln!("Schema build error: {:?}", e);
1283        }
1284        assert!(result.is_ok(), "Schema build failed: {:?}", result.err());
1285    }
1286
1287    #[test]
1288    fn test_create_object_type() {
1289        let table = create_test_table("users");
1290        let obj = TableObjectType::from_table(&table);
1291        let _gql_obj = create_object_type(&obj);
1292    }
1293
1294    #[test]
1295    fn test_create_query_type() {
1296        let cache = create_test_schema_cache();
1297        let config = SchemaConfig::default();
1298        let generated = build_schema(&cache, &config);
1299
1300        let _query = create_query_type(&generated);
1301    }
1302
1303    #[test]
1304    fn test_create_mutation_type() {
1305        let cache = create_test_schema_cache();
1306        let config = SchemaConfig::default();
1307        let generated = build_schema(&cache, &config);
1308
1309        let _mutation = create_mutation_type(&generated);
1310    }
1311
1312    // ============================================================================
1313    // Scalar Tests
1314    // ============================================================================
1315
1316    #[test]
1317    fn test_create_scalars() {
1318        let _bigint = create_bigint_scalar();
1319        let _json = create_json_scalar();
1320        let _uuid = create_uuid_scalar();
1321        let _datetime = create_datetime_scalar();
1322    }
1323
1324    // ============================================================================
1325    // Filter Input Type Tests
1326    // ============================================================================
1327
1328    #[test]
1329    fn test_register_filter_input_types() {
1330        let cache = create_test_schema_cache();
1331        let config = SchemaConfig::default();
1332        let _generated = build_schema(&cache, &config);
1333
1334        // Build a minimal schema with filter types
1335        let query = Object::new("Query").field(Field::new(
1336            "test",
1337            TypeRef::named("String"),
1338            |_| FieldFuture::new(async { Ok(None::<FieldValue>) }),
1339        ));
1340
1341        let mut builder = Schema::build("Query", None::<&str>, None);
1342        builder = builder.register(query);
1343        builder = register_filter_input_types(builder);
1344
1345        let result = builder.finish();
1346        assert!(result.is_ok());
1347    }
1348
1349    // ============================================================================
1350    // Subscription Tests
1351    // ============================================================================
1352
1353    #[test]
1354    fn test_build_schema_with_subscriptions() {
1355        let cache = create_test_schema_cache();
1356        let config = SchemaConfig {
1357            enable_subscriptions: true,
1358            ..SchemaConfig::default()
1359        };
1360        let generated = build_schema(&cache, &config);
1361
1362        // Generate subscription fields
1363        let sub_fields = generate_subscription_fields(&cache, &generated);
1364        assert!(!sub_fields.is_empty(), "Should have subscription fields");
1365
1366        // Build schema with subscriptions
1367        let result = build_dynamic_schema(&generated, &cache, Some(&sub_fields));
1368        assert!(result.is_ok(), "Schema with subscriptions should build");
1369    }
1370
1371    #[test]
1372    fn test_subscription_field_generation() {
1373        let cache = create_test_schema_cache();
1374        let config = SchemaConfig::default();
1375        let generated = build_schema(&cache, &config);
1376
1377        let fields = generate_subscription_fields(&cache, &generated);
1378
1379        // Should have one subscription field for the users table
1380        assert_eq!(fields.len(), 1);
1381        assert_eq!(fields[0].name, "users");
1382        assert_eq!(fields[0].table_name, "users");
1383        assert_eq!(fields[0].channel_name(), "postrust_public_users");
1384    }
1385
1386    #[test]
1387    fn test_create_subscription_type() {
1388        use crate::subscription::SubscriptionField as SubField;
1389
1390        let fields = vec![
1391            SubField::for_table("public", "users", "Users"),
1392            SubField::for_table("public", "orders", "Orders"),
1393        ];
1394
1395        let _subscription = create_subscription_type(&fields);
1396        // Just test that it doesn't panic
1397    }
1398}