Skip to main content

fraiseql_core/runtime/executor/
pipeline.rs

1//! Multi-root query pipelining — parallel execution of independent query roots.
2//!
3//! Dispatches multi-root GraphQL queries concurrently using
4//! [`futures::future::try_join_all`], then merges the results into a single
5//! `{ "data": { ... } }` envelope.
6//!
7//! # Example
8//!
9//! ```text
10//! { users { id name } posts { id title } }
11//! ```
12//!
13//! Without pipelining: `t_users + t_posts` latency (sequential).
14//! With pipelining:    `max(t_users, t_posts)` latency (concurrent).
15
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use super::Executor;
19use crate::{
20    db::traits::DatabaseAdapter,
21    error::Result,
22    graphql::{FieldSelection, GraphQLArgument, ParsedQuery},
23};
24
25// ── Prometheus counter ────────────────────────────────────────────────────────
26
27static MULTI_ROOT_QUERIES_TOTAL: AtomicU64 = AtomicU64::new(0);
28
29/// Total multi-root GraphQL queries dispatched via the parallel execution path.
30pub fn multi_root_queries_total() -> u64 {
31    MULTI_ROOT_QUERIES_TOTAL.load(Ordering::Relaxed)
32}
33
34// ── Result types ──────────────────────────────────────────────────────────────
35
36/// Result for a single root field in a pipelined execution.
37#[derive(Debug)]
38pub struct RootFieldResult {
39    /// Response key for this field (alias if provided, otherwise field name).
40    pub field_name: String,
41    /// Resolved data value.
42    pub data:       serde_json::Value,
43}
44
45/// Aggregated result from a multi-root parallel execution.
46#[derive(Debug)]
47pub struct PipelineResult {
48    /// Results for each root field, in the order they were requested.
49    pub fields:   Vec<RootFieldResult>,
50    /// `true` when results were produced by the parallel path.
51    pub parallel: bool,
52}
53
54impl PipelineResult {
55    /// Merge all field results into a single JSON map.
56    ///
57    /// Returns a `serde_json::Map` suitable for embedding in a `"data"` envelope.
58    #[must_use]
59    pub fn merge_into_data_map(&self) -> serde_json::Map<String, serde_json::Value> {
60        self.fields.iter().map(|f| (f.field_name.clone(), f.data.clone())).collect()
61    }
62}
63
64// ── Detection helpers ─────────────────────────────────────────────────────────
65
66/// Returns `true` when the query has more than one root field selection.
67///
68/// Only applies to anonymous queries and `query { ... }` operations; mutations
69/// and subscriptions are not affected.
70#[must_use]
71pub const fn is_multi_root(parsed: &ParsedQuery) -> bool {
72    parsed.selections.len() > 1
73}
74
75/// Returns the response key (alias or field name) for every root-level selection.
76#[must_use]
77pub fn extract_root_field_names(parsed: &ParsedQuery) -> Vec<&str> {
78    parsed.selections.iter().map(|s| s.response_key()).collect()
79}
80
81// ── Query-string serializer ───────────────────────────────────────────────────
82
83/// Serialize a root `FieldSelection` to a valid GraphQL query string.
84///
85/// Produces `{ fieldName(arg: value) { sub1 sub2 { ... } } }`.
86/// Variables are preserved as `$varName` references; inline values are
87/// converted from their stored JSON representation to GraphQL syntax.
88pub(super) fn field_selection_to_query(field: &FieldSelection) -> String {
89    format!("{{ {} }}", serialize_field(field))
90}
91
92fn serialize_field(field: &FieldSelection) -> String {
93    let mut s = String::new();
94
95    // Alias prefix
96    if let Some(alias) = &field.alias {
97        s.push_str(alias);
98        s.push_str(": ");
99    }
100    s.push_str(&field.name);
101
102    // Arguments
103    if !field.arguments.is_empty() {
104        s.push('(');
105        let args: Vec<String> = field.arguments.iter().map(serialize_arg).collect();
106        s.push_str(&args.join(", "));
107        s.push(')');
108    }
109
110    // Nested sub-selections
111    if !field.nested_fields.is_empty() {
112        s.push_str(" { ");
113        let sub: Vec<String> = field.nested_fields.iter().map(serialize_field).collect();
114        s.push_str(&sub.join(" "));
115        s.push_str(" }");
116    }
117
118    s
119}
120
121fn serialize_arg(arg: &GraphQLArgument) -> String {
122    format!("{}: {}", arg.name, arg_value_to_graphql(arg))
123}
124
125/// Convert a stored `GraphQLArgument` back to a GraphQL-syntax value.
126fn arg_value_to_graphql(arg: &GraphQLArgument) -> String {
127    match arg.value_type.as_str() {
128        "variable" => {
129            // value_json is stored as a JSON string e.g. `"\"$varName\""`.
130            // Parse it to get the raw `$varName`.
131            serde_json::from_str::<String>(&arg.value_json)
132                .unwrap_or_else(|_| arg.value_json.clone())
133        },
134        "object" => {
135            // JSON objects use quoted keys; GraphQL objects don't.
136            serde_json::from_str::<serde_json::Value>(&arg.value_json)
137                .map_or_else(|_| arg.value_json.clone(), |v| json_value_to_graphql(&v))
138        },
139        "enum" => {
140            // Strip surrounding JSON quotes from enum values.
141            serde_json::from_str::<String>(&arg.value_json)
142                .unwrap_or_else(|_| arg.value_json.clone())
143        },
144        // int, float, boolean, null, string, list — value_json is already valid GraphQL.
145        _ => arg.value_json.clone(),
146    }
147}
148
149/// Recursively convert a `serde_json::Value` to GraphQL value syntax.
150fn json_value_to_graphql(val: &serde_json::Value) -> String {
151    match val {
152        serde_json::Value::Object(map) => {
153            let pairs: Vec<String> =
154                map.iter().map(|(k, v)| format!("{k}: {}", json_value_to_graphql(v))).collect();
155            format!("{{{}}}", pairs.join(", "))
156        },
157        serde_json::Value::Array(arr) => {
158            let items: Vec<String> = arr.iter().map(json_value_to_graphql).collect();
159            format!("[{}]", items.join(", "))
160        },
161        serde_json::Value::String(s) => format!("\"{s}\""),
162        serde_json::Value::Number(n) => n.to_string(),
163        serde_json::Value::Bool(b) => b.to_string(),
164        serde_json::Value::Null => "null".to_string(),
165    }
166}
167
168// ── Parallel execution ────────────────────────────────────────────────────────
169
170impl<A: DatabaseAdapter> Executor<A> {
171    /// Execute all root fields of a multi-root query concurrently.
172    ///
173    /// Each root field is dispatched as an independent single-root query.
174    /// Results are awaited with [`futures::future::try_join_all`] and merged
175    /// into a `PipelineResult`.
176    ///
177    /// # Errors
178    ///
179    /// Returns the first error encountered across all concurrent sub-queries.
180    pub async fn execute_parallel(
181        &self,
182        parsed: &ParsedQuery,
183        variables: Option<&serde_json::Value>,
184    ) -> Result<PipelineResult> {
185        MULTI_ROOT_QUERIES_TOTAL.fetch_add(1, Ordering::Relaxed);
186
187        // Pre-compute synthetic single-root query strings (owned — avoids borrow
188        // lifetime entanglement between iterations and the final zip).
189        let field_queries: Vec<(String, String)> = parsed
190            .selections
191            .iter()
192            .map(|f| (f.response_key().to_string(), field_selection_to_query(f)))
193            .collect();
194
195        // Create all futures in a Vec; each borrows `self` and a slice of `field_queries`.
196        // Both borrows are valid for the lifetime of `execute_parallel`.
197        let futs: Vec<_> = field_queries
198            .iter()
199            .map(|(_, query)| self.execute_regular_query(query.as_str(), variables))
200            .collect();
201
202        // Drive all futures concurrently (single-threaded cooperative multitasking).
203        let results = futures::future::try_join_all(futs).await?;
204
205        // Extract the per-field `data` from each `{"data":{"field":[...]}}` Value response.
206        let fields = results
207            .into_iter()
208            .zip(field_queries.iter())
209            .map(|(response, (field_name, _))| {
210                let data = response["data"][field_name.as_str()].clone();
211                Ok(RootFieldResult {
212                    field_name: field_name.clone(),
213                    data,
214                })
215            })
216            .collect::<Result<Vec<_>>>()?;
217
218        Ok(PipelineResult {
219            fields,
220            parallel: true,
221        })
222    }
223}
224
225// ── Tests ─────────────────────────────────────────────────────────────────────
226
227#[cfg(test)]
228mod tests {
229    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
230
231    use std::sync::Arc;
232
233    use async_trait::async_trait;
234
235    use super::*;
236    use crate::{
237        db::{
238            WhereClause,
239            types::{DatabaseType, JsonbValue, OrderByClause, PoolMetrics},
240        },
241        graphql::parse_query,
242        runtime::Executor,
243        schema::{CompiledSchema, QueryDefinition, SqlProjectionHint},
244    };
245
246    // ── helpers ───────────────────────────────────────────────────────────────
247
248    fn parsed(query: &str) -> ParsedQuery {
249        parse_query(query).expect("valid query")
250    }
251
252    fn make_schema_with_queries(names: &[(&str, &str)]) -> CompiledSchema {
253        let mut schema = CompiledSchema::default();
254        for (name, sql_source) in names {
255            let mut qd = QueryDefinition::new(*name, "SomeType");
256            qd.sql_source = Some((*sql_source).to_string());
257            qd.returns_list = true;
258            schema.queries.push(qd);
259        }
260        schema
261    }
262
263    struct MockAdapter;
264
265    // Reason: DatabaseAdapter is defined with #[async_trait]; all implementations must match
266    // its transformed method signatures to satisfy the trait contract
267    // async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
268    #[async_trait]
269    impl crate::db::traits::DatabaseAdapter for MockAdapter {
270        async fn execute_where_query(
271            &self,
272            _view: &str,
273            _where_clause: Option<&WhereClause>,
274            _limit: Option<u32>,
275            _offset: Option<u32>,
276            _order_by: Option<&[OrderByClause]>,
277        ) -> crate::error::Result<Vec<JsonbValue>> {
278            Ok(vec![])
279        }
280
281        async fn execute_with_projection(
282            &self,
283            _view: &str,
284            _projection: Option<&SqlProjectionHint>,
285            _where_clause: Option<&WhereClause>,
286            _limit: Option<u32>,
287            _offset: Option<u32>,
288            _order_by: Option<&[OrderByClause]>,
289        ) -> crate::error::Result<Vec<JsonbValue>> {
290            Ok(vec![JsonbValue::new(serde_json::json!({"id": 1}))])
291        }
292
293        fn database_type(&self) -> DatabaseType {
294            DatabaseType::SQLite
295        }
296
297        async fn health_check(&self) -> crate::error::Result<()> {
298            Ok(())
299        }
300
301        fn pool_metrics(&self) -> PoolMetrics {
302            PoolMetrics {
303                total_connections:  1,
304                idle_connections:   1,
305                active_connections: 0,
306                waiting_requests:   0,
307            }
308        }
309
310        async fn execute_raw_query(
311            &self,
312            _sql: &str,
313        ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
314        {
315            Ok(vec![])
316        }
317
318        async fn execute_parameterized_aggregate(
319            &self,
320            _sql: &str,
321            _params: &[serde_json::Value],
322        ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
323        {
324            Ok(vec![])
325        }
326    }
327
328    fn make_executor(names: &[(&str, &str)]) -> Executor<MockAdapter> {
329        let schema = make_schema_with_queries(names);
330        Executor::new(schema, Arc::new(MockAdapter))
331    }
332
333    // ── detection tests ───────────────────────────────────────────────────────
334
335    #[test]
336    fn test_is_multi_root_single() {
337        assert!(!is_multi_root(&parsed("{ users { id } }")));
338    }
339
340    #[test]
341    fn test_is_multi_root_two_roots() {
342        assert!(is_multi_root(&parsed("{ users { id } posts { id } }")));
343    }
344
345    #[test]
346    fn test_is_multi_root_three_roots() {
347        assert!(is_multi_root(&parsed("{ users { id } posts { id } orders { id } }")));
348    }
349
350    #[test]
351    fn test_extract_root_field_names_single() {
352        let p = parsed("{ users { id } }");
353        assert_eq!(extract_root_field_names(&p), vec!["users"]);
354    }
355
356    #[test]
357    fn test_extract_root_field_names_two() {
358        let p = parsed("{ users { id } posts { id } }");
359        assert_eq!(extract_root_field_names(&p), vec!["users", "posts"]);
360    }
361
362    // ── serializer tests ──────────────────────────────────────────────────────
363
364    #[test]
365    fn test_serializer_simple_field() {
366        let p = parsed("{ users { id name } }");
367        let field = &p.selections[0];
368        let q = field_selection_to_query(field);
369        assert!(q.contains("users"), "missing field name: {q}");
370        assert!(q.contains("id"), "missing subfield: {q}");
371        assert!(q.contains("name"), "missing subfield: {q}");
372    }
373
374    #[test]
375    fn test_serializer_scalar_arg() {
376        let p = parsed("{ users(limit: 10) { id } }");
377        let field = &p.selections[0];
378        let q = field_selection_to_query(field);
379        assert!(q.contains("limit"), "missing arg: {q}");
380        assert!(q.contains("10"), "missing value: {q}");
381    }
382
383    #[test]
384    fn test_serializer_roundtrip_is_parseable() {
385        let original = "{ users { id name } }";
386        let p = parsed(original);
387        let synthetic = field_selection_to_query(&p.selections[0]);
388        // The synthetic query should be re-parseable
389        parse_query(&synthetic).expect("synthetic query must be valid GraphQL");
390    }
391
392    // ── parallel execution tests ──────────────────────────────────────────────
393
394    #[tokio::test]
395    async fn test_execute_parallel_returns_all_fields() {
396        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
397        let p = parsed("{ users { id } posts { id } }");
398        let result = exec.execute_parallel(&p, None).await.unwrap();
399        assert_eq!(result.fields.len(), 2);
400        assert!(result.fields.iter().any(|f| f.field_name == "users"));
401        assert!(result.fields.iter().any(|f| f.field_name == "posts"));
402        assert!(result.parallel);
403    }
404
405    #[tokio::test]
406    async fn test_execute_parallel_merges_data_correctly() {
407        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
408        let p = parsed("{ users { id } posts { id } }");
409        let result = exec.execute_parallel(&p, None).await.unwrap();
410        let merged = result.merge_into_data_map();
411        assert!(merged.contains_key("users"), "missing users key");
412        assert!(merged.contains_key("posts"), "missing posts key");
413    }
414
415    #[tokio::test]
416    async fn test_single_root_unaffected() {
417        let exec = make_executor(&[("users", "v_users")]);
418        let val = exec.execute("{ users { id } }", None).await.unwrap();
419        assert!(val["data"]["users"].is_array());
420    }
421
422    #[tokio::test]
423    async fn test_multi_root_counter_increments() {
424        let before = multi_root_queries_total();
425        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
426        let p = parsed("{ users { id } posts { id } }");
427        exec.execute_parallel(&p, None).await.unwrap();
428        assert!(multi_root_queries_total() > before);
429    }
430}