hive_router_plan_executor/execution/
plan.rs

1use std::collections::HashMap;
2
3use bytes::{BufMut, Bytes};
4use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
5use hive_router_query_planner::planner::plan_nodes::{
6    ConditionNode, FetchNode, FetchRewrite, FlattenNode, FlattenNodePath, ParallelNode, PlanNode,
7    QueryPlan, SequenceNode,
8};
9use serde::Deserialize;
10use sonic_rs::ValueRef;
11
12use crate::{
13    context::ExecutionContext,
14    execution::{error::PlanExecutionError, rewrites::FetchRewriteExt},
15    executors::{common::HttpExecutionRequest, map::SubgraphExecutorMap},
16    introspection::{
17        resolve::{resolve_introspection, IntrospectionContext},
18        schema::SchemaMetadata,
19    },
20    projection::{
21        plan::FieldProjectionPlan,
22        request::{project_requires, RequestProjectionContext},
23        response::project_by_operation,
24    },
25    response::{
26        graphql_error::GraphQLError, merge::deep_merge, subgraph_response::SubgraphResponse,
27        value::Value,
28    },
29    utils::{
30        consts::{CLOSE_BRACKET, OPEN_BRACKET},
31        traverse::{traverse_and_callback, traverse_and_callback_mut},
32    },
33};
34
35pub struct QueryPlanExecutionContext<'exec> {
36    pub query_plan: &'exec QueryPlan,
37    pub projection_plan: &'exec Vec<FieldProjectionPlan>,
38    pub variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
39    pub extensions: Option<HashMap<String, sonic_rs::Value>>,
40    pub introspection_context: &'exec IntrospectionContext<'exec, 'static>,
41    pub operation_type_name: &'exec str,
42    pub executors: &'exec SubgraphExecutorMap,
43}
44
45pub async fn execute_query_plan<'exec>(
46    ctx: QueryPlanExecutionContext<'exec>,
47) -> Result<Vec<u8>, PlanExecutionError> {
48    let init_value = if let Some(introspection_query) = ctx.introspection_context.query {
49        resolve_introspection(introspection_query, ctx.introspection_context)
50    } else {
51        Value::Null
52    };
53
54    let mut exec_ctx = ExecutionContext::new(ctx.query_plan, init_value);
55
56    if ctx.query_plan.node.is_some() {
57        let executor = Executor::new(
58            ctx.variable_values,
59            ctx.executors,
60            ctx.introspection_context.metadata,
61            // Deduplicate subgraph requests only if the operation type is a query
62            ctx.operation_type_name == "Query",
63        );
64        executor
65            .execute(&mut exec_ctx, ctx.query_plan.node.as_ref())
66            .await;
67    }
68
69    let final_response = &exec_ctx.final_response;
70    project_by_operation(
71        final_response,
72        exec_ctx.errors,
73        &ctx.extensions,
74        ctx.operation_type_name,
75        ctx.projection_plan,
76        ctx.variable_values,
77        exec_ctx.response_storage.estimate_final_response_size(),
78    )
79    .map_err(|e| e.into())
80}
81
82pub struct Executor<'exec> {
83    variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
84    schema_metadata: &'exec SchemaMetadata,
85    executors: &'exec SubgraphExecutorMap,
86    dedupe_subgraph_requests: bool,
87}
88
89struct ConcurrencyScope<'exec, T> {
90    jobs: FuturesUnordered<BoxFuture<'exec, T>>,
91}
92
93impl<'exec, T> ConcurrencyScope<'exec, T> {
94    fn new() -> Self {
95        Self {
96            jobs: FuturesUnordered::new(),
97        }
98    }
99
100    fn spawn(&mut self, future: BoxFuture<'exec, T>) {
101        self.jobs.push(future);
102    }
103
104    async fn join_all(mut self) -> Vec<T> {
105        let mut results = Vec::with_capacity(self.jobs.len());
106        while let Some(result) = self.jobs.next().await {
107            results.push(result);
108        }
109        results
110    }
111}
112
113struct FetchJob {
114    fetch_node_id: i64,
115    response: Bytes,
116}
117
118struct FlattenFetchJob {
119    flatten_node_path: FlattenNodePath,
120    response: Bytes,
121    fetch_node_id: i64,
122    representation_hashes: Vec<u64>,
123    representation_hash_to_index: HashMap<u64, usize>,
124}
125
126enum ExecutionJob {
127    Fetch(FetchJob),
128    FlattenFetch(FlattenFetchJob),
129    None,
130}
131
132impl From<ExecutionJob> for Bytes {
133    fn from(value: ExecutionJob) -> Self {
134        match value {
135            ExecutionJob::Fetch(j) => j.response,
136            ExecutionJob::FlattenFetch(j) => j.response,
137            ExecutionJob::None => Bytes::new(),
138        }
139    }
140}
141
142struct PreparedFlattenData {
143    representations: Vec<u8>,
144    representation_hashes: Vec<u64>,
145    representation_hash_to_index: HashMap<u64, usize>,
146}
147
148impl<'exec> Executor<'exec> {
149    pub fn new(
150        variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
151        executors: &'exec SubgraphExecutorMap,
152        schema_metadata: &'exec SchemaMetadata,
153        dedupe_subgraph_requests: bool,
154    ) -> Self {
155        Executor {
156            variable_values,
157            executors,
158            schema_metadata,
159            dedupe_subgraph_requests,
160        }
161    }
162
163    pub async fn execute(&self, ctx: &mut ExecutionContext<'exec>, plan: Option<&PlanNode>) {
164        match plan {
165            Some(PlanNode::Fetch(node)) => self.execute_fetch_wave(ctx, node).await,
166            Some(PlanNode::Parallel(node)) => self.execute_parallel_wave(ctx, node).await,
167            Some(PlanNode::Sequence(node)) => self.execute_sequence_wave(ctx, node).await,
168            // Plans produced by our Query Planner can only start with: Fetch, Sequence or Parallel.
169            // Any other node type at the root is not supported, do nothing
170            Some(_) => (),
171            // An empty plan is valid, just do nothing
172            None => (),
173        }
174    }
175
176    async fn execute_fetch_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &FetchNode) {
177        match self.execute_fetch_node(node, None).await {
178            Ok(result) => self.process_job_result(ctx, result),
179            Err(err) => ctx.errors.push(GraphQLError {
180                message: err.to_string(),
181                locations: None,
182                path: None,
183                extensions: None,
184            }),
185        }
186    }
187
188    async fn execute_sequence_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &SequenceNode) {
189        for child in &node.nodes {
190            Box::pin(self.execute_plan_node(ctx, child)).await;
191        }
192    }
193
194    async fn execute_parallel_wave(&self, ctx: &mut ExecutionContext<'exec>, node: &ParallelNode) {
195        let mut scope = ConcurrencyScope::new();
196
197        for child in &node.nodes {
198            let job_future = self.prepare_job_future(child, &ctx.final_response);
199            scope.spawn(job_future);
200        }
201
202        let results = scope.join_all().await;
203
204        for result in results {
205            match result {
206                Ok(job) => {
207                    self.process_job_result(ctx, job);
208                }
209                Err(err) => ctx.errors.push(GraphQLError {
210                    message: err.to_string(),
211                    locations: None,
212                    path: None,
213                    extensions: None,
214                }),
215            }
216        }
217    }
218
219    async fn execute_plan_node(&self, ctx: &mut ExecutionContext<'exec>, node: &PlanNode) {
220        match node {
221            PlanNode::Fetch(fetch_node) => match self.execute_fetch_node(fetch_node, None).await {
222                Ok(job) => {
223                    self.process_job_result(ctx, job);
224                }
225                Err(err) => ctx.errors.push(GraphQLError {
226                    message: err.to_string(),
227                    locations: None,
228                    path: None,
229                    extensions: None,
230                }),
231            },
232            PlanNode::Parallel(parallel_node) => {
233                self.execute_parallel_wave(ctx, parallel_node).await;
234            }
235            PlanNode::Flatten(flatten_node) => {
236                match self.prepare_flatten_data(&ctx.final_response, flatten_node) {
237                    Ok(Some(p)) => {
238                        match self
239                            .execute_flatten_fetch_node(
240                                flatten_node,
241                                Some(p.representations),
242                                Some(p.representation_hashes),
243                                Some(p.representation_hash_to_index),
244                            )
245                            .await
246                        {
247                            Ok(job) => {
248                                self.process_job_result(ctx, job);
249                            }
250                            Err(err) => {
251                                ctx.errors.push(GraphQLError {
252                                    message: err.to_string(),
253                                    locations: None,
254                                    path: None,
255                                    extensions: None,
256                                });
257                            }
258                        }
259                    }
260                    Ok(None) => { /* do nothing */ }
261                    Err(e) => {
262                        ctx.errors.push(GraphQLError {
263                            message: e.to_string(),
264                            locations: None,
265                            path: None,
266                            extensions: None,
267                        });
268                    }
269                }
270            }
271            PlanNode::Sequence(sequence_node) => {
272                self.execute_sequence_wave(ctx, sequence_node).await;
273            }
274            PlanNode::Condition(condition_node) => {
275                if let Some(node) =
276                    condition_node_by_variables(condition_node, self.variable_values)
277                {
278                    Box::pin(self.execute_plan_node(ctx, node)).await;
279                }
280            }
281            // An unsupported plan node was found, do nothing.
282            _ => {}
283        }
284    }
285
286    fn prepare_job_future<'wave>(
287        &'wave self,
288        node: &'wave PlanNode,
289        final_response: &Value<'exec>,
290    ) -> BoxFuture<'wave, Result<ExecutionJob, PlanExecutionError>> {
291        match node {
292            PlanNode::Fetch(fetch_node) => Box::pin(self.execute_fetch_node(fetch_node, None)),
293            PlanNode::Flatten(flatten_node) => {
294                match self.prepare_flatten_data(final_response, flatten_node) {
295                    Ok(Some(p)) => Box::pin(self.execute_flatten_fetch_node(
296                        flatten_node,
297                        Some(p.representations),
298                        Some(p.representation_hashes),
299                        Some(p.representation_hash_to_index),
300                    )),
301                    Ok(None) => Box::pin(async { Ok(ExecutionJob::None) }),
302                    Err(e) => Box::pin(async move { Err(e) }),
303                }
304            }
305            PlanNode::Condition(node) => {
306                match condition_node_by_variables(node, self.variable_values) {
307                    Some(node) => Box::pin(self.prepare_job_future(node, final_response)), // This is already clean.
308                    None => Box::pin(async { Ok(ExecutionJob::None) }),
309                }
310            }
311            // Our Query Planner does not produce any other plan node types in ParallelNode
312            _ => Box::pin(async { Ok(ExecutionJob::None) }),
313        }
314    }
315
316    fn process_subgraph_response(
317        &self,
318        ctx: &mut ExecutionContext<'exec>,
319        response_bytes: Bytes,
320        fetch_node_id: i64,
321    ) -> Option<(Value<'exec>, Option<&'exec Vec<FetchRewrite>>)> {
322        let idx = ctx.response_storage.add_response(response_bytes);
323        // SAFETY: The `bytes` are transmuted to the lifetime `'a` of the `ExecutionContext`.
324        // This is safe because the `response_storage` is part of the `ExecutionContext` (`ctx`)
325        // and will live as long as `'a`. The `Bytes` are stored in an `Arc`, so they won't be
326        // dropped until all references are gone. The `Value`s deserialized from this byte
327        // slice will borrow from it, and they are stored in `ctx.final_response`, which also
328        // lives for `'a`.
329        let bytes: &'exec [u8] =
330            unsafe { std::mem::transmute(ctx.response_storage.get_bytes(idx)) };
331
332        // SAFETY: The `output_rewrites` are transmuted to the lifetime `'a`. This is safe
333        // because `output_rewrites` is part of `OutputRewritesStorage` which is owned by
334        // `ExecutionContext` and lives for `'a`.
335        let output_rewrites: Option<&'exec Vec<FetchRewrite>> =
336            unsafe { std::mem::transmute(ctx.output_rewrites.get(fetch_node_id)) };
337
338        let mut deserializer = sonic_rs::Deserializer::from_slice(bytes);
339        let response = match SubgraphResponse::deserialize(&mut deserializer) {
340            Ok(response) => response,
341            Err(e) => {
342                ctx.errors
343                    .push(crate::response::graphql_error::GraphQLError {
344                        message: format!("Failed to deserialize subgraph response: {}", e),
345                        locations: None,
346                        path: None,
347                        extensions: None,
348                    });
349                return None;
350            }
351        };
352
353        ctx.handle_errors(response.errors);
354
355        Some((response.data, output_rewrites))
356    }
357
358    fn process_job_result(&self, ctx: &mut ExecutionContext<'exec>, job: ExecutionJob) {
359        match job {
360            ExecutionJob::Fetch(job) => {
361                if let Some((mut data, output_rewrites)) =
362                    self.process_subgraph_response(ctx, job.response, job.fetch_node_id)
363                {
364                    if let Some(output_rewrites) = output_rewrites {
365                        for output_rewrite in output_rewrites {
366                            output_rewrite.rewrite(&self.schema_metadata.possible_types, &mut data);
367                        }
368                    }
369
370                    deep_merge(&mut ctx.final_response, data);
371                }
372            }
373            ExecutionJob::FlattenFetch(job) => {
374                if let Some((mut data, output_rewrites)) =
375                    self.process_subgraph_response(ctx, job.response, job.fetch_node_id)
376                {
377                    if let Some(mut entities) = data.take_entities() {
378                        if let Some(output_rewrites) = output_rewrites {
379                            for output_rewrite in output_rewrites {
380                                for entity in &mut entities {
381                                    output_rewrite
382                                        .rewrite(&self.schema_metadata.possible_types, entity);
383                                }
384                            }
385                        }
386
387                        let mut index = 0;
388                        let normalized_path = job.flatten_node_path.as_slice();
389                        traverse_and_callback_mut(
390                            &mut ctx.final_response,
391                            normalized_path,
392                            self.schema_metadata,
393                            &mut |target| {
394                                let hash = job.representation_hashes[index];
395                                if let Some(entity_index) =
396                                    job.representation_hash_to_index.get(&hash)
397                                {
398                                    if let Some(entity) = entities.get(*entity_index) {
399                                        // SAFETY: `new_val` is a clone of an entity that lives for `'a`.
400                                        // The transmute is to satisfy the compiler, but the lifetime
401                                        // is valid.
402                                        let new_val: Value<'_> =
403                                            unsafe { std::mem::transmute(entity.clone()) };
404                                        deep_merge(target, new_val);
405                                    }
406                                }
407                                index += 1;
408                            },
409                        );
410                    }
411                }
412            }
413            ExecutionJob::None => {
414                // nothing to do
415            }
416        }
417    }
418
419    fn prepare_flatten_data(
420        &self,
421        final_response: &Value<'exec>,
422        flatten_node: &FlattenNode,
423    ) -> Result<Option<PreparedFlattenData>, PlanExecutionError> {
424        let fetch_node = match flatten_node.node.as_ref() {
425            PlanNode::Fetch(fetch_node) => fetch_node,
426            _ => return Ok(None),
427        };
428        let requires_nodes = match fetch_node.requires.as_ref() {
429            Some(nodes) => nodes,
430            None => return Ok(None),
431        };
432
433        let mut index = 0;
434        let normalized_path = flatten_node.path.as_slice();
435        let mut filtered_representations = Vec::new();
436        filtered_representations.put(OPEN_BRACKET);
437        let proj_ctx = RequestProjectionContext::new(&self.schema_metadata.possible_types);
438        let mut representation_hashes: Vec<u64> = Vec::new();
439        let mut filtered_representations_hashes: HashMap<u64, usize> = HashMap::new();
440        let arena = bumpalo::Bump::new();
441
442        traverse_and_callback(
443            final_response,
444            normalized_path,
445            self.schema_metadata,
446            &mut |entity| {
447                let hash = entity.to_hash(&requires_nodes.items, proj_ctx.possible_types);
448
449                if !entity.is_null() {
450                    representation_hashes.push(hash);
451                }
452
453                if filtered_representations_hashes.contains_key(&hash) {
454                    return Ok::<(), PlanExecutionError>(());
455                }
456
457                let entity = if let Some(input_rewrites) = &fetch_node.input_rewrites {
458                    let new_entity = arena.alloc(entity.clone());
459                    for input_rewrite in input_rewrites {
460                        input_rewrite.rewrite(&self.schema_metadata.possible_types, new_entity);
461                    }
462                    new_entity
463                } else {
464                    entity
465                };
466
467                let is_projected = project_requires(
468                    &proj_ctx,
469                    &requires_nodes.items,
470                    entity,
471                    &mut filtered_representations,
472                    filtered_representations_hashes.is_empty(),
473                    None,
474                )?;
475
476                if is_projected {
477                    filtered_representations_hashes.insert(hash, index);
478                }
479
480                index += 1;
481
482                Ok(())
483            },
484        )?;
485        filtered_representations.put(CLOSE_BRACKET);
486
487        if filtered_representations_hashes.is_empty() {
488            return Ok(None);
489        }
490
491        Ok(Some(PreparedFlattenData {
492            representations: filtered_representations,
493            representation_hashes,
494            representation_hash_to_index: filtered_representations_hashes,
495        }))
496    }
497
498    async fn execute_flatten_fetch_node(
499        &self,
500        node: &FlattenNode,
501        representations: Option<Vec<u8>>,
502        representation_hashes: Option<Vec<u64>>,
503        filtered_representations_hashes: Option<HashMap<u64, usize>>,
504    ) -> Result<ExecutionJob, PlanExecutionError> {
505        Ok(match node.node.as_ref() {
506            PlanNode::Fetch(fetch_node) => ExecutionJob::FlattenFetch(FlattenFetchJob {
507                flatten_node_path: node.path.clone(),
508                response: self
509                    .execute_fetch_node(fetch_node, representations)
510                    .await?
511                    .into(),
512                fetch_node_id: fetch_node.id,
513                representation_hashes: representation_hashes.unwrap_or_default(),
514                representation_hash_to_index: filtered_representations_hashes.unwrap_or_default(),
515            }),
516            _ => ExecutionJob::None,
517        })
518    }
519
520    async fn execute_fetch_node(
521        &self,
522        node: &FetchNode,
523        representations: Option<Vec<u8>>,
524    ) -> Result<ExecutionJob, PlanExecutionError> {
525        Ok(ExecutionJob::Fetch(FetchJob {
526            fetch_node_id: node.id,
527            response: self
528                .executors
529                .execute(
530                    &node.service_name,
531                    HttpExecutionRequest {
532                        query: node.operation.document_str.as_str(),
533                        dedupe: self.dedupe_subgraph_requests,
534                        operation_name: node.operation_name.as_deref(),
535                        variables: None,
536                        representations,
537                    },
538                )
539                .await,
540        }))
541    }
542}
543
544fn condition_node_by_variables<'a>(
545    condition_node: &'a ConditionNode,
546    variable_values: &'a Option<HashMap<String, sonic_rs::Value>>,
547) -> Option<&'a PlanNode> {
548    let vars = variable_values.as_ref()?;
549    let value = vars.get(&condition_node.condition)?;
550    let condition_met = matches!(value.as_ref(), ValueRef::Bool(true));
551
552    if condition_met {
553        condition_node.if_clause.as_deref()
554    } else {
555        condition_node.else_clause.as_deref()
556    }
557}