grafbase_sdk/test/
request.rs

1use std::{
2    borrow::Cow,
3    ops::{Deref, DerefMut},
4};
5
6use bytes::Bytes;
7use futures_util::{StreamExt as _, stream::BoxStream};
8use http::HeaderValue;
9use http_body_util::BodyExt as _;
10use serde::de::DeserializeOwned;
11
12/// Represents a GraphQL request.
13pub struct GraphqlRequest {
14    pub(super) builder: reqwest::RequestBuilder,
15    pub(super) body: Body,
16}
17
18impl GraphqlRequest {
19    /// Add a header to the request.
20    pub fn header<Name, Value>(mut self, name: Name, value: Value) -> Self
21    where
22        Name: TryInto<http::HeaderName, Error: std::fmt::Debug>,
23        Value: TryInto<http::HeaderValue, Error: std::fmt::Debug>,
24    {
25        self.builder = self.builder.header(name.try_into().unwrap(), value.try_into().unwrap());
26        self
27    }
28
29    /// Add the GraphQL variables to the request.
30    pub fn variables(mut self, variables: impl serde::Serialize) -> Self {
31        self.body.variables = Some(serde_json::to_value(variables).expect("variables to be serializable"));
32        self
33    }
34
35    /// Send the GraphQL request to the gateway
36    pub async fn send(self) -> GraphqlResponse {
37        let response = self
38            .builder
39            .header(http::header::ACCEPT, "application/json")
40            .json(&self.body)
41            .send()
42            .await
43            .expect("Request suceeded");
44        let (parts, body) = http::Response::from(response).into_parts();
45        let bytes = body.collect().await.expect("Could retrieve response body").to_bytes();
46        http::Response::from_parts(parts, bytes).try_into().unwrap()
47    }
48
49    /// Send the GraphQL request to the gateway and return a streaming response through a
50    /// websocket.
51    pub async fn ws_stream(self) -> GraphqlStreamingResponse {
52        use async_tungstenite::tungstenite::client::IntoClientRequest as _;
53        use futures_util::StreamExt;
54
55        let mut req = self.builder.build().expect("Valid request");
56        req.url_mut().set_scheme("ws").expect("Valid URL scheme");
57        req.url_mut().set_path("/ws");
58        let (parts, _) = http::Request::try_from(req).expect("Valid HTTP request").into_parts();
59
60        let mut request = parts.uri.into_client_request().unwrap();
61
62        request.headers_mut().extend(parts.headers);
63        request.headers_mut().insert(
64            http::header::SEC_WEBSOCKET_PROTOCOL,
65            HeaderValue::from_str("graphql-transport-ws").unwrap(),
66        );
67
68        let (connection, response) = async_tungstenite::tokio::connect_async(request)
69            .await
70            .expect("Request suceeded");
71        let (parts, _) = response.into_parts();
72
73        let (client, actor) = graphql_ws_client::Client::build(connection)
74            .await
75            .expect("Client build succeeded");
76
77        tokio::spawn(actor.into_future());
78
79        let stream: BoxStream<'_, _> = Box::pin(
80            client
81                .subscribe(self.body)
82                .await
83                .expect("Subscription succeeded")
84                .map(move |item| item.unwrap()),
85        );
86
87        GraphqlStreamingResponse {
88            status: parts.status,
89            headers: parts.headers,
90            stream,
91        }
92    }
93}
94
95/// Represents the body of a GraphQL request.
96#[derive(serde::Serialize)]
97pub struct Body {
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub(super) query: Option<String>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub(super) variables: Option<serde_json::Value>,
102}
103
104impl<'a> From<&'a str> for Body {
105    fn from(value: &'a str) -> Self {
106        value.to_string().into()
107    }
108}
109
110impl<'a> From<&'a String> for Body {
111    fn from(value: &'a String) -> Self {
112        value.clone().into()
113    }
114}
115
116impl From<String> for Body {
117    fn from(query: String) -> Self {
118        Body {
119            query: Some(query),
120            variables: None,
121        }
122    }
123}
124
125/// Represents a GraphQL response.
126#[derive(serde::Serialize, Debug, serde::Deserialize)]
127pub struct GraphqlResponse {
128    /// The HTTP status code of the response.
129    #[serde(skip)]
130    status: http::StatusCode,
131    /// The HTTP headers of the response.
132    #[serde(skip)]
133    headers: http::HeaderMap,
134    /// The body of the response, which contains the GraphQL data.
135    #[serde(flatten)]
136    body: serde_json::Value,
137}
138
139impl TryFrom<http::Response<Bytes>> for GraphqlResponse {
140    type Error = serde_json::Error;
141
142    fn try_from(response: http::Response<Bytes>) -> Result<Self, Self::Error> {
143        let (parts, body) = response.into_parts();
144
145        Ok(GraphqlResponse {
146            status: parts.status,
147            body: serde_json::from_slice(body.as_ref())
148                .unwrap_or_else(|err| serde_json::Value::String(format!("Could not deserialize JSON data: {err}"))),
149            headers: parts.headers,
150        })
151    }
152}
153
154impl std::fmt::Display for GraphqlResponse {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(f, "{}", serde_json::to_string_pretty(&self.body).unwrap())
157    }
158}
159
160impl Deref for GraphqlResponse {
161    type Target = serde_json::Value;
162
163    fn deref(&self) -> &Self::Target {
164        &self.body
165    }
166}
167
168impl DerefMut for GraphqlResponse {
169    fn deref_mut(&mut self) -> &mut Self::Target {
170        &mut self.body
171    }
172}
173
174impl GraphqlResponse {
175    /// Returns the HTTP status code of the response.
176    pub fn status(&self) -> http::StatusCode {
177        self.status
178    }
179
180    /// Returns the HTTP headers of the response.
181    pub fn headers(&self) -> &http::HeaderMap {
182        &self.headers
183    }
184
185    /// Consumes the response and returns the body as a JSON value.
186    pub fn into_body(self) -> serde_json::Value {
187        self.body
188    }
189
190    /// Deserializes the response body
191    pub fn deserialize<T: DeserializeOwned>(self) -> anyhow::Result<T> {
192        serde_json::from_value(self.body).map_err(Into::into)
193    }
194
195    /// Extracts the `data` field from the response body, if it exists.
196    #[track_caller]
197    pub fn into_data(self) -> serde_json::Value {
198        assert!(self.errors().is_empty(), "{self:#?}");
199
200        match self.body {
201            serde_json::Value::Object(mut value) => value.remove("data"),
202            _ => None,
203        }
204        .unwrap_or_default()
205    }
206
207    /// Returns the `errors` field from the response body, if it exists.
208    pub fn errors(&self) -> Cow<'_, Vec<serde_json::Value>> {
209        self.body["errors"]
210            .as_array()
211            .map(Cow::Borrowed)
212            .unwrap_or_else(|| Cow::Owned(Vec::new()))
213    }
214}
215
216/// Represents a GraphQL subscription response.
217pub struct GraphqlStreamingResponse {
218    /// The HTTP status code of the response.
219    status: http::StatusCode,
220    /// The HTTP headers of the response.
221    headers: http::HeaderMap,
222    /// The stream of messages from the subscription.
223    stream: BoxStream<'static, serde_json::Value>,
224}
225
226impl std::ops::Deref for GraphqlStreamingResponse {
227    type Target = BoxStream<'static, serde_json::Value>;
228    fn deref(&self) -> &Self::Target {
229        &self.stream
230    }
231}
232
233impl std::ops::DerefMut for GraphqlStreamingResponse {
234    fn deref_mut(&mut self) -> &mut Self::Target {
235        &mut self.stream
236    }
237}
238
239impl GraphqlStreamingResponse {
240    /// Returns the HTTP status code of the response.
241    pub fn status(&self) -> http::StatusCode {
242        self.status
243    }
244
245    /// Returns the HTTP headers of the response.
246    pub fn headers(&self) -> &http::HeaderMap {
247        &self.headers
248    }
249
250    /// Consumes the response and returns the underlying stream.
251    pub fn into_stream(self) -> BoxStream<'static, serde_json::Value> {
252        self.stream
253    }
254
255    /// Consumes the response and returns the first `n` messages.
256    pub async fn take(self, n: usize) -> GraphqlCollectedStreamingResponse {
257        let messages = self.stream.take(n).collect().await;
258        GraphqlCollectedStreamingResponse {
259            status: self.status,
260            headers: self.headers,
261            messages,
262        }
263    }
264
265    /// Collect all messages from the subscription stream.
266    pub async fn collect(self) -> GraphqlCollectedStreamingResponse {
267        let messages = self.stream.collect().await;
268        GraphqlCollectedStreamingResponse {
269            status: self.status,
270            headers: self.headers,
271            messages,
272        }
273    }
274}
275
276/// Represents a collected GraphQL subscription response.
277#[derive(Debug)]
278pub struct GraphqlCollectedStreamingResponse {
279    /// The HTTP status code of the response.
280    status: http::StatusCode,
281    /// The HTTP headers of the response.
282    headers: http::HeaderMap,
283    /// The collected messages from the subscription.
284    messages: Vec<serde_json::Value>,
285}
286
287impl GraphqlCollectedStreamingResponse {
288    /// Returns the HTTP status code of the response.
289    pub fn status(&self) -> http::StatusCode {
290        self.status
291    }
292    /// Returns the HTTP headers of the response.
293    pub fn headers(&self) -> &http::HeaderMap {
294        &self.headers
295    }
296    /// Returns the collected messages from the subscription.
297    pub fn messages(&self) -> &Vec<serde_json::Value> {
298        &self.messages
299    }
300    /// Consumes the response and returns the collected messages.
301    pub fn into_messages(self) -> Vec<serde_json::Value> {
302        self.messages
303    }
304}
305
306impl graphql_ws_client::graphql::GraphqlOperation for Body {
307    type Response = serde_json::Value;
308    type Error = serde_json::Error;
309
310    fn decode(&self, data: serde_json::Value) -> Result<Self::Response, Self::Error> {
311        Ok(data)
312    }
313}
314
315impl serde::Serialize for GraphqlCollectedStreamingResponse {
316    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317    where
318        S: serde::Serializer,
319    {
320        self.messages.serialize(serializer)
321    }
322}