hive_router_plan_executor/response/
graphql_error.rs

1use core::fmt;
2use graphql_tools::parser::Pos;
3use graphql_tools::validation::utils::ValidationError;
4use serde::{de, Deserialize, Deserializer, Serialize};
5use sonic_rs::Value;
6use std::collections::HashMap;
7
8#[derive(Clone, Debug, Deserialize, Serialize)]
9#[serde(rename_all = "camelCase")]
10pub struct GraphQLError {
11    pub message: String,
12    #[serde(default, skip_serializing_if = "is_none_or_empty")]
13    pub locations: Option<Vec<GraphQLErrorLocation>>,
14    #[serde(default, skip_serializing_if = "Option::is_none")]
15    pub path: Option<GraphQLErrorPath>,
16    #[serde(default, skip_serializing_if = "GraphQLErrorExtensions::is_empty")]
17    pub extensions: GraphQLErrorExtensions,
18}
19
20fn is_none_or_empty<T>(opt: &Option<Vec<T>>) -> bool {
21    opt.as_ref().is_none_or(|v| v.is_empty())
22}
23
24impl From<String> for GraphQLError {
25    fn from(message: String) -> Self {
26        GraphQLError {
27            message,
28            locations: None,
29            path: None,
30            extensions: GraphQLErrorExtensions::default(),
31        }
32    }
33}
34
35impl From<&str> for GraphQLError {
36    fn from(message: &str) -> Self {
37        GraphQLError {
38            message: message.to_string(),
39            locations: None,
40            path: None,
41            extensions: GraphQLErrorExtensions::default(),
42        }
43    }
44}
45
46impl From<&ValidationError> for GraphQLError {
47    fn from(val: &ValidationError) -> Self {
48        GraphQLError {
49            message: val.message.to_string(),
50            locations: Some(val.locations.iter().map(|pos| pos.into()).collect()),
51            path: None,
52            extensions: GraphQLErrorExtensions::new_from_code("GRAPHQL_VALIDATION_FAILED"),
53        }
54    }
55}
56
57impl From<&Pos> for GraphQLErrorLocation {
58    fn from(val: &Pos) -> Self {
59        GraphQLErrorLocation {
60            line: val.line,
61            column: val.column,
62        }
63    }
64}
65
66impl GraphQLError {
67    pub fn entity_index_and_path<'a>(&'a self) -> Option<EntityIndexAndPath<'a>> {
68        self.path.as_ref().and_then(|p| p.entity_index_and_path())
69    }
70
71    pub fn normalize_entity_error(
72        self,
73        entity_index_error_map: &HashMap<&usize, Vec<GraphQLErrorPath>>,
74    ) -> Vec<GraphQLError> {
75        if let Some(entity_index_and_path) = &self.entity_index_and_path() {
76            if let Some(entity_error_paths) =
77                entity_index_error_map.get(&entity_index_and_path.entity_index)
78            {
79                return entity_error_paths
80                    .iter()
81                    .map(|error_path| {
82                        let mut new_error_path = error_path.clone();
83                        new_error_path.extend_from_slice(entity_index_and_path.rest_of_path);
84                        GraphQLError {
85                            path: Some(new_error_path),
86                            ..self.clone()
87                        }
88                    })
89                    .collect();
90            }
91        }
92        vec![self]
93    }
94    /// Creates a GraphQLError with the given message and extensions.
95    /// Example:
96    /// ```rust
97    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
98    /// use hive_router_plan_executor::response::graphql_error::GraphQLErrorExtensions;
99    /// use sonic_rs::json;
100    ///
101    /// let extensions = GraphQLErrorExtensions {
102    ///     code: Some("SOME_ERROR_CODE".to_string()),
103    ///     service_name: None,
104    ///     affected_path: None,
105    ///     extensions: std::collections::HashMap::new(),
106    /// };
107    ///
108    /// let error = GraphQLError::from_message_and_extensions("An error occurred", extensions);
109    ///
110    /// assert_eq!(json!(error), json!({
111    ///     "message": "An error occurred",
112    ///     "extensions": {
113    ///         "code": "SOME_ERROR_CODE"
114    ///     }
115    /// }));
116    /// ```
117    pub fn from_message_and_extensions<TMessage: Into<String>>(
118        message: TMessage,
119        extensions: GraphQLErrorExtensions,
120    ) -> Self {
121        GraphQLError {
122            message: message.into(),
123            locations: None,
124            path: None,
125            extensions,
126        }
127    }
128    /// Creates a GraphQLError with the given message and code in extensions.
129    /// Example:
130    /// ```rust
131    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
132    /// use sonic_rs::json;
133    ///
134    /// let error = GraphQLError::from_message_and_code("An error occurred", "SOME_ERROR_CODE");
135    ///
136    /// assert_eq!(json!(error), json!({
137    ///     "message": "An error occurred",
138    ///     "extensions": {
139    ///         "code": "SOME_ERROR_CODE"
140    ///     }
141    /// }));
142    /// ```
143    pub fn from_message_and_code<TMessage: Into<String>, TCode: Into<String>>(
144        message: TMessage,
145        code: TCode,
146    ) -> Self {
147        GraphQLError {
148            message: message.into(),
149            locations: None,
150            path: None,
151            extensions: GraphQLErrorExtensions::new_from_code(code),
152        }
153    }
154
155    /// Adds subgraph name and error code `DOWNSTREAM_SERVICE_ERROR` to the extensions.
156    /// Example:
157    /// ```rust
158    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
159    /// use sonic_rs::json;
160    ///
161    /// let error = GraphQLError::from("An error occurred")
162    ///     .add_subgraph_name("users");
163    ///
164    /// assert_eq!(json!(error), json!({
165    ///     "message": "An error occurred",
166    ///     "extensions": {
167    ///         "serviceName": "users",
168    ///         "code": "DOWNSTREAM_SERVICE_ERROR"
169    ///     }
170    /// }));
171    /// ```
172    pub fn add_subgraph_name<TStr: Into<String>>(mut self, subgraph_name: TStr) -> Self {
173        self.extensions
174            .service_name
175            .get_or_insert(subgraph_name.into());
176        self.extensions
177            .code
178            .get_or_insert("DOWNSTREAM_SERVICE_ERROR".to_string());
179        self
180    }
181
182    /// Adds affected path to the extensions.
183    /// Example:
184    /// ```rust
185    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
186    /// use sonic_rs::json;
187    ///
188    /// let error = GraphQLError::from("An error occurred")
189    ///     .add_affected_path("user.friends[0].name");
190    ///
191    /// assert_eq!(json!(error), json!({
192    ///     "message": "An error occurred",
193    ///     "extensions": {
194    ///         "affectedPath": "user.friends[0].name"
195    ///     }
196    /// }));
197    /// ```
198    pub fn add_affected_path<TStr: Into<String>>(mut self, affected_path: TStr) -> Self {
199        self.extensions.affected_path = Some(affected_path.into());
200        self
201    }
202}
203
204#[derive(Clone, Debug, Deserialize, Serialize)]
205pub struct GraphQLErrorLocation {
206    pub line: usize,
207    pub column: usize,
208}
209
210#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
211#[serde(untagged)]
212pub enum GraphQLErrorPathSegment {
213    String(String),
214    Index(usize),
215}
216
217#[derive(Clone, Debug, Default, Deserialize, Serialize)]
218#[serde(transparent)]
219pub struct GraphQLErrorPath {
220    pub segments: Vec<GraphQLErrorPathSegment>,
221}
222
223pub struct EntityIndexAndPath<'a> {
224    pub entity_index: usize,
225    pub rest_of_path: &'a [GraphQLErrorPathSegment],
226}
227
228impl GraphQLErrorPath {
229    pub fn with_capacity(capacity: usize) -> Self {
230        GraphQLErrorPath {
231            segments: Vec::with_capacity(capacity),
232        }
233    }
234    pub fn concat(&self, segment: GraphQLErrorPathSegment) -> Self {
235        let mut new_path = self.segments.clone();
236        new_path.push(segment);
237        GraphQLErrorPath { segments: new_path }
238    }
239
240    pub fn concat_index(&self, index: usize) -> Self {
241        self.concat(GraphQLErrorPathSegment::Index(index))
242    }
243
244    pub fn concat_str(&self, field: String) -> Self {
245        self.concat(GraphQLErrorPathSegment::String(field))
246    }
247
248    pub fn extend_from_slice(&mut self, other: &[GraphQLErrorPathSegment]) {
249        self.segments.extend_from_slice(other);
250    }
251
252    pub fn entity_index_and_path<'a>(&'a self) -> Option<EntityIndexAndPath<'a>> {
253        match &self.segments.as_slice() {
254            [GraphQLErrorPathSegment::String(maybe_entities), GraphQLErrorPathSegment::Index(entity_index), rest_of_path @ ..]
255                if maybe_entities == "_entities" =>
256            {
257                Some(EntityIndexAndPath {
258                    entity_index: *entity_index,
259                    rest_of_path,
260                })
261            }
262            _ => None,
263        }
264    }
265}
266
267#[derive(Clone, Debug, Serialize, Default)]
268#[serde(rename_all = "camelCase")]
269pub struct GraphQLErrorExtensions {
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub code: Option<String>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub service_name: Option<String>,
274    /// Corresponds to a path of a Flatten(Fetch) node that caused the error.
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub affected_path: Option<String>,
277    #[serde(flatten)]
278    pub extensions: HashMap<String, Value>,
279}
280
281// Workaround for https://github.com/cloudwego/sonic-rs/issues/114
282
283impl<'de> Deserialize<'de> for GraphQLErrorExtensions {
284    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
285    where
286        D: Deserializer<'de>,
287    {
288        struct GraphQLErrorExtensionsVisitor;
289
290        impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor {
291            type Value = GraphQLErrorExtensions;
292
293            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
294                formatter.write_str("a map for GraphQLErrorExtensions")
295            }
296
297            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
298            where
299                A: de::MapAccess<'de>,
300            {
301                let mut code = None;
302                let mut service_name = None;
303                let mut affected_path = None;
304                let mut extensions = HashMap::new();
305
306                while let Some(key) = map.next_key::<String>()? {
307                    match key.as_str() {
308                        "code" => {
309                            if code.is_some() {
310                                return Err(de::Error::duplicate_field("code"));
311                            }
312                            code = Some(map.next_value()?);
313                        }
314                        "serviceName" => {
315                            if service_name.is_some() {
316                                return Err(de::Error::duplicate_field("serviceName"));
317                            }
318                            service_name = Some(map.next_value()?);
319                        }
320                        "affectedPath" => {
321                            if affected_path.is_some() {
322                                return Err(de::Error::duplicate_field("affectedPath"));
323                            }
324                            affected_path = map.next_value()?;
325                        }
326                        other_key => {
327                            let value: Value = map.next_value()?;
328                            extensions.insert(other_key.to_string(), value);
329                        }
330                    }
331                }
332
333                Ok(GraphQLErrorExtensions {
334                    code,
335                    service_name,
336                    affected_path,
337                    extensions,
338                })
339            }
340        }
341
342        deserializer.deserialize_map(GraphQLErrorExtensionsVisitor)
343    }
344}
345
346impl GraphQLErrorExtensions {
347    pub fn new_from_code<TCode: Into<String>>(code: TCode) -> Self {
348        GraphQLErrorExtensions {
349            code: Some(code.into()),
350            service_name: None,
351            affected_path: None,
352            extensions: HashMap::new(),
353        }
354    }
355
356    pub fn new_from_code_and_service_name<TCode: Into<String>, TServiceName: Into<String>>(
357        code: TCode,
358        service_name: TServiceName,
359    ) -> Self {
360        GraphQLErrorExtensions {
361            code: Some(code.into()),
362            service_name: Some(service_name.into()),
363            affected_path: None,
364            extensions: HashMap::new(),
365        }
366    }
367
368    pub fn get(&self, key: &str) -> Option<&Value> {
369        self.extensions.get(key)
370    }
371
372    pub fn set(&mut self, key: String, value: Value) {
373        self.extensions.insert(key, value);
374    }
375
376    pub fn is_empty(&self) -> bool {
377        self.code.is_none()
378            && self.service_name.is_none()
379            && self.affected_path.is_none()
380            && self.extensions.is_empty()
381    }
382}