Skip to main content

fraiseql_core/federation/
http_resolver.rs

1//! HTTP entity resolution for federated subgraphs.
2//!
3//! Resolves entities from remote GraphQL subgraphs via HTTP POST requests
4//! to their `_entities` endpoint. Includes retry logic, timeout handling,
5//! and error recovery.
6
7use std::time::Duration;
8
9use serde_json::{Value, json};
10
11use crate::{
12    error::Result,
13    federation::{
14        selection_parser::FieldSelection, tracing::FederationTraceContext,
15        types::EntityRepresentation,
16    },
17};
18
19/// Configuration for HTTP client behavior
20#[derive(Debug, Clone)]
21pub struct HttpClientConfig {
22    /// Request timeout in milliseconds
23    pub timeout_ms:     u64,
24    /// Maximum number of retry attempts
25    pub max_retries:    u32,
26    /// Initial delay between retries in milliseconds (exponential backoff)
27    pub retry_delay_ms: u64,
28}
29
30impl Default for HttpClientConfig {
31    fn default() -> Self {
32        Self {
33            timeout_ms:     5000,
34            max_retries:    3,
35            retry_delay_ms: 100,
36        }
37    }
38}
39
40/// HTTP entity resolver
41#[derive(Clone)]
42pub struct HttpEntityResolver {
43    client: reqwest::Client,
44    config: HttpClientConfig,
45}
46
47#[derive(serde::Serialize)]
48struct GraphQLRequest {
49    query:     String,
50    variables: Value,
51}
52
53#[derive(serde::Deserialize, Debug)]
54struct GraphQLResponse {
55    data:   Option<Value>,
56    errors: Option<Vec<GraphQLError>>,
57}
58
59#[derive(serde::Deserialize, Debug)]
60struct GraphQLError {
61    message: String,
62}
63
64impl HttpEntityResolver {
65    /// Create a new HTTP entity resolver
66    pub fn new(config: HttpClientConfig) -> Self {
67        let client = reqwest::Client::builder()
68            .timeout(Duration::from_millis(config.timeout_ms))
69            .build()
70            .unwrap_or_default();
71
72        Self { client, config }
73    }
74
75    /// Resolve entities via HTTP _entities query
76    pub async fn resolve_entities(
77        &self,
78        subgraph_url: &str,
79        representations: &[EntityRepresentation],
80        selection: &FieldSelection,
81    ) -> Result<Vec<Option<Value>>> {
82        self.resolve_entities_with_tracing(subgraph_url, representations, selection, None)
83            .await
84    }
85
86    /// Resolve entities via HTTP _entities query with optional distributed tracing.
87    pub async fn resolve_entities_with_tracing(
88        &self,
89        subgraph_url: &str,
90        representations: &[EntityRepresentation],
91        selection: &FieldSelection,
92        _trace_context: Option<FederationTraceContext>,
93    ) -> Result<Vec<Option<Value>>> {
94        if representations.is_empty() {
95            return Ok(Vec::new());
96        }
97
98        // Build GraphQL _entities query
99        let query = self.build_entities_query(representations, selection)?;
100
101        // Execute with retry
102        let response = self.execute_with_retry(subgraph_url, &query).await?;
103
104        // Parse response
105        self.parse_response(&response, representations)
106    }
107
108    fn build_entities_query(
109        &self,
110        representations: &[EntityRepresentation],
111        selection: &FieldSelection,
112    ) -> Result<GraphQLRequest> {
113        // Group representations by typename
114        let mut typename_fields: std::collections::HashMap<String, Vec<String>> =
115            std::collections::HashMap::new();
116
117        for rep in representations {
118            typename_fields.entry(rep.typename.clone()).or_insert_with(Vec::new);
119        }
120
121        // Build inline fragments for each type
122        let mut inline_fragments = Vec::new();
123        for typename in typename_fields.keys() {
124            let fields = selection.fields.join(" ");
125            inline_fragments.push(format!("... on {} {{ {} }}", typename, fields));
126        }
127
128        // Build the complete query
129        let query = format!(
130            "query($representations: [_Any!]!) {{ _entities(representations: $representations) {{ {} }} }}",
131            inline_fragments.join(" ")
132        );
133
134        // Serialize representations as variables
135        let repr_values: Vec<Value> = representations
136            .iter()
137            .map(|rep| {
138                let mut obj = rep.all_fields.clone();
139                obj.insert("__typename".to_string(), Value::String(rep.typename.clone()));
140                Value::Object(obj.into_iter().collect::<serde_json::Map<_, _>>())
141            })
142            .collect();
143
144        Ok(GraphQLRequest {
145            query,
146            variables: json!({ "representations": repr_values }),
147        })
148    }
149
150    async fn execute_with_retry(
151        &self,
152        url: &str,
153        request: &GraphQLRequest,
154    ) -> Result<GraphQLResponse> {
155        let mut attempts = 0;
156        let mut last_error = None;
157
158        while attempts < self.config.max_retries {
159            attempts += 1;
160
161            match self.client.post(url).json(request).send().await {
162                Ok(response) if response.status().is_success() => {
163                    match response.json::<GraphQLResponse>().await {
164                        Ok(gql_response) => return Ok(gql_response),
165                        Err(e) => {
166                            last_error = Some(format!("Failed to parse response: {}", e));
167                        },
168                    }
169                },
170                Ok(response) => {
171                    last_error = Some(format!("HTTP {}", response.status()));
172                },
173                Err(e) => {
174                    last_error = Some(format!("Request failed: {}", e));
175                },
176            }
177
178            // Exponential backoff
179            if attempts < self.config.max_retries {
180                let delay = Duration::from_millis(
181                    self.config.retry_delay_ms * 2_u64.saturating_pow(attempts - 1),
182                );
183                tokio::time::sleep(delay).await;
184            }
185        }
186
187        Err(crate::error::FraiseQLError::Internal {
188            message: format!(
189                "HTTP resolution failed after {} attempts: {}",
190                attempts,
191                last_error.unwrap_or_else(|| "unknown error".to_string())
192            ),
193            source:  None,
194        })
195    }
196
197    fn parse_response(
198        &self,
199        response: &GraphQLResponse,
200        representations: &[EntityRepresentation],
201    ) -> Result<Vec<Option<Value>>> {
202        // Check for GraphQL errors
203        if let Some(errors) = &response.errors {
204            let error_messages: Vec<String> = errors.iter().map(|e| e.message.clone()).collect();
205            return Err(crate::error::FraiseQLError::Internal {
206                message: format!("GraphQL errors: {}", error_messages.join("; ")),
207                source:  None,
208            });
209        }
210
211        // Extract entities from response
212        let entities = response
213            .data
214            .as_ref()
215            .and_then(|d| d.get("_entities"))
216            .and_then(|e| e.as_array())
217            .cloned()
218            .unwrap_or_default();
219
220        if entities.len() != representations.len() {
221            return Err(crate::error::FraiseQLError::Internal {
222                message: format!(
223                    "Entity count mismatch: expected {}, got {}",
224                    representations.len(),
225                    entities.len()
226                ),
227                source:  None,
228            });
229        }
230
231        // Return entities in same order as representations
232        Ok(entities.into_iter().map(Some).collect())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use std::collections::HashMap;
239
240    use super::*;
241
242    fn mock_representation(typename: &str, id: &str) -> EntityRepresentation {
243        let mut key_fields = HashMap::new();
244        key_fields.insert("id".to_string(), Value::String(id.to_string()));
245
246        let mut all_fields = key_fields.clone();
247        all_fields.insert("__typename".to_string(), Value::String(typename.to_string()));
248
249        EntityRepresentation {
250            typename: typename.to_string(),
251            key_fields,
252            all_fields,
253        }
254    }
255
256    #[test]
257    fn test_http_resolver_creation() {
258        let config = HttpClientConfig::default();
259        let _resolver = HttpEntityResolver::new(config);
260        // Should not panic
261    }
262
263    #[test]
264    fn test_empty_representations() {
265        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
266        let rt = tokio::runtime::Runtime::new().unwrap();
267
268        rt.block_on(async {
269            let result = resolver
270                .resolve_entities("http://example.com/graphql", &[], &FieldSelection::default())
271                .await;
272
273            assert!(result.is_ok());
274            assert_eq!(result.unwrap().len(), 0);
275        });
276    }
277
278    #[test]
279    fn test_graphql_query_building() {
280        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
281        let reps = vec![mock_representation("User", "123")];
282        let selection = FieldSelection {
283            fields: vec!["id".to_string(), "email".to_string()],
284        };
285
286        let request = resolver.build_entities_query(&reps, &selection).unwrap();
287
288        assert!(request.query.contains("_entities"));
289        assert!(request.query.contains("_Any!"));
290        assert!(request.query.contains("User"));
291        assert!(request.query.contains("id"));
292        assert!(request.query.contains("email"));
293    }
294
295    #[test]
296    fn test_multiple_types_in_query() {
297        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
298        let reps = vec![
299            mock_representation("User", "123"),
300            mock_representation("Order", "456"),
301        ];
302        let selection = FieldSelection {
303            fields: vec!["id".to_string()],
304        };
305
306        let request = resolver.build_entities_query(&reps, &selection).unwrap();
307
308        assert!(request.query.contains("User"));
309        assert!(request.query.contains("Order"));
310    }
311
312    #[test]
313    fn test_response_parsing_success() {
314        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
315        let representations = vec![mock_representation("User", "123")];
316
317        let response = GraphQLResponse {
318            data:   Some(json!({
319                "_entities": [
320                    { "id": "123", "email": "user@example.com" }
321                ]
322            })),
323            errors: None,
324        };
325
326        let result = resolver.parse_response(&response, &representations);
327        assert!(result.is_ok());
328
329        let entities = result.unwrap();
330        assert_eq!(entities.len(), 1);
331        assert!(entities[0].is_some());
332    }
333
334    #[test]
335    fn test_response_parsing_with_errors() {
336        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
337        let representations = vec![mock_representation("User", "123")];
338
339        let response = GraphQLResponse {
340            data:   None,
341            errors: Some(vec![GraphQLError {
342                message: "Entity not found".to_string(),
343            }]),
344        };
345
346        let result = resolver.parse_response(&response, &representations);
347        assert!(result.is_err());
348    }
349
350    #[test]
351    fn test_response_parsing_entity_count_mismatch() {
352        let resolver = HttpEntityResolver::new(HttpClientConfig::default());
353        let representations = vec![
354            mock_representation("User", "123"),
355            mock_representation("User", "456"),
356        ];
357
358        let response = GraphQLResponse {
359            data:   Some(json!({
360                "_entities": [
361                    { "id": "123" }
362                ]
363            })),
364            errors: None,
365        };
366
367        let result = resolver.parse_response(&response, &representations);
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn test_config_defaults() {
373        let config = HttpClientConfig::default();
374        assert_eq!(config.timeout_ms, 5000);
375        assert_eq!(config.max_retries, 3);
376        assert_eq!(config.retry_delay_ms, 100);
377    }
378
379    #[test]
380    fn test_config_custom() {
381        let config = HttpClientConfig {
382            timeout_ms:     10000,
383            max_retries:    5,
384            retry_delay_ms: 200,
385        };
386        assert_eq!(config.timeout_ms, 10000);
387        assert_eq!(config.max_retries, 5);
388        assert_eq!(config.retry_delay_ms, 200);
389    }
390}