Skip to main content

hive_router_plan_executor/response/
graphql_error.rs

1use core::fmt;
2use graphql_tools::parser::Pos;
3use graphql_tools::validation::utils::ValidationError;
4use hive_router_internal::graphql::{ObservedError, PathSegment};
5use serde::{de, Deserialize, Deserializer, Serialize};
6use sonic_rs::Value;
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Deserialize, Serialize)]
10#[serde(rename_all = "camelCase")]
11pub struct GraphQLError {
12    pub message: String,
13    #[serde(default, skip_serializing_if = "is_none_or_empty")]
14    pub locations: Option<Vec<GraphQLErrorLocation>>,
15    #[serde(default, skip_serializing_if = "Option::is_none")]
16    pub path: Option<GraphQLErrorPath>,
17    #[serde(default, skip_serializing_if = "GraphQLErrorExtensions::is_empty")]
18    pub extensions: GraphQLErrorExtensions,
19}
20
21fn is_none_or_empty<T>(opt: &Option<Vec<T>>) -> bool {
22    opt.as_ref().is_none_or(|v| v.is_empty())
23}
24
25impl From<String> for GraphQLError {
26    fn from(message: String) -> Self {
27        GraphQLError {
28            message,
29            locations: None,
30            path: None,
31            extensions: GraphQLErrorExtensions::default(),
32        }
33    }
34}
35
36impl From<&str> for GraphQLError {
37    fn from(message: &str) -> Self {
38        GraphQLError {
39            message: message.to_string(),
40            locations: None,
41            path: None,
42            extensions: GraphQLErrorExtensions::default(),
43        }
44    }
45}
46
47impl From<&ValidationError> for GraphQLError {
48    fn from(val: &ValidationError) -> Self {
49        GraphQLError {
50            message: val.message.to_string(),
51            locations: Some(val.locations.iter().map(|pos| pos.into()).collect()),
52            path: None,
53            extensions: GraphQLErrorExtensions::new_from_code(val.error_code),
54        }
55    }
56}
57
58impl From<&Pos> for GraphQLErrorLocation {
59    fn from(val: &Pos) -> Self {
60        GraphQLErrorLocation {
61            line: val.line,
62            column: val.column,
63        }
64    }
65}
66
67impl GraphQLError {
68    pub fn entity_index_and_path<'a>(&'a self) -> Option<EntityIndexAndPath<'a>> {
69        self.path.as_ref().and_then(|p| p.entity_index_and_path())
70    }
71
72    pub fn normalize_entity_error(
73        self,
74        entity_index_error_map: &HashMap<&usize, Vec<GraphQLErrorPath>>,
75    ) -> Vec<GraphQLError> {
76        if let Some(entity_index_and_path) = &self.entity_index_and_path() {
77            if let Some(entity_error_paths) =
78                entity_index_error_map.get(&entity_index_and_path.entity_index)
79            {
80                return entity_error_paths
81                    .iter()
82                    .map(|error_path| {
83                        let mut new_error_path = error_path.clone();
84                        new_error_path.extend_from_slice(entity_index_and_path.rest_of_path);
85                        GraphQLError {
86                            path: Some(new_error_path),
87                            ..self.clone()
88                        }
89                    })
90                    .collect();
91            }
92        }
93        vec![self]
94    }
95    /// Creates a GraphQLError with the given message and extensions.
96    /// Example:
97    /// ```rust
98    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
99    /// use hive_router_plan_executor::response::graphql_error::GraphQLErrorExtensions;
100    /// use sonic_rs::json;
101    ///
102    /// let extensions = GraphQLErrorExtensions {
103    ///     code: Some("SOME_ERROR_CODE".to_string()),
104    ///     service_name: None,
105    ///     affected_path: None,
106    ///     extensions: std::collections::HashMap::new(),
107    /// };
108    ///
109    /// let error = GraphQLError::from_message_and_extensions("An error occurred", extensions);
110    ///
111    /// assert_eq!(json!(error), json!({
112    ///     "message": "An error occurred",
113    ///     "extensions": {
114    ///         "code": "SOME_ERROR_CODE"
115    ///     }
116    /// }));
117    /// ```
118    pub fn from_message_and_extensions<TMessage: Into<String>>(
119        message: TMessage,
120        extensions: GraphQLErrorExtensions,
121    ) -> Self {
122        GraphQLError {
123            message: message.into(),
124            locations: None,
125            path: None,
126            extensions,
127        }
128    }
129    /// Creates a GraphQLError with the given message and code in extensions.
130    /// Example:
131    /// ```rust
132    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
133    /// use sonic_rs::json;
134    ///
135    /// let error = GraphQLError::from_message_and_code("An error occurred", "SOME_ERROR_CODE");
136    ///
137    /// assert_eq!(json!(error), json!({
138    ///     "message": "An error occurred",
139    ///     "extensions": {
140    ///         "code": "SOME_ERROR_CODE"
141    ///     }
142    /// }));
143    /// ```
144    pub fn from_message_and_code<TMessage: Into<String>, TCode: Into<String>>(
145        message: TMessage,
146        code: TCode,
147    ) -> Self {
148        GraphQLError {
149            message: message.into(),
150            locations: None,
151            path: None,
152            extensions: GraphQLErrorExtensions::new_from_code(code),
153        }
154    }
155
156    /// Adds subgraph name and error code `DOWNSTREAM_SERVICE_ERROR` to the extensions.
157    /// Example:
158    /// ```rust
159    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
160    /// use sonic_rs::json;
161    ///
162    /// let error = GraphQLError::from("An error occurred")
163    ///     .add_subgraph_name("users");
164    ///
165    /// assert_eq!(json!(error), json!({
166    ///     "message": "An error occurred",
167    ///     "extensions": {
168    ///         "serviceName": "users",
169    ///         "code": "DOWNSTREAM_SERVICE_ERROR"
170    ///     }
171    /// }));
172    /// ```
173    pub fn add_subgraph_name<TStr: Into<String>>(mut self, subgraph_name: TStr) -> Self {
174        self.extensions
175            .service_name
176            .get_or_insert(subgraph_name.into());
177        self.extensions
178            .code
179            .get_or_insert("DOWNSTREAM_SERVICE_ERROR".to_string());
180        self
181    }
182
183    /// Adds affected path to the extensions.
184    /// Example:
185    /// ```rust
186    /// use hive_router_plan_executor::response::graphql_error::GraphQLError;
187    /// use sonic_rs::json;
188    ///
189    /// let error = GraphQLError::from("An error occurred")
190    ///     .add_affected_path("user.friends[0].name");
191    ///
192    /// assert_eq!(json!(error), json!({
193    ///     "message": "An error occurred",
194    ///     "extensions": {
195    ///         "affectedPath": "user.friends[0].name"
196    ///     }
197    /// }));
198    /// ```
199    pub fn add_affected_path<TStr: Into<String>>(mut self, affected_path: TStr) -> Self {
200        self.extensions.affected_path = Some(affected_path.into());
201        self
202    }
203}
204
205impl From<&GraphQLError> for ObservedError {
206    fn from(value: &GraphQLError) -> Self {
207        Self {
208            code: value.extensions.code.clone(),
209            message: value.message.clone(),
210            path: value.path.as_ref().map(|p| {
211                ObservedError::format_path(p.segments.iter().map(|segment| match segment {
212                    GraphQLErrorPathSegment::String(s) => PathSegment::Field(s.as_str()),
213                    GraphQLErrorPathSegment::Index(i) => PathSegment::Index(*i),
214                }))
215            }),
216            service_name: value.extensions.service_name.clone(),
217            affected_path: value.extensions.affected_path.clone(),
218        }
219    }
220}
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
223pub struct GraphQLErrorLocation {
224    pub line: usize,
225    pub column: usize,
226}
227
228#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
229#[serde(untagged)]
230pub enum GraphQLErrorPathSegment {
231    String(String),
232    Index(usize),
233}
234
235#[derive(Clone, Debug, Default, Deserialize, Serialize)]
236#[serde(transparent)]
237pub struct GraphQLErrorPath {
238    pub segments: Vec<GraphQLErrorPathSegment>,
239}
240
241pub struct EntityIndexAndPath<'a> {
242    pub entity_index: usize,
243    pub rest_of_path: &'a [GraphQLErrorPathSegment],
244}
245
246impl GraphQLErrorPath {
247    pub fn with_capacity(capacity: usize) -> Self {
248        GraphQLErrorPath {
249            segments: Vec::with_capacity(capacity),
250        }
251    }
252    pub fn concat(&self, segment: GraphQLErrorPathSegment) -> Self {
253        let mut new_path = self.segments.clone();
254        new_path.push(segment);
255        GraphQLErrorPath { segments: new_path }
256    }
257
258    pub fn concat_index(&self, index: usize) -> Self {
259        self.concat(GraphQLErrorPathSegment::Index(index))
260    }
261
262    pub fn concat_str(&self, field: String) -> Self {
263        self.concat(GraphQLErrorPathSegment::String(field))
264    }
265
266    pub fn extend_from_slice(&mut self, other: &[GraphQLErrorPathSegment]) {
267        self.segments.extend_from_slice(other);
268    }
269
270    pub fn entity_index_and_path<'a>(&'a self) -> Option<EntityIndexAndPath<'a>> {
271        match &self.segments.as_slice() {
272            [GraphQLErrorPathSegment::String(maybe_entities), GraphQLErrorPathSegment::Index(entity_index), rest_of_path @ ..]
273                if maybe_entities == "_entities" =>
274            {
275                Some(EntityIndexAndPath {
276                    entity_index: *entity_index,
277                    rest_of_path,
278                })
279            }
280            _ => None,
281        }
282    }
283}
284
285#[derive(Clone, Debug, Serialize, Default)]
286#[serde(rename_all = "camelCase")]
287pub struct GraphQLErrorExtensions {
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub code: Option<String>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub service_name: Option<String>,
292    /// Corresponds to a path of a Flatten(Fetch) node that caused the error.
293    #[serde(default, skip_serializing_if = "Option::is_none")]
294    pub affected_path: Option<String>,
295    #[serde(flatten)]
296    pub extensions: HashMap<String, Value>,
297}
298
299// Workaround for https://github.com/cloudwego/sonic-rs/issues/114
300
301impl<'de> Deserialize<'de> for GraphQLErrorExtensions {
302    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
303    where
304        D: Deserializer<'de>,
305    {
306        struct GraphQLErrorExtensionsVisitor;
307
308        impl<'de> de::Visitor<'de> for GraphQLErrorExtensionsVisitor {
309            type Value = GraphQLErrorExtensions;
310
311            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
312                formatter.write_str("a map for GraphQLErrorExtensions")
313            }
314
315            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
316            where
317                A: de::MapAccess<'de>,
318            {
319                let mut code = None;
320                let mut service_name = None;
321                let mut affected_path = None;
322                let mut extensions = HashMap::new();
323
324                while let Some(key) = map.next_key::<String>()? {
325                    match key.as_str() {
326                        "code" => {
327                            if code.is_some() {
328                                return Err(de::Error::duplicate_field("code"));
329                            }
330                            code = Some(map.next_value()?);
331                        }
332                        "serviceName" => {
333                            if service_name.is_some() {
334                                return Err(de::Error::duplicate_field("serviceName"));
335                            }
336                            service_name = Some(map.next_value()?);
337                        }
338                        "affectedPath" => {
339                            if affected_path.is_some() {
340                                return Err(de::Error::duplicate_field("affectedPath"));
341                            }
342                            affected_path = map.next_value()?;
343                        }
344                        other_key => {
345                            let value: Value = map.next_value()?;
346                            extensions.insert(other_key.to_string(), value);
347                        }
348                    }
349                }
350
351                Ok(GraphQLErrorExtensions {
352                    code,
353                    service_name,
354                    affected_path,
355                    extensions,
356                })
357            }
358        }
359
360        deserializer.deserialize_map(GraphQLErrorExtensionsVisitor)
361    }
362}
363
364impl GraphQLErrorExtensions {
365    pub fn new_from_code<TCode: Into<String>>(code: TCode) -> Self {
366        GraphQLErrorExtensions {
367            code: Some(code.into()),
368            service_name: None,
369            affected_path: None,
370            extensions: HashMap::new(),
371        }
372    }
373
374    pub fn new_from_code_and_service_name<TCode: Into<String>, TServiceName: Into<String>>(
375        code: TCode,
376        service_name: TServiceName,
377    ) -> Self {
378        GraphQLErrorExtensions {
379            code: Some(code.into()),
380            service_name: Some(service_name.into()),
381            affected_path: None,
382            extensions: HashMap::new(),
383        }
384    }
385
386    pub fn get(&self, key: &str) -> Option<&Value> {
387        self.extensions.get(key)
388    }
389
390    pub fn set(&mut self, key: String, value: Value) {
391        self.extensions.insert(key, value);
392    }
393
394    pub fn is_empty(&self) -> bool {
395        self.code.is_none()
396            && self.service_name.is_none()
397            && self.affected_path.is_none()
398            && self.extensions.is_empty()
399    }
400}