fraiseql_core/runtime/executor/
pipeline.rs1use std::sync::atomic::{AtomicU64, Ordering};
17
18use super::Executor;
19use crate::{
20 db::traits::DatabaseAdapter,
21 error::{FraiseQLError, Result},
22 graphql::{FieldSelection, GraphQLArgument, ParsedQuery},
23};
24
25static MULTI_ROOT_QUERIES_TOTAL: AtomicU64 = AtomicU64::new(0);
28
29pub fn multi_root_queries_total() -> u64 {
31 MULTI_ROOT_QUERIES_TOTAL.load(Ordering::Relaxed)
32}
33
34#[derive(Debug)]
38pub struct RootFieldResult {
39 pub field_name: String,
41 pub data: serde_json::Value,
43}
44
45#[derive(Debug)]
47pub struct PipelineResult {
48 pub fields: Vec<RootFieldResult>,
50 pub parallel: bool,
52}
53
54impl PipelineResult {
55 #[must_use]
59 pub fn merge_into_data_map(&self) -> serde_json::Map<String, serde_json::Value> {
60 self.fields.iter().map(|f| (f.field_name.clone(), f.data.clone())).collect()
61 }
62}
63
64#[must_use]
71pub const fn is_multi_root(parsed: &ParsedQuery) -> bool {
72 parsed.selections.len() > 1
73}
74
75#[must_use]
77pub fn extract_root_field_names(parsed: &ParsedQuery) -> Vec<&str> {
78 parsed.selections.iter().map(|s| s.response_key()).collect()
79}
80
81pub(super) fn field_selection_to_query(field: &FieldSelection) -> String {
89 format!("{{ {} }}", serialize_field(field))
90}
91
92fn serialize_field(field: &FieldSelection) -> String {
93 let mut s = String::new();
94
95 if let Some(alias) = &field.alias {
97 s.push_str(alias);
98 s.push_str(": ");
99 }
100 s.push_str(&field.name);
101
102 if !field.arguments.is_empty() {
104 s.push('(');
105 let args: Vec<String> = field.arguments.iter().map(serialize_arg).collect();
106 s.push_str(&args.join(", "));
107 s.push(')');
108 }
109
110 if !field.nested_fields.is_empty() {
112 s.push_str(" { ");
113 let sub: Vec<String> = field.nested_fields.iter().map(serialize_field).collect();
114 s.push_str(&sub.join(" "));
115 s.push_str(" }");
116 }
117
118 s
119}
120
121fn serialize_arg(arg: &GraphQLArgument) -> String {
122 format!("{}: {}", arg.name, arg_value_to_graphql(arg))
123}
124
125fn arg_value_to_graphql(arg: &GraphQLArgument) -> String {
127 match arg.value_type.as_str() {
128 "variable" => {
129 serde_json::from_str::<String>(&arg.value_json)
132 .unwrap_or_else(|_| arg.value_json.clone())
133 },
134 "object" => {
135 serde_json::from_str::<serde_json::Value>(&arg.value_json)
137 .map_or_else(|_| arg.value_json.clone(), |v| json_value_to_graphql(&v))
138 },
139 "enum" => {
140 serde_json::from_str::<String>(&arg.value_json)
142 .unwrap_or_else(|_| arg.value_json.clone())
143 },
144 _ => arg.value_json.clone(),
146 }
147}
148
149fn json_value_to_graphql(val: &serde_json::Value) -> String {
151 match val {
152 serde_json::Value::Object(map) => {
153 let pairs: Vec<String> =
154 map.iter().map(|(k, v)| format!("{k}: {}", json_value_to_graphql(v))).collect();
155 format!("{{{}}}", pairs.join(", "))
156 },
157 serde_json::Value::Array(arr) => {
158 let items: Vec<String> = arr.iter().map(json_value_to_graphql).collect();
159 format!("[{}]", items.join(", "))
160 },
161 serde_json::Value::String(s) => format!("\"{s}\""),
162 serde_json::Value::Number(n) => n.to_string(),
163 serde_json::Value::Bool(b) => b.to_string(),
164 serde_json::Value::Null => "null".to_string(),
165 }
166}
167
168impl<A: DatabaseAdapter> Executor<A> {
171 pub async fn execute_parallel(
181 &self,
182 parsed: &ParsedQuery,
183 variables: Option<&serde_json::Value>,
184 ) -> Result<PipelineResult> {
185 MULTI_ROOT_QUERIES_TOTAL.fetch_add(1, Ordering::Relaxed);
186
187 let field_queries: Vec<(String, String)> = parsed
190 .selections
191 .iter()
192 .map(|f| (f.response_key().to_string(), field_selection_to_query(f)))
193 .collect();
194
195 let futs: Vec<_> = field_queries
198 .iter()
199 .map(|(_, query)| self.execute_regular_query(query.as_str(), variables))
200 .collect();
201
202 let results = futures::future::try_join_all(futs).await?;
204
205 let fields = results
207 .into_iter()
208 .zip(field_queries.iter())
209 .map(|(json_str, (field_name, _))| {
210 let response: serde_json::Value =
211 serde_json::from_str(&json_str).map_err(|e| FraiseQLError::Internal {
212 message: e.to_string(),
213 source: None,
214 })?;
215 let data = response["data"][field_name.as_str()].clone();
216 Ok(RootFieldResult {
217 field_name: field_name.clone(),
218 data,
219 })
220 })
221 .collect::<Result<Vec<_>>>()?;
222
223 Ok(PipelineResult {
224 fields,
225 parallel: true,
226 })
227 }
228}
229
230#[cfg(test)]
233mod tests {
234 #![allow(clippy::unwrap_used)] use std::sync::Arc;
237
238 use async_trait::async_trait;
239
240 use super::*;
241 use crate::{
242 db::{
243 WhereClause,
244 types::{DatabaseType, JsonbValue, OrderByClause, PoolMetrics},
245 },
246 graphql::parse_query,
247 runtime::Executor,
248 schema::{CompiledSchema, QueryDefinition, SqlProjectionHint},
249 };
250
251 fn parsed(query: &str) -> ParsedQuery {
254 parse_query(query).expect("valid query")
255 }
256
257 fn make_schema_with_queries(names: &[(&str, &str)]) -> CompiledSchema {
258 let mut schema = CompiledSchema::default();
259 for (name, sql_source) in names {
260 let mut qd = QueryDefinition::new(*name, "SomeType");
261 qd.sql_source = Some((*sql_source).to_string());
262 qd.returns_list = true;
263 schema.queries.push(qd);
264 }
265 schema
266 }
267
268 struct MockAdapter;
269
270 #[async_trait]
274 impl crate::db::traits::DatabaseAdapter for MockAdapter {
275 async fn execute_where_query(
276 &self,
277 _view: &str,
278 _where_clause: Option<&WhereClause>,
279 _limit: Option<u32>,
280 _offset: Option<u32>,
281 _order_by: Option<&[OrderByClause]>,
282 ) -> crate::error::Result<Vec<JsonbValue>> {
283 Ok(vec![])
284 }
285
286 async fn execute_with_projection(
287 &self,
288 _view: &str,
289 _projection: Option<&SqlProjectionHint>,
290 _where_clause: Option<&WhereClause>,
291 _limit: Option<u32>,
292 _offset: Option<u32>,
293 _order_by: Option<&[OrderByClause]>,
294 ) -> crate::error::Result<Vec<JsonbValue>> {
295 Ok(vec![JsonbValue::new(serde_json::json!({"id": 1}))])
296 }
297
298 fn database_type(&self) -> DatabaseType {
299 DatabaseType::SQLite
300 }
301
302 async fn health_check(&self) -> crate::error::Result<()> {
303 Ok(())
304 }
305
306 fn pool_metrics(&self) -> PoolMetrics {
307 PoolMetrics {
308 total_connections: 1,
309 idle_connections: 1,
310 active_connections: 0,
311 waiting_requests: 0,
312 }
313 }
314
315 async fn execute_raw_query(
316 &self,
317 _sql: &str,
318 ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
319 {
320 Ok(vec![])
321 }
322
323 async fn execute_parameterized_aggregate(
324 &self,
325 _sql: &str,
326 _params: &[serde_json::Value],
327 ) -> crate::error::Result<Vec<std::collections::HashMap<String, serde_json::Value>>>
328 {
329 Ok(vec![])
330 }
331 }
332
333 fn make_executor(names: &[(&str, &str)]) -> Executor<MockAdapter> {
334 let schema = make_schema_with_queries(names);
335 Executor::new(schema, Arc::new(MockAdapter))
336 }
337
338 #[test]
341 fn test_is_multi_root_single() {
342 assert!(!is_multi_root(&parsed("{ users { id } }")));
343 }
344
345 #[test]
346 fn test_is_multi_root_two_roots() {
347 assert!(is_multi_root(&parsed("{ users { id } posts { id } }")));
348 }
349
350 #[test]
351 fn test_is_multi_root_three_roots() {
352 assert!(is_multi_root(&parsed("{ users { id } posts { id } orders { id } }")));
353 }
354
355 #[test]
356 fn test_extract_root_field_names_single() {
357 let p = parsed("{ users { id } }");
358 assert_eq!(extract_root_field_names(&p), vec!["users"]);
359 }
360
361 #[test]
362 fn test_extract_root_field_names_two() {
363 let p = parsed("{ users { id } posts { id } }");
364 assert_eq!(extract_root_field_names(&p), vec!["users", "posts"]);
365 }
366
367 #[test]
370 fn test_serializer_simple_field() {
371 let p = parsed("{ users { id name } }");
372 let field = &p.selections[0];
373 let q = field_selection_to_query(field);
374 assert!(q.contains("users"), "missing field name: {q}");
375 assert!(q.contains("id"), "missing subfield: {q}");
376 assert!(q.contains("name"), "missing subfield: {q}");
377 }
378
379 #[test]
380 fn test_serializer_scalar_arg() {
381 let p = parsed("{ users(limit: 10) { id } }");
382 let field = &p.selections[0];
383 let q = field_selection_to_query(field);
384 assert!(q.contains("limit"), "missing arg: {q}");
385 assert!(q.contains("10"), "missing value: {q}");
386 }
387
388 #[test]
389 fn test_serializer_roundtrip_is_parseable() {
390 let original = "{ users { id name } }";
391 let p = parsed(original);
392 let synthetic = field_selection_to_query(&p.selections[0]);
393 parse_query(&synthetic).expect("synthetic query must be valid GraphQL");
395 }
396
397 #[tokio::test]
400 async fn test_execute_parallel_returns_all_fields() {
401 let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
402 let p = parsed("{ users { id } posts { id } }");
403 let result = exec.execute_parallel(&p, None).await.unwrap();
404 assert_eq!(result.fields.len(), 2);
405 assert!(result.fields.iter().any(|f| f.field_name == "users"));
406 assert!(result.fields.iter().any(|f| f.field_name == "posts"));
407 assert!(result.parallel);
408 }
409
410 #[tokio::test]
411 async fn test_execute_parallel_merges_data_correctly() {
412 let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
413 let p = parsed("{ users { id } posts { id } }");
414 let result = exec.execute_parallel(&p, None).await.unwrap();
415 let merged = result.merge_into_data_map();
416 assert!(merged.contains_key("users"), "missing users key");
417 assert!(merged.contains_key("posts"), "missing posts key");
418 }
419
420 #[tokio::test]
421 async fn test_single_root_unaffected() {
422 let exec = make_executor(&[("users", "v_users")]);
423 let response = exec.execute("{ users { id } }", None).await.unwrap();
424 let val: serde_json::Value = serde_json::from_str(&response).unwrap();
425 assert!(val["data"]["users"].is_array());
426 }
427
428 #[tokio::test]
429 async fn test_multi_root_counter_increments() {
430 let before = multi_root_queries_total();
431 let exec = make_executor(&[("users", "v_users"), ("posts", "v_posts")]);
432 let p = parsed("{ users { id } posts { id } }");
433 exec.execute_parallel(&p, None).await.unwrap();
434 assert!(multi_root_queries_total() > before);
435 }
436}