use std::sync::atomic::{AtomicU64, Ordering};
use super::Executor;
use crate::{
db::traits::DatabaseAdapter,
error::Result,
graphql::{FieldSelection, GraphQLArgument, ParsedQuery},
};
static MULTI_ROOT_QUERIES_TOTAL: AtomicU64 = AtomicU64::new(0);
pub fn multi_root_queries_total() -> u64 {
MULTI_ROOT_QUERIES_TOTAL.load(Ordering::Relaxed)
}
#[derive(Debug)]
pub struct RootFieldResult {
pub field_name: String,
pub data: serde_json::Value,
}
#[derive(Debug)]
pub struct PipelineResult {
pub fields: Vec<RootFieldResult>,
pub parallel: bool,
}
impl PipelineResult {
#[must_use]
pub fn merge_into_data_map(&self) -> serde_json::Map<String, serde_json::Value> {
self.fields.iter().map(|f| (f.field_name.clone(), f.data.clone())).collect()
}
}
#[must_use]
pub const fn is_multi_root(parsed: &ParsedQuery) -> bool {
parsed.selections.len() > 1
}
#[must_use]
pub fn extract_root_field_names(parsed: &ParsedQuery) -> Vec<&str> {
parsed.selections.iter().map(|s| s.response_key()).collect()
}
pub(super) fn field_selection_to_query(field: &FieldSelection) -> String {
format!("{{ {} }}", serialize_field(field))
}
fn serialize_field(field: &FieldSelection) -> String {
let mut s = String::new();
if let Some(alias) = &field.alias {
s.push_str(alias);
s.push_str(": ");
}
s.push_str(&field.name);
if !field.arguments.is_empty() {
s.push('(');
let args: Vec<String> = field.arguments.iter().map(serialize_arg).collect();
s.push_str(&args.join(", "));
s.push(')');
}
if !field.nested_fields.is_empty() {
s.push_str(" { ");
let sub: Vec<String> = field.nested_fields.iter().map(serialize_field).collect();
s.push_str(&sub.join(" "));
s.push_str(" }");
}
s
}
fn serialize_arg(arg: &GraphQLArgument) -> String {
format!("{}: {}", arg.name, arg_value_to_graphql(arg))
}
fn arg_value_to_graphql(arg: &GraphQLArgument) -> String {
match arg.value_type.as_str() {
"variable" => {
serde_json::from_str::<String>(&arg.value_json)
.unwrap_or_else(|_| arg.value_json.clone())
},
"object" => {
serde_json::from_str::<serde_json::Value>(&arg.value_json)
.map_or_else(|_| arg.value_json.clone(), |v| json_value_to_graphql(&v))
},
"enum" => {
serde_json::from_str::<String>(&arg.value_json)
.unwrap_or_else(|_| arg.value_json.clone())
},
_ => arg.value_json.clone(),
}
}
fn json_value_to_graphql(val: &serde_json::Value) -> String {
match val {
serde_json::Value::Object(map) => {
let pairs: Vec<String> =
map.iter().map(|(k, v)| format!("{k}: {}", json_value_to_graphql(v))).collect();
format!("{{{}}}", pairs.join(", "))
},
serde_json::Value::Array(arr) => {
let items: Vec<String> = arr.iter().map(json_value_to_graphql).collect();
format!("[{}]", items.join(", "))
},
serde_json::Value::String(s) => format!("\"{s}\""),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Null => "null".to_string(),
}
}
impl<A: DatabaseAdapter> Executor<A> {
pub async fn execute_parallel(
&self,
parsed: &ParsedQuery,
variables: Option<&serde_json::Value>,
) -> Result<PipelineResult> {
MULTI_ROOT_QUERIES_TOTAL.fetch_add(1, Ordering::Relaxed);
let field_queries: Vec<(String, String)> = parsed
.selections
.iter()
.map(|f| (f.response_key().to_string(), field_selection_to_query(f)))
.collect();
let futs: Vec<_> = field_queries
.iter()
.map(|(_, query)| self.execute_regular_query(query.as_str(), variables))
.collect();
let results = futures::future::try_join_all(futs).await?;
let fields = results
.into_iter()
.zip(field_queries.iter())
.map(|(response, (field_name, _))| {
let data = response["data"][field_name.as_str()].clone();
Ok(RootFieldResult {
field_name: field_name.clone(),
data,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(PipelineResult {
fields,
parallel: true,
})
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use std::sync::Arc;
use async_trait::async_trait;
use super::*;
use crate::{
db::{
WhereClause,
types::{DatabaseType, JsonbValue, OrderByClause, PoolMetrics},
},
graphql::parse_query,
runtime::Executor,
schema::{CompiledSchema, QueryDefinition, SqlProjectionHint},
};
fn parsed(query: &str) -> ParsedQuery {
parse_query(query).expect("valid query")
}
fn make_schema_with_queries(names: &[(&str, &str)]) -> CompiledSchema {
let mut schema = CompiledSchema::default();
for (name, sql_source) in names {
let mut qd = QueryDefinition::new(*name, "SomeType");
qd.sql_source = Some((*sql_source).to_string());
qd.returns_list = true;
schema.queries.push(qd);
}
schema
}
struct MockAdapter;
#[async_trait]
impl crate::db::traits::DatabaseAdapter for MockAdapter {
async fn execute_where_query(
&self,
_view: &str,
_where_clause: Option<&WhereClause>,
_limit: Option<u32>,
_offset: Option<u32>,
_order_by: Option<&[OrderByClause]>,
) -> crate::error::Result<Vec<JsonbValue>> {
Ok(vec![])
}
async fn execute_with_projection(
&self,
_view: &str,
_projection: Option<&SqlProjectionHint>,
_where_clause: Option<&WhereClause>,
_limit: Option<u32>,
_offset: Option<u32>,
_order_by: Option<&[OrderByClause]>,
) -> crate::error::Result<Vec<JsonbValue>> {
Ok(vec![JsonbValue::new(serde_json::json!({"id": 1}))])
}
fn database_type(&self) -> DatabaseType {
DatabaseType::SQLite
}
async fn health_check(&self) -> crate::error::Result<()> {
Ok(())
}
fn pool_metrics(&self) -> PoolMetrics {
PoolMetrics {
total_connections: 1,
idle_connections: 1,
active_connections: 0,
waiting_requests: 0,
}
}
async fn execute_raw_query(
&self,
_sql: &str,
) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
{
Ok(vec![])
}
async fn execute_parameterized_aggregate(
&self,
_sql: &str,
_params: &[serde_json::Value],
) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
{
Ok(vec![])
}
}
fn make_executor(names: &[(&str, &str)]) -> Executor<MockAdapter> {
let schema = make_schema_with_queries(names);
Executor::new(schema, Arc::new(MockAdapter))
}
#[test]
fn test_is_multi_root_single() {
assert!(!is_multi_root(&parsed("{ users { id } }")));
}
#[test]
fn test_is_multi_root_two_roots() {
assert!(is_multi_root(&parsed("{ users { id } posts { id } }")));
}
#[test]
fn test_is_multi_root_three_roots() {
assert!(is_multi_root(&parsed("{ users { id } posts { id } orders { id } }")));
}
#[test]
fn test_extract_root_field_names_single() {
let p = parsed("{ users { id } }");
assert_eq!(extract_root_field_names(&p), vec!["users"]);
}
#[test]
fn test_extract_root_field_names_two() {
let p = parsed("{ users { id } posts { id } }");
assert_eq!(extract_root_field_names(&p), vec!["users", "posts"]);
}
#[test]
fn test_serializer_simple_field() {
let p = parsed("{ users { id name } }");
let field = &p.selections[0];
let q = field_selection_to_query(field);
assert!(q.contains("users"), "missing field name: {q}");
assert!(q.contains("id"), "missing subfield: {q}");
assert!(q.contains("name"), "missing subfield: {q}");
}
#[test]
fn test_serializer_scalar_arg() {
let p = parsed("{ users(limit: 10) { id } }");
let field = &p.selections[0];
let q = field_selection_to_query(field);
assert!(q.contains("limit"), "missing arg: {q}");
assert!(q.contains("10"), "missing value: {q}");
}
#[test]
fn test_serializer_roundtrip_is_parseable() {
let original = "{ users { id name } }";
let p = parsed(original);
let synthetic = field_selection_to_query(&p.selections[0]);
parse_query(&synthetic).expect("synthetic query must be valid GraphQL");
}
#[tokio::test]
async fn test_execute_parallel_returns_all_fields() {
let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
let p = parsed("{ users { id } posts { id } }");
let result = exec.execute_parallel(&p, None).await.unwrap();
assert_eq!(result.fields.len(), 2);
assert!(result.fields.iter().any(|f| f.field_name == "users"));
assert!(result.fields.iter().any(|f| f.field_name == "posts"));
assert!(result.parallel);
}
#[tokio::test]
async fn test_execute_parallel_merges_data_correctly() {
let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
let p = parsed("{ users { id } posts { id } }");
let result = exec.execute_parallel(&p, None).await.unwrap();
let merged = result.merge_into_data_map();
assert!(merged.contains_key("users"), "missing users key");
assert!(merged.contains_key("posts"), "missing posts key");
}
#[tokio::test]
async fn test_single_root_unaffected() {
let exec = make_executor(&[("users", "v_users")]);
let val = exec.execute("{ users { id } }", None).await.unwrap();
assert!(val["data"]["users"].is_array());
}
#[tokio::test]
async fn test_multi_root_counter_increments() {
let before = multi_root_queries_total();
let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
let p = parsed("{ users { id } posts { id } }");
exec.execute_parallel(&p, None).await.unwrap();
assert!(multi_root_queries_total() > before);
}
}