Skip to main content

fraiseql_core/runtime/executor/
pipeline.rs

1//! Multi-root query pipelining — parallel execution of independent query roots.
2//!
3//! Dispatches multi-root GraphQL queries concurrently using
4//! [`futures::future::try_join_all`], then merges the results into a single
5//! `{ "data": { ... } }` envelope.
6//!
7//! # Example
8//!
9//! ```text
10//! { users { id name } posts { id title } }
11//! ```
12//!
13//! Without pipelining: `t_users + t_posts` latency (sequential).
14//! With pipelining:    `max(t_users, t_posts)` latency (concurrent).
15
16use std::sync::atomic::{AtomicU64, Ordering};
17
18use super::Executor;
19use crate::{
20    db::traits::DatabaseAdapter,
21    error::{FraiseQLError, Result},
22    graphql::{FieldSelection, GraphQLArgument, ParsedQuery},
23};
24
25// ── Prometheus counter ────────────────────────────────────────────────────────
26
27static MULTI_ROOT_QUERIES_TOTAL: AtomicU64 = AtomicU64::new(0);
28
29/// Total multi-root GraphQL queries dispatched via the parallel execution path.
30pub fn multi_root_queries_total() -> u64 {
31    MULTI_ROOT_QUERIES_TOTAL.load(Ordering::Relaxed)
32}
33
34// ── Result types ──────────────────────────────────────────────────────────────
35
36/// Result for a single root field in a pipelined execution.
37#[derive(Debug)]
38pub struct RootFieldResult {
39    /// Response key for this field (alias if provided, otherwise field name).
40    pub field_name: String,
41    /// Resolved data value.
42    pub data:       serde_json::Value,
43}
44
45/// Aggregated result from a multi-root parallel execution.
46#[derive(Debug)]
47pub struct PipelineResult {
48    /// Results for each root field, in the order they were requested.
49    pub fields:   Vec<RootFieldResult>,
50    /// `true` when results were produced by the parallel path.
51    pub parallel: bool,
52}
53
54impl PipelineResult {
55    /// Merge all field results into a single JSON map.
56    ///
57    /// Returns a `serde_json::Map` suitable for embedding in a `"data"` envelope.
58    #[must_use]
59    pub fn merge_into_data_map(&self) -> serde_json::Map<String, serde_json::Value> {
60        self.fields.iter().map(|f| (f.field_name.clone(), f.data.clone())).collect()
61    }
62}
63
64// ── Detection helpers ─────────────────────────────────────────────────────────
65
66/// Returns `true` when the query has more than one root field selection.
67///
68/// Only applies to anonymous queries and `query { ... }` operations; mutations
69/// and subscriptions are not affected.
70#[must_use]
71pub const fn is_multi_root(parsed: &ParsedQuery) -> bool {
72    parsed.selections.len() > 1
73}
74
75/// Returns the response key (alias or field name) for every root-level selection.
76#[must_use]
77pub fn extract_root_field_names(parsed: &ParsedQuery) -> Vec<&str> {
78    parsed.selections.iter().map(|s| s.response_key()).collect()
79}
80
81// ── Query-string serializer ───────────────────────────────────────────────────
82
83/// Serialize a root `FieldSelection` to a valid GraphQL query string.
84///
85/// Produces `{ fieldName(arg: value) { sub1 sub2 { ... } } }`.
86/// Variables are preserved as `$varName` references; inline values are
87/// converted from their stored JSON representation to GraphQL syntax.
88pub(super) fn field_selection_to_query(field: &FieldSelection) -> String {
89    format!("{{ {} }}", serialize_field(field))
90}
91
92fn serialize_field(field: &FieldSelection) -> String {
93    let mut s = String::new();
94
95    // Alias prefix
96    if let Some(alias) = &field.alias {
97        s.push_str(alias);
98        s.push_str(": ");
99    }
100    s.push_str(&field.name);
101
102    // Arguments
103    if !field.arguments.is_empty() {
104        s.push('(');
105        let args: Vec<String> = field.arguments.iter().map(serialize_arg).collect();
106        s.push_str(&args.join(", "));
107        s.push(')');
108    }
109
110    // Nested sub-selections
111    if !field.nested_fields.is_empty() {
112        s.push_str(" { ");
113        let sub: Vec<String> = field.nested_fields.iter().map(serialize_field).collect();
114        s.push_str(&sub.join(" "));
115        s.push_str(" }");
116    }
117
118    s
119}
120
121fn serialize_arg(arg: &GraphQLArgument) -> String {
122    format!("{}: {}", arg.name, arg_value_to_graphql(arg))
123}
124
125/// Convert a stored `GraphQLArgument` back to a GraphQL-syntax value.
126fn arg_value_to_graphql(arg: &GraphQLArgument) -> String {
127    match arg.value_type.as_str() {
128        "variable" => {
129            // value_json is stored as a JSON string e.g. `"\"$varName\""`.
130            // Parse it to get the raw `$varName`.
131            serde_json::from_str::<String>(&arg.value_json)
132                .unwrap_or_else(|_| arg.value_json.clone())
133        },
134        "object" => {
135            // JSON objects use quoted keys; GraphQL objects don't.
136            serde_json::from_str::<serde_json::Value>(&arg.value_json)
137                .map_or_else(|_| arg.value_json.clone(), |v| json_value_to_graphql(&v))
138        },
139        "enum" => {
140            // Strip surrounding JSON quotes from enum values.
141            serde_json::from_str::<String>(&arg.value_json)
142                .unwrap_or_else(|_| arg.value_json.clone())
143        },
144        // int, float, boolean, null, string, list — value_json is already valid GraphQL.
145        _ => arg.value_json.clone(),
146    }
147}
148
149/// Recursively convert a `serde_json::Value` to GraphQL value syntax.
150fn json_value_to_graphql(val: &serde_json::Value) -> String {
151    match val {
152        serde_json::Value::Object(map) => {
153            let pairs: Vec<String> =
154                map.iter().map(|(k, v)| format!("{k}: {}", json_value_to_graphql(v))).collect();
155            format!("{{{}}}", pairs.join(", "))
156        },
157        serde_json::Value::Array(arr) => {
158            let items: Vec<String> = arr.iter().map(json_value_to_graphql).collect();
159            format!("[{}]", items.join(", "))
160        },
161        serde_json::Value::String(s) => format!("\"{s}\""),
162        serde_json::Value::Number(n) => n.to_string(),
163        serde_json::Value::Bool(b) => b.to_string(),
164        serde_json::Value::Null => "null".to_string(),
165    }
166}
167
168// ── Parallel execution ────────────────────────────────────────────────────────
169
170impl<A: DatabaseAdapter> Executor<A> {
171    /// Execute all root fields of a multi-root query concurrently.
172    ///
173    /// Each root field is dispatched as an independent single-root query.
174    /// Results are awaited with [`futures::future::try_join_all`] and merged
175    /// into a `PipelineResult`.
176    ///
177    /// # Errors
178    ///
179    /// Returns the first error encountered across all concurrent sub-queries.
180    pub async fn execute_parallel(
181        &self,
182        parsed: &ParsedQuery,
183        variables: Option<&serde_json::Value>,
184    ) -> Result<PipelineResult> {
185        MULTI_ROOT_QUERIES_TOTAL.fetch_add(1, Ordering::Relaxed);
186
187        // Pre-compute synthetic single-root query strings (owned — avoids borrow
188        // lifetime entanglement between iterations and the final zip).
189        let field_queries: Vec<(String, String)> = parsed
190            .selections
191            .iter()
192            .map(|f| (f.response_key().to_string(), field_selection_to_query(f)))
193            .collect();
194
195        // Create all futures in a Vec; each borrows `self` and a slice of `field_queries`.
196        // Both borrows are valid for the lifetime of `execute_parallel`.
197        let futs: Vec<_> = field_queries
198            .iter()
199            .map(|(_, query)| self.execute_regular_query(query.as_str(), variables))
200            .collect();
201
202        // Drive all futures concurrently (single-threaded cooperative multitasking).
203        let results = futures::future::try_join_all(futs).await?;
204
205        // Extract the per-field `data` from each `{"data":{"field":[...]}}` response.
206        let fields = results
207            .into_iter()
208            .zip(field_queries.iter())
209            .map(|(json_str, (field_name, _))| {
210                let response: serde_json::Value =
211                    serde_json::from_str(&json_str).map_err(|e| FraiseQLError::Internal {
212                        message: e.to_string(),
213                        source:  None,
214                    })?;
215                let data = response["data"][field_name.as_str()].clone();
216                Ok(RootFieldResult {
217                    field_name: field_name.clone(),
218                    data,
219                })
220            })
221            .collect::<Result<Vec<_>>>()?;
222
223        Ok(PipelineResult {
224            fields,
225            parallel: true,
226        })
227    }
228}
229
230// ── Tests ─────────────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
235
236    use std::sync::Arc;
237
238    use async_trait::async_trait;
239
240    use super::*;
241    use crate::{
242        db::{
243            WhereClause,
244            types::{DatabaseType, JsonbValue, OrderByClause, PoolMetrics},
245        },
246        graphql::parse_query,
247        runtime::Executor,
248        schema::{CompiledSchema, QueryDefinition, SqlProjectionHint},
249    };
250
251    // ── helpers ───────────────────────────────────────────────────────────────
252
253    fn parsed(query: &str) -> ParsedQuery {
254        parse_query(query).expect("valid query")
255    }
256
257    fn make_schema_with_queries(names: &[(&str, &str)]) -> CompiledSchema {
258        let mut schema = CompiledSchema::default();
259        for (name, sql_source) in names {
260            let mut qd = QueryDefinition::new(*name, "SomeType");
261            qd.sql_source = Some((*sql_source).to_string());
262            qd.returns_list = true;
263            schema.queries.push(qd);
264        }
265        schema
266    }
267
268    struct MockAdapter;
269
270    // Reason: DatabaseAdapter is defined with #[async_trait]; all implementations must match
271    // its transformed method signatures to satisfy the trait contract
272    // async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
273    #[async_trait]
274    impl crate::db::traits::DatabaseAdapter for MockAdapter {
275        async fn execute_where_query(
276            &self,
277            _view: &str,
278            _where_clause: Option<&WhereClause>,
279            _limit: Option<u32>,
280            _offset: Option<u32>,
281            _order_by: Option<&[OrderByClause]>,
282        ) -> crate::error::Result<Vec<JsonbValue>> {
283            Ok(vec![])
284        }
285
286        async fn execute_with_projection(
287            &self,
288            _view: &str,
289            _projection: Option<&SqlProjectionHint>,
290            _where_clause: Option<&WhereClause>,
291            _limit: Option<u32>,
292            _offset: Option<u32>,
293            _order_by: Option<&[OrderByClause]>,
294        ) -> crate::error::Result<Vec<JsonbValue>> {
295            Ok(vec![JsonbValue::new(serde_json::json!({"id": 1}))])
296        }
297
298        fn database_type(&self) -> DatabaseType {
299            DatabaseType::SQLite
300        }
301
302        async fn health_check(&self) -> crate::error::Result<()> {
303            Ok(())
304        }
305
306        fn pool_metrics(&self) -> PoolMetrics {
307            PoolMetrics {
308                total_connections:  1,
309                idle_connections:   1,
310                active_connections: 0,
311                waiting_requests:   0,
312            }
313        }
314
315        async fn execute_raw_query(
316            &self,
317            _sql: &str,
318        ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
319        {
320            Ok(vec![])
321        }
322
323        async fn execute_parameterized_aggregate(
324            &self,
325            _sql: &str,
326            _params: &[serde_json::Value],
327        ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
328        {
329            Ok(vec![])
330        }
331    }
332
333    fn make_executor(names: &[(&str, &str)]) -> Executor<MockAdapter> {
334        let schema = make_schema_with_queries(names);
335        Executor::new(schema, Arc::new(MockAdapter))
336    }
337
338    // ── detection tests ───────────────────────────────────────────────────────
339
340    #[test]
341    fn test_is_multi_root_single() {
342        assert!(!is_multi_root(&parsed("{ users { id } }")));
343    }
344
345    #[test]
346    fn test_is_multi_root_two_roots() {
347        assert!(is_multi_root(&parsed("{ users { id } posts { id } }")));
348    }
349
350    #[test]
351    fn test_is_multi_root_three_roots() {
352        assert!(is_multi_root(&parsed("{ users { id } posts { id } orders { id } }")));
353    }
354
355    #[test]
356    fn test_extract_root_field_names_single() {
357        let p = parsed("{ users { id } }");
358        assert_eq!(extract_root_field_names(&p), vec!["users"]);
359    }
360
361    #[test]
362    fn test_extract_root_field_names_two() {
363        let p = parsed("{ users { id } posts { id } }");
364        assert_eq!(extract_root_field_names(&p), vec!["users", "posts"]);
365    }
366
367    // ── serializer tests ──────────────────────────────────────────────────────
368
369    #[test]
370    fn test_serializer_simple_field() {
371        let p = parsed("{ users { id name } }");
372        let field = &p.selections[0];
373        let q = field_selection_to_query(field);
374        assert!(q.contains("users"), "missing field name: {q}");
375        assert!(q.contains("id"), "missing subfield: {q}");
376        assert!(q.contains("name"), "missing subfield: {q}");
377    }
378
379    #[test]
380    fn test_serializer_scalar_arg() {
381        let p = parsed("{ users(limit: 10) { id } }");
382        let field = &p.selections[0];
383        let q = field_selection_to_query(field);
384        assert!(q.contains("limit"), "missing arg: {q}");
385        assert!(q.contains("10"), "missing value: {q}");
386    }
387
388    #[test]
389    fn test_serializer_roundtrip_is_parseable() {
390        let original = "{ users { id name } }";
391        let p = parsed(original);
392        let synthetic = field_selection_to_query(&p.selections[0]);
393        // The synthetic query should be re-parseable
394        parse_query(&synthetic).expect("synthetic query must be valid GraphQL");
395    }
396
397    // ── parallel execution tests ──────────────────────────────────────────────
398
399    #[tokio::test]
400    async fn test_execute_parallel_returns_all_fields() {
401        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
402        let p = parsed("{ users { id } posts { id } }");
403        let result = exec.execute_parallel(&p, None).await.unwrap();
404        assert_eq!(result.fields.len(), 2);
405        assert!(result.fields.iter().any(|f| f.field_name == "users"));
406        assert!(result.fields.iter().any(|f| f.field_name == "posts"));
407        assert!(result.parallel);
408    }
409
410    #[tokio::test]
411    async fn test_execute_parallel_merges_data_correctly() {
412        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
413        let p = parsed("{ users { id } posts { id } }");
414        let result = exec.execute_parallel(&p, None).await.unwrap();
415        let merged = result.merge_into_data_map();
416        assert!(merged.contains_key("users"), "missing users key");
417        assert!(merged.contains_key("posts"), "missing posts key");
418    }
419
420    #[tokio::test]
421    async fn test_single_root_unaffected() {
422        let exec = make_executor(&[("users", "v_users")]);
423        let response = exec.execute("{ users { id } }", None).await.unwrap();
424        let val: serde_json::Value = serde_json::from_str(&response).unwrap();
425        assert!(val["data"]["users"].is_array());
426    }
427
428    #[tokio::test]
429    async fn test_multi_root_counter_increments() {
430        let before = multi_root_queries_total();
431        let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
432        let p = parsed("{ users { id } posts { id } }");
433        exec.execute_parallel(&p, None).await.unwrap();
434        assert!(multi_root_queries_total() > before);
435    }
436}