1use crate::context::GraphQLContext;
7use crate::error::GraphQLError;
8use crate::schema::object::TableObjectType;
9use crate::schema::{build_schema, GeneratedSchema, MutationType, SchemaConfig};
10use async_graphql::dynamic::*;
11use async_graphql::Value;
12use async_graphql_axum::{GraphQLRequest, GraphQLResponse};
13use axum::extract::State;
14use postrust_core::schema_cache::SchemaCache;
15use sqlx::PgPool;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tracing::debug;
19
20pub struct GraphQLState {
22 pub pool: PgPool,
24 pub schema_cache: Arc<SchemaCache>,
26 pub generated_schema: GeneratedSchema,
28 pub schema: Schema,
30 pub config: SchemaConfig,
32}
33
34impl GraphQLState {
35 pub fn new(
37 pool: PgPool,
38 schema_cache: Arc<SchemaCache>,
39 config: SchemaConfig,
40 ) -> Result<Self, GraphQLError> {
41 let generated_schema = build_schema(&schema_cache, &config);
42 let schema = build_dynamic_schema(&generated_schema, &schema_cache)?;
43
44 Ok(Self {
45 pool,
46 schema_cache,
47 generated_schema,
48 schema,
49 config,
50 })
51 }
52
53 pub fn rebuild(&mut self) -> Result<(), GraphQLError> {
55 self.generated_schema = build_schema(&self.schema_cache, &self.config);
56 self.schema = build_dynamic_schema(&self.generated_schema, &self.schema_cache)?;
57 Ok(())
58 }
59}
60
61pub async fn graphql_handler(
63 State(state): State<Arc<GraphQLState>>,
64 ctx: GraphQLContext,
65 req: GraphQLRequest,
66) -> GraphQLResponse {
67 let request = req.into_inner().data(ctx).data(state.pool.clone());
68 state.schema.execute(request).await.into()
69}
70
71pub async fn graphql_playground() -> impl axum::response::IntoResponse {
73 axum::response::Html(async_graphql::http::playground_source(
74 async_graphql::http::GraphQLPlaygroundConfig::new("/graphql"),
75 ))
76}
77
78fn build_dynamic_schema(
80 generated: &GeneratedSchema,
81 _schema_cache: &SchemaCache,
82) -> Result<Schema, GraphQLError> {
83 let mut object_types: HashMap<String, Object> = HashMap::new();
85
86 for (type_name, obj) in &generated.object_types {
87 let table_obj = create_object_type(obj);
88 object_types.insert(type_name.clone(), table_obj);
89 }
90
91 let query = create_query_type(generated);
93
94 let mutation = if !generated.mutation_fields.is_empty() {
96 Some(create_mutation_type(generated))
97 } else {
98 None
99 };
100
101 let mut builder = Schema::build("Query", mutation.as_ref().map(|_| "Mutation"), None);
103
104 for (_, obj) in object_types {
106 builder = builder.register(obj);
107 }
108
109 builder = builder.register(query);
111
112 if let Some(mutation) = mutation {
114 builder = builder.register(mutation);
115 }
116
117 builder = builder.register(create_bigint_scalar());
119 builder = builder.register(create_bigdecimal_scalar());
120 builder = builder.register(create_json_scalar());
121 builder = builder.register(create_uuid_scalar());
122 builder = builder.register(create_date_scalar());
123 builder = builder.register(create_datetime_scalar());
124 builder = builder.register(create_time_scalar());
125
126 builder = register_filter_input_types(builder);
128
129 builder
130 .finish()
131 .map_err(|e| GraphQLError::SchemaError(e.to_string()))
132}
133
134fn create_object_type(obj: &TableObjectType) -> Object {
136 let mut object = Object::new(&obj.name);
137
138 if let Some(desc) = obj.description() {
139 object = object.description(desc);
140 }
141
142 for field in &obj.fields {
143 let field_type = graphql_type_ref(&field.type_string());
144 let mut gql_field = Field::new(&field.name, field_type, |_| {
145 FieldFuture::new(async move { Ok(None::<FieldValue>) })
146 });
147
148 if let Some(desc) = &field.description {
149 gql_field = gql_field.description(desc);
150 }
151
152 object = object.field(gql_field);
153 }
154
155 object
156}
157
158fn create_query_type(generated: &GeneratedSchema) -> Object {
160 let mut query = Object::new("Query");
161
162 for field in &generated.query_fields {
163 let table_name = field.table_name.clone();
164 let is_by_pk = field.is_by_pk;
165 let return_type = graphql_type_ref(&field.return_type);
166
167 let mut gql_field = Field::new(&field.name, return_type, move |ctx| {
168 let table_name = table_name.clone();
169 FieldFuture::new(async move {
170 resolve_query(&ctx, &table_name, is_by_pk).await
171 })
172 });
173
174 if !is_by_pk {
176 gql_field = gql_field
177 .argument(InputValue::new("filter", TypeRef::named("JSON")))
178 .argument(InputValue::new("orderBy", TypeRef::named_list("String")))
179 .argument(InputValue::new("limit", TypeRef::named("Int")))
180 .argument(InputValue::new("offset", TypeRef::named("Int")));
181 } else {
182 gql_field = gql_field.argument(InputValue::new("id", TypeRef::named_nn("Int")));
184 }
185
186 if let Some(desc) = &field.description {
187 gql_field = gql_field.description(desc);
188 }
189
190 query = query.field(gql_field);
191 }
192
193 query = query.field(
195 Field::new("_schema", TypeRef::named("String"), |_| {
196 FieldFuture::new(async move {
197 Ok(Some(Value::String("Postrust GraphQL Schema".to_string())))
198 })
199 })
200 .description("Schema introspection"),
201 );
202
203 query
204}
205
206fn create_mutation_type(generated: &GeneratedSchema) -> Object {
208 let mut mutation = Object::new("Mutation");
209
210 for field in &generated.mutation_fields {
211 let table_name = field.table_name.clone();
212 let mutation_type = field.mutation_type;
213 let return_type = graphql_type_ref(&field.return_type);
214
215 let mut gql_field = Field::new(&field.name, return_type, move |ctx| {
216 let table_name = table_name.clone();
217 FieldFuture::new(async move {
218 resolve_mutation(&ctx, &table_name, mutation_type).await
219 })
220 });
221
222 match mutation_type {
224 MutationType::Insert | MutationType::InsertOne => {
225 gql_field = gql_field
226 .argument(InputValue::new("objects", TypeRef::named_nn_list("JSON")));
227 }
228 MutationType::Update | MutationType::UpdateByPk => {
229 gql_field = gql_field
230 .argument(InputValue::new("where", TypeRef::named("JSON")))
231 .argument(InputValue::new("set", TypeRef::named_nn("JSON")));
232 }
233 MutationType::Delete | MutationType::DeleteByPk => {
234 gql_field = gql_field.argument(InputValue::new("where", TypeRef::named("JSON")));
235 }
236 }
237
238 if let Some(desc) = &field.description {
239 gql_field = gql_field.description(desc);
240 }
241
242 mutation = mutation.field(gql_field);
243 }
244
245 mutation
246}
247
248async fn resolve_query(
250 ctx: &ResolverContext<'_>,
251 table_name: &str,
252 is_by_pk: bool,
253) -> Result<Option<Value>, async_graphql::Error> {
254 let pool = ctx.data::<PgPool>()?;
255 let gql_ctx = ctx.data::<GraphQLContext>()?;
256
257 debug!("Resolving query for table: {}", table_name);
258
259 let limit: Option<i64> = ctx
261 .args
262 .try_get("limit")
263 .ok()
264 .and_then(|v| v.i64().ok());
265
266 let offset: Option<i64> = ctx
267 .args
268 .try_get("offset")
269 .ok()
270 .and_then(|v| v.i64().ok());
271
272 let mut sql = format!(
274 "SELECT row_to_json(t) FROM (SELECT * FROM public.{}) t",
275 table_name
276 );
277
278 if let Some(limit) = limit {
279 sql.push_str(&format!(" LIMIT {}", limit));
280 }
281
282 if let Some(offset) = offset {
283 sql.push_str(&format!(" OFFSET {}", offset));
284 }
285
286 let result = execute_query(pool, &sql, gql_ctx.role()).await?;
288
289 if is_by_pk {
290 Ok(result.first().cloned())
291 } else {
292 Ok(Some(Value::List(result)))
293 }
294}
295
296async fn resolve_mutation(
298 ctx: &ResolverContext<'_>,
299 table_name: &str,
300 mutation_type: MutationType,
301) -> Result<Option<Value>, async_graphql::Error> {
302 let pool = ctx.data::<PgPool>()?;
303 let gql_ctx = ctx.data::<GraphQLContext>()?;
304
305 debug!("Resolving mutation for table: {} type: {:?}", table_name, mutation_type);
306
307 let result = match mutation_type {
308 MutationType::Insert | MutationType::InsertOne => {
309 let objects = ctx
310 .args
311 .try_get("objects")
312 .ok()
313 .map(|v| accessor_to_json(&v))
314 .unwrap_or_else(|| serde_json::Value::Array(vec![]));
315
316 execute_insert(pool, table_name, gql_ctx.role(), objects).await?
317 }
318 MutationType::Update | MutationType::UpdateByPk => {
319 let set_value = ctx
320 .args
321 .try_get("set")
322 .ok()
323 .map(|v| accessor_to_json(&v))
324 .unwrap_or_else(|| serde_json::json!({}));
325
326 execute_update(pool, table_name, gql_ctx.role(), set_value).await?
327 }
328 MutationType::Delete | MutationType::DeleteByPk => {
329 execute_delete(pool, table_name, gql_ctx.role()).await?
330 }
331 };
332
333 Ok(Some(result))
334}
335
336async fn execute_query(
338 pool: &PgPool,
339 sql: &str,
340 role: &str,
341) -> Result<Vec<Value>, async_graphql::Error> {
342 use sqlx::Row;
343
344 debug!("Executing SQL: {}", sql);
345
346 let mut conn = pool.acquire().await?;
347
348 sqlx::query(&format!("SET LOCAL ROLE {}", postrust_sql::escape_ident(role)))
350 .execute(&mut *conn)
351 .await?;
352
353 let rows = sqlx::query(sql).fetch_all(&mut *conn).await?;
355
356 let results: Vec<Value> = rows
358 .iter()
359 .filter_map(|row| {
360 row.try_get::<serde_json::Value, _>(0)
361 .ok()
362 .map(json_to_value)
363 })
364 .collect();
365
366 Ok(results)
367}
368
369async fn execute_insert(
371 _pool: &PgPool,
372 table_name: &str,
373 _role: &str,
374 objects: serde_json::Value,
375) -> Result<Value, async_graphql::Error> {
376 debug!("Insert mutation for {}: {:?}", table_name, objects);
378 Ok(Value::List(vec![]))
379}
380
381async fn execute_update(
383 _pool: &PgPool,
384 table_name: &str,
385 _role: &str,
386 set_value: serde_json::Value,
387) -> Result<Value, async_graphql::Error> {
388 debug!("Update mutation for {}: {:?}", table_name, set_value);
390 Ok(Value::List(vec![]))
391}
392
393async fn execute_delete(
395 _pool: &PgPool,
396 table_name: &str,
397 _role: &str,
398) -> Result<Value, async_graphql::Error> {
399 debug!("Delete mutation for {}", table_name);
401 Ok(Value::List(vec![]))
402}
403
404fn graphql_type_ref(type_str: &str) -> TypeRef {
406 let is_list = type_str.starts_with('[');
408 let is_nn = type_str.ends_with('!');
409
410 let inner = if is_list {
412 let stripped = type_str
413 .trim_end_matches('!') .trim_start_matches('[') .trim_end_matches(']'); stripped
417 } else {
418 type_str.trim_end_matches('!')
419 };
420
421 let inner_nn = inner.ends_with('!');
422 let base_type = inner.trim_end_matches('!');
423
424 if is_list {
425 if is_nn {
426 if inner_nn {
427 TypeRef::named_nn_list_nn(base_type)
428 } else {
429 TypeRef::named_list_nn(base_type)
430 }
431 } else if inner_nn {
432 TypeRef::named_nn_list(base_type)
433 } else {
434 TypeRef::named_list(base_type)
435 }
436 } else if is_nn {
437 TypeRef::named_nn(base_type)
438 } else {
439 TypeRef::named(base_type)
440 }
441}
442
443fn accessor_to_json(accessor: &ValueAccessor<'_>) -> serde_json::Value {
445 if accessor.is_null() {
447 serde_json::Value::Null
448 } else if let Ok(b) = accessor.boolean() {
449 serde_json::Value::Bool(b)
450 } else if let Ok(i) = accessor.i64() {
451 serde_json::Value::Number(i.into())
452 } else if let Ok(f) = accessor.f64() {
453 serde_json::Number::from_f64(f)
454 .map(serde_json::Value::Number)
455 .unwrap_or(serde_json::Value::Null)
456 } else if let Ok(s) = accessor.string() {
457 serde_json::Value::String(s.to_string())
458 } else if let Ok(list) = accessor.list() {
459 serde_json::Value::Array(
460 list.iter()
461 .map(|v| accessor_to_json(&v))
462 .collect()
463 )
464 } else if let Ok(obj) = accessor.object() {
465 let map: serde_json::Map<String, serde_json::Value> = obj
466 .iter()
467 .map(|(k, v)| (k.to_string(), accessor_to_json(&v)))
468 .collect();
469 serde_json::Value::Object(map)
470 } else {
471 serde_json::Value::Null
472 }
473}
474
475fn value_to_json(value: &Value) -> serde_json::Value {
477 match value {
478 Value::Null => serde_json::Value::Null,
479 Value::Boolean(b) => serde_json::Value::Bool(*b),
480 Value::Number(n) => {
481 if let Some(i) = n.as_i64() {
482 serde_json::Value::Number(i.into())
483 } else if let Some(f) = n.as_f64() {
484 serde_json::Value::Number(serde_json::Number::from_f64(f).unwrap())
485 } else {
486 serde_json::Value::Null
487 }
488 }
489 Value::String(s) => serde_json::Value::String(s.clone()),
490 Value::List(arr) => {
491 serde_json::Value::Array(arr.iter().map(value_to_json).collect())
492 }
493 Value::Object(obj) => {
494 let map: serde_json::Map<String, serde_json::Value> = obj
495 .iter()
496 .map(|(k, v)| (k.to_string(), value_to_json(v)))
497 .collect();
498 serde_json::Value::Object(map)
499 }
500 Value::Binary(b) => serde_json::Value::String(base64::Engine::encode(
501 &base64::engine::general_purpose::STANDARD,
502 b,
503 )),
504 Value::Enum(e) => serde_json::Value::String(e.to_string()),
505 }
506}
507
508fn json_to_value(json: serde_json::Value) -> Value {
510 match json {
511 serde_json::Value::Null => Value::Null,
512 serde_json::Value::Bool(b) => Value::Boolean(b),
513 serde_json::Value::Number(n) => {
514 if let Some(i) = n.as_i64() {
515 Value::Number(i.into())
516 } else if let Some(f) = n.as_f64() {
517 Value::Number(async_graphql::Number::from_f64(f).unwrap())
518 } else {
519 Value::Null
520 }
521 }
522 serde_json::Value::String(s) => Value::String(s),
523 serde_json::Value::Array(arr) => {
524 Value::List(arr.into_iter().map(json_to_value).collect())
525 }
526 serde_json::Value::Object(obj) => {
527 let map: indexmap::IndexMap<async_graphql::Name, Value> = obj
528 .into_iter()
529 .map(|(k, v)| (async_graphql::Name::new(k), json_to_value(v)))
530 .collect();
531 Value::Object(map)
532 }
533 }
534}
535
536fn create_bigint_scalar() -> Scalar {
538 Scalar::new("BigInt")
539 .description("64-bit integer")
540 .specified_by_url("https://spec.graphql.org/draft/#sec-Int")
541}
542
543fn create_bigdecimal_scalar() -> Scalar {
545 Scalar::new("BigDecimal")
546 .description("Arbitrary precision decimal number")
547}
548
549fn create_json_scalar() -> Scalar {
551 Scalar::new("JSON")
552 .description("Arbitrary JSON value")
553 .specified_by_url("https://spec.graphql.org/draft/#sec-Scalars")
554}
555
556fn create_uuid_scalar() -> Scalar {
558 Scalar::new("UUID").description("UUID string")
559}
560
561fn create_date_scalar() -> Scalar {
563 Scalar::new("Date").description("ISO 8601 date string (YYYY-MM-DD)")
564}
565
566fn create_datetime_scalar() -> Scalar {
568 Scalar::new("DateTime").description("ISO 8601 datetime string")
569}
570
571fn create_time_scalar() -> Scalar {
573 Scalar::new("Time").description("ISO 8601 time string (HH:MM:SS)")
574}
575
576fn register_filter_input_types(builder: SchemaBuilder) -> SchemaBuilder {
578 let string_filter = InputObject::new("StringFilterInput")
579 .field(InputValue::new("eq", TypeRef::named("String")))
580 .field(InputValue::new("neq", TypeRef::named("String")))
581 .field(InputValue::new("like", TypeRef::named("String")))
582 .field(InputValue::new("ilike", TypeRef::named("String")))
583 .field(InputValue::new("in", TypeRef::named_list("String")))
584 .field(InputValue::new("isNull", TypeRef::named("Boolean")));
585
586 let int_filter = InputObject::new("IntFilterInput")
587 .field(InputValue::new("eq", TypeRef::named("Int")))
588 .field(InputValue::new("neq", TypeRef::named("Int")))
589 .field(InputValue::new("gt", TypeRef::named("Int")))
590 .field(InputValue::new("gte", TypeRef::named("Int")))
591 .field(InputValue::new("lt", TypeRef::named("Int")))
592 .field(InputValue::new("lte", TypeRef::named("Int")))
593 .field(InputValue::new("in", TypeRef::named_list("Int")));
594
595 let boolean_filter = InputObject::new("BooleanFilterInput")
596 .field(InputValue::new("eq", TypeRef::named("Boolean")));
597
598 builder
599 .register(string_filter)
600 .register(int_filter)
601 .register(boolean_filter)
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use indexmap::IndexMap;
608 use postrust_core::schema_cache::{Column, Table};
609 use std::collections::{HashMap, HashSet};
610
611 fn create_test_table(name: &str) -> Table {
612 let mut columns = IndexMap::new();
613 columns.insert(
614 "id".into(),
615 Column {
616 name: "id".into(),
617 description: None,
618 nullable: false,
619 data_type: "integer".into(),
620 nominal_type: "int4".into(),
621 max_len: None,
622 default: Some("nextval('id_seq')".into()),
623 enum_values: vec![],
624 is_pk: true,
625 position: 1,
626 },
627 );
628 columns.insert(
629 "name".into(),
630 Column {
631 name: "name".into(),
632 description: None,
633 nullable: false,
634 data_type: "text".into(),
635 nominal_type: "text".into(),
636 max_len: None,
637 default: None,
638 enum_values: vec![],
639 is_pk: false,
640 position: 2,
641 },
642 );
643
644 Table {
645 schema: "public".into(),
646 name: name.into(),
647 description: None,
648 is_view: false,
649 insertable: true,
650 updatable: true,
651 deletable: true,
652 pk_cols: vec!["id".into()],
653 columns,
654 }
655 }
656
657 fn create_test_schema_cache() -> SchemaCache {
658 let mut tables = HashMap::new();
659 let users = create_test_table("users");
660 tables.insert(users.qualified_identifier(), users);
661
662 SchemaCache {
663 tables,
664 relationships: HashMap::new(),
665 routines: HashMap::new(),
666 timezones: HashSet::new(),
667 pg_version: 150000,
668 }
669 }
670
671 #[test]
676 fn test_graphql_type_ref_simple() {
677 let _type_ref = graphql_type_ref("String");
678 }
680
681 #[test]
682 fn test_graphql_type_ref_non_null() {
683 let _type_ref = graphql_type_ref("String!");
684 }
685
686 #[test]
687 fn test_graphql_type_ref_list() {
688 let _type_ref = graphql_type_ref("[String]");
689 }
690
691 #[test]
692 fn test_graphql_type_ref_list_non_null() {
693 let _type_ref = graphql_type_ref("[String!]!");
694 }
695
696 #[test]
701 fn test_value_to_json_null() {
702 let value = Value::Null;
703 let json = value_to_json(&value);
704 assert_eq!(json, serde_json::Value::Null);
705 }
706
707 #[test]
708 fn test_value_to_json_boolean() {
709 let value = Value::Boolean(true);
710 let json = value_to_json(&value);
711 assert_eq!(json, serde_json::Value::Bool(true));
712 }
713
714 #[test]
715 fn test_value_to_json_number() {
716 let value = Value::Number(42.into());
717 let json = value_to_json(&value);
718 assert_eq!(json, serde_json::json!(42));
719 }
720
721 #[test]
722 fn test_value_to_json_string() {
723 let value = Value::String("hello".to_string());
724 let json = value_to_json(&value);
725 assert_eq!(json, serde_json::Value::String("hello".to_string()));
726 }
727
728 #[test]
729 fn test_value_to_json_list() {
730 let value = Value::List(vec![Value::Number(1.into()), Value::Number(2.into())]);
731 let json = value_to_json(&value);
732 assert_eq!(json, serde_json::json!([1, 2]));
733 }
734
735 #[test]
736 fn test_json_to_value_null() {
737 let json = serde_json::Value::Null;
738 let value = json_to_value(json);
739 assert!(matches!(value, Value::Null));
740 }
741
742 #[test]
743 fn test_json_to_value_boolean() {
744 let json = serde_json::Value::Bool(false);
745 let value = json_to_value(json);
746 assert!(matches!(value, Value::Boolean(false)));
747 }
748
749 #[test]
750 fn test_json_to_value_number() {
751 let json = serde_json::json!(123);
752 let value = json_to_value(json);
753 assert!(matches!(value, Value::Number(_)));
754 }
755
756 #[test]
757 fn test_json_to_value_string() {
758 let json = serde_json::Value::String("test".to_string());
759 let value = json_to_value(json);
760 assert!(matches!(value, Value::String(_)));
761 }
762
763 #[test]
764 fn test_json_to_value_array() {
765 let json = serde_json::json!([1, 2, 3]);
766 let value = json_to_value(json);
767 assert!(matches!(value, Value::List(_)));
768 }
769
770 #[test]
771 fn test_json_to_value_object() {
772 let json = serde_json::json!({"key": "value"});
773 let value = json_to_value(json);
774 assert!(matches!(value, Value::Object(_)));
775 }
776
777 #[test]
782 fn test_build_dynamic_schema() {
783 let cache = create_test_schema_cache();
784 let config = SchemaConfig::default();
785 let generated = build_schema(&cache, &config);
786
787 let result = build_dynamic_schema(&generated, &cache);
788 if let Err(ref e) = result {
789 eprintln!("Schema build error: {:?}", e);
790 }
791 assert!(result.is_ok(), "Schema build failed: {:?}", result.err());
792 }
793
794 #[test]
795 fn test_create_object_type() {
796 let table = create_test_table("users");
797 let obj = TableObjectType::from_table(&table);
798 let _gql_obj = create_object_type(&obj);
799 }
800
801 #[test]
802 fn test_create_query_type() {
803 let cache = create_test_schema_cache();
804 let config = SchemaConfig::default();
805 let generated = build_schema(&cache, &config);
806
807 let _query = create_query_type(&generated);
808 }
809
810 #[test]
811 fn test_create_mutation_type() {
812 let cache = create_test_schema_cache();
813 let config = SchemaConfig::default();
814 let generated = build_schema(&cache, &config);
815
816 let _mutation = create_mutation_type(&generated);
817 }
818
819 #[test]
824 fn test_create_scalars() {
825 let _bigint = create_bigint_scalar();
826 let _json = create_json_scalar();
827 let _uuid = create_uuid_scalar();
828 let _datetime = create_datetime_scalar();
829 }
830
831 #[test]
836 fn test_register_filter_input_types() {
837 let cache = create_test_schema_cache();
838 let config = SchemaConfig::default();
839 let _generated = build_schema(&cache, &config);
840
841 let query = Object::new("Query").field(Field::new(
843 "test",
844 TypeRef::named("String"),
845 |_| FieldFuture::new(async { Ok(None::<FieldValue>) }),
846 ));
847
848 let mut builder = Schema::build("Query", None::<&str>, None);
849 builder = builder.register(query);
850 builder = register_filter_input_types(builder);
851
852 let result = builder.finish();
853 assert!(result.is_ok());
854 }
855}