grafbase_sdk/test/
runner.rs

1use std::{
2    future::IntoFuture,
3    marker::PhantomData,
4    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
5    path::Path,
6    time::Duration,
7};
8
9use super::TestConfig;
10use async_tungstenite::tungstenite::handshake::client::Request;
11use futures_util::{StreamExt, stream::BoxStream};
12use grafbase_sdk_mock::{MockGraphQlServer, MockSubgraph};
13use graphql_composition::{LoadedExtension, Subgraphs};
14use graphql_ws_client::graphql::GraphqlOperation;
15use http::{
16    HeaderValue,
17    header::{IntoHeaderName, SEC_WEBSOCKET_PROTOCOL},
18};
19use serde::de::DeserializeOwned;
20use tempfile::TempDir;
21use tungstenite::client::IntoClientRequest;
22use url::Url;
23
24/// A test runner that can start a gateway and execute GraphQL queries against it.
25pub struct TestRunner {
26    http_client: reqwest::Client,
27    config: TestConfig,
28    gateway_handle: Option<duct::Handle>,
29    gateway_listen_address: SocketAddr,
30    gateway_endpoint: Url,
31    test_specific_temp_dir: TempDir,
32    _mock_subgraphs: Vec<MockGraphQlServer>,
33    federated_graph: String,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ExtensionToml {
38    extension: ExtensionDefinition,
39}
40
41#[derive(Debug, serde::Deserialize)]
42struct ExtensionDefinition {
43    name: String,
44}
45
46#[allow(clippy::panic)]
47impl TestRunner {
48    /// Creates a new [`TestRunner`] with the given [`TestConfig`].
49    pub async fn new(mut config: TestConfig) -> anyhow::Result<Self> {
50        let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
51        let gateway_listen_address = listen_address()?;
52        let gateway_endpoint = Url::parse(&format!("http://{}/graphql", gateway_listen_address))?;
53
54        let extension_toml_path = std::env::current_dir()?.join("extension.toml");
55        let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
56        let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
57        let extension_name = extension_toml.extension.name;
58
59        let mut mock_subgraphs = Vec::new();
60        let mut subgraphs = Subgraphs::default();
61
62        let extension_path = match config.extension_path {
63            Some(ref path) => path.to_path_buf(),
64            None => std::env::current_dir()?.join("build"),
65        };
66
67        subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
68            format!("file://{}", extension_path.display()),
69            extension_name.clone(),
70        )));
71
72        for subgraph in config.mock_subgraphs.drain(..) {
73            match subgraph {
74                MockSubgraph::Dynamic(subgraph) => {
75                    let mock_graph = subgraph.start().await;
76                    subgraphs.ingest_str(mock_graph.sdl(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
77                    mock_subgraphs.push(mock_graph);
78                }
79                MockSubgraph::ExtensionOnly(subgraph) => {
80                    subgraphs.ingest_str(subgraph.sdl(), subgraph.name(), None)?;
81                }
82            }
83        }
84
85        let federated_graph = graphql_composition::compose(&subgraphs)
86            .warnings_are_fatal()
87            .into_result()
88            .unwrap();
89        let federated_graph = graphql_composition::render_federated_sdl(&federated_graph)?;
90
91        let mut this = Self {
92            http_client: reqwest::Client::new(),
93            config,
94            gateway_handle: None,
95            gateway_listen_address,
96            gateway_endpoint,
97            test_specific_temp_dir,
98            _mock_subgraphs: mock_subgraphs,
99            federated_graph,
100        };
101
102        if this.config.extension_path.is_none() {
103            this.build_extension(&extension_path)?;
104        }
105
106        this.start_servers(&extension_name, &extension_path)
107            .await
108            .map_err(|err| anyhow::anyhow!("Failed to start servers: {err}"))?;
109
110        Ok(this)
111    }
112
113    async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
114        let extension_path = extension_path.display();
115        let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
116        let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
117        let config = &self.config.gateway_configuration;
118        let enable_stdout = self.config.enable_stdout;
119        let enable_stderr = self.config.enable_stdout;
120        let enable_networking = self.config.enable_networking;
121        let enable_environment_variables = self.config.enable_environment_variables;
122        let max_pool_size = self.config.max_pool_size.unwrap_or(100);
123
124        let config = indoc::formatdoc! {r#"
125            [extensions.{extension_name}]
126            path = "{extension_path}"
127            stdout = {enable_stdout}
128            stderr = {enable_stderr}
129            networking = {enable_networking}
130            environment_variables = {enable_environment_variables}
131            max_pool_size = {max_pool_size}
132
133            {config}
134        "#};
135
136        println!("{config}");
137
138        std::fs::write(&config_path, config.as_bytes())
139            .map_err(|err| anyhow::anyhow!("Failed to write config at {:?}: {err}", config_path))?;
140        std::fs::write(&schema_path, self.federated_graph.as_bytes())
141            .map_err(|err| anyhow::anyhow!("Failed to write schema at {:?}: {err}", schema_path))?;
142
143        let args = &[
144            "--listen-address",
145            &self.gateway_listen_address.to_string(),
146            "--config",
147            &config_path.to_string_lossy(),
148            "--schema",
149            &schema_path.to_string_lossy(),
150            "--log",
151            self.config.log_level.as_ref(),
152        ];
153
154        let mut expr = duct::cmd(&self.config.gateway_path, args);
155
156        if !self.config.enable_stderr {
157            expr = expr.stderr_capture();
158        }
159
160        if !self.config.enable_stdout {
161            expr = expr.stdout_capture();
162        }
163
164        let gateway_handle = expr
165            .unchecked()
166            .start()
167            .map_err(|err| anyhow::anyhow!("Failed to start the gateway: {err}"))?;
168
169        let mut i = 0;
170        while !self.check_gateway_health().await? {
171            // printing every second only
172            if i % 10 == 0 {
173                match gateway_handle.try_wait() {
174                    Ok(Some(output)) => panic!(
175                        "Gateway process exited unexpectedly: {:?}\n{}\n{}",
176                        output.status,
177                        String::from_utf8_lossy(&output.stdout),
178                        String::from_utf8_lossy(&output.stderr)
179                    ),
180                    Ok(None) => (),
181                    Err(err) => panic!("Error waiting for gateway process: {}", err),
182                }
183                println!("Waiting for gateway to be ready...");
184            }
185            i += 1;
186            std::thread::sleep(Duration::from_millis(100));
187        }
188
189        self.gateway_handle = Some(gateway_handle);
190
191        Ok(())
192    }
193
194    async fn check_gateway_health(&self) -> anyhow::Result<bool> {
195        let url = self.gateway_endpoint.join("/health")?;
196
197        let Ok(result) = self.http_client.get(url).send().await else {
198            return Ok(false);
199        };
200
201        let result = result.error_for_status().is_ok();
202
203        Ok(result)
204    }
205
206    fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
207        let extension_path = extension_path.to_string_lossy();
208
209        // Only one test can build the extension at a time. The others must
210        // wait.
211        let mut lock_file = fslock::LockFile::open(".build.lock")?;
212        lock_file.lock()?;
213
214        let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
215        let mut expr = duct::cmd(&self.config.cli_path, args);
216
217        if !self.config.enable_stdout {
218            expr = expr.stdout_capture();
219        }
220
221        if !self.config.enable_stderr {
222            expr = expr.stderr_capture();
223        }
224
225        let output = expr.unchecked().run()?;
226        if !output.status.success() {
227            panic!(
228                "Failed to build extension: {}\n{}\n{}",
229                output.status,
230                String::from_utf8_lossy(&output.stdout),
231                String::from_utf8_lossy(&output.stderr)
232            );
233        }
234
235        lock_file.unlock()?;
236
237        Ok(())
238    }
239
240    /// Creates a new GraphQL query builder with the given query.
241    ///
242    /// # Arguments
243    ///
244    /// * `query` - The GraphQL query string to execute
245    ///
246    /// # Returns
247    ///
248    /// A [`QueryBuilder`] that can be used to customize and execute the query
249    pub fn graphql_query<Response>(&self, query: impl Into<String>) -> QueryBuilder<Response> {
250        let reqwest_builder = self
251            .http_client
252            .post(self.gateway_endpoint.clone())
253            .header(http::header::ACCEPT, "application/json");
254
255        QueryBuilder {
256            query: query.into(),
257            variables: None,
258            phantom: PhantomData,
259            reqwest_builder,
260        }
261    }
262
263    ///
264    /// # Arguments
265    ///
266    /// * `query` - The GraphQL subscription query string to execute
267    ///
268    /// # Returns
269    ///
270    /// A [`SubscriptionBuilder`] that can be used to customize and execute the subscription
271    pub fn graphql_subscription<Response>(
272        &self,
273        query: impl Into<String>,
274    ) -> anyhow::Result<SubscriptionBuilder<Response>> {
275        let mut url = self.gateway_endpoint.clone();
276
277        url.set_path("/ws");
278        url.set_scheme("ws").unwrap();
279
280        let mut request_builder = url.as_ref().into_client_request()?;
281
282        request_builder
283            .headers_mut()
284            .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("graphql-transport-ws"));
285
286        let operation = Operation {
287            query: query.into(),
288            variables: None,
289            phantom: PhantomData,
290        };
291
292        Ok(SubscriptionBuilder {
293            operation,
294            request_builder,
295        })
296    }
297
298    /// Returns the federated schema as a string.
299    pub fn federated_graph(&self) -> &str {
300        &self.federated_graph
301    }
302}
303
304pub(crate) fn free_port() -> anyhow::Result<u16> {
305    const INITIAL_PORT: u16 = 14712;
306
307    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
308    std::fs::create_dir_all(&test_dir)?;
309
310    let lock_file_path = test_dir.join("port-number.lock");
311    let port_number_file_path = test_dir.join("port-number.txt");
312
313    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
314    lock_file.lock()?;
315
316    let port = if port_number_file_path.exists() {
317        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
318    } else {
319        INITIAL_PORT
320    };
321
322    std::fs::write(&port_number_file_path, port.to_string())?;
323    lock_file.unlock()?;
324
325    Ok(port)
326}
327
328pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
329    let port = free_port()?;
330    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
331}
332
333impl Drop for TestRunner {
334    fn drop(&mut self) {
335        let Some(handle) = self.gateway_handle.take() else {
336            return;
337        };
338
339        if let Err(err) = handle.kill() {
340            eprintln!("Failed to kill grafbase-gateway: {}", err)
341        }
342    }
343}
344
345#[derive(serde::Serialize)]
346#[must_use]
347/// A builder for constructing GraphQL queries with customizable parameters and headers.
348pub struct QueryBuilder<Response> {
349    // These two will be serialized into the request
350    query: String,
351    #[serde(skip_serializing_if = "Option::is_none")]
352    variables: Option<serde_json::Value>,
353
354    // These won't
355    #[serde(skip)]
356    phantom: PhantomData<fn() -> Response>,
357    #[serde(skip)]
358    reqwest_builder: reqwest::RequestBuilder,
359}
360
361impl<Response> QueryBuilder<Response> {
362    /// Adds variables to the GraphQL query.
363    ///
364    /// # Arguments
365    ///
366    /// * `variables` - The variables to include with the query, serializable to JSON
367    pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
368        self.variables = Some(serde_json::to_value(variables).unwrap());
369        self
370    }
371
372    /// Adds a header to the GraphQL request.
373    pub fn with_header(self, name: &str, value: &str) -> Self {
374        let Self {
375            phantom,
376            query,
377            mut reqwest_builder,
378            variables,
379        } = self;
380
381        reqwest_builder = reqwest_builder.header(name, value);
382
383        Self {
384            query,
385            variables,
386            phantom,
387            reqwest_builder,
388        }
389    }
390
391    /// Sends the GraphQL query and returns the response.
392    ///
393    /// # Returns
394    ///
395    /// The deserialized response from the GraphQL server
396    ///
397    /// # Errors
398    ///
399    /// Will return an error if:
400    /// - Request serialization fails
401    /// - Network request fails
402    /// - Response deserialization fails
403    pub async fn send(self) -> anyhow::Result<Response>
404    where
405        Response: for<'de> serde::Deserialize<'de>,
406    {
407        let json = serde_json::to_value(&self)?;
408        Ok(self.reqwest_builder.json(&json).send().await?.json().await?)
409    }
410}
411
412#[must_use]
413/// A builder for constructing GraphQL queries with customizable parameters and headers.
414pub struct SubscriptionBuilder<Response> {
415    operation: Operation<Response>,
416    request_builder: Request,
417}
418
419#[derive(serde::Serialize)]
420struct Operation<Response> {
421    query: String,
422    #[serde(skip_serializing_if = "Option::is_none")]
423    variables: Option<serde_json::Value>,
424    #[serde(skip)]
425    phantom: PhantomData<fn() -> Response>,
426}
427
428impl<Response> GraphqlOperation for Operation<Response>
429where
430    Response: DeserializeOwned,
431{
432    type Response = Response;
433    type Error = serde_json::Error;
434
435    fn decode(&self, data: serde_json::Value) -> Result<Self::Response, Self::Error> {
436        serde_json::from_value(data)
437    }
438}
439
440impl<Response> SubscriptionBuilder<Response>
441where
442    Response: DeserializeOwned + 'static,
443{
444    /// Adds variables to the GraphQL subscription.
445    ///
446    /// # Arguments
447    ///
448    /// * `variables` - The variables to include with the subscription, serializable to JSON
449    pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
450        self.operation.variables = Some(serde_json::to_value(variables).unwrap());
451        self
452    }
453
454    /// Adds a header to the GraphQL request.
455    ///
456    /// # Arguments
457    ///
458    /// * `name` - The header name
459    /// * `value` - The header value
460    pub fn with_header<K>(mut self, name: K, value: HeaderValue) -> Self
461    where
462        K: IntoHeaderName,
463    {
464        self.request_builder.headers_mut().insert(name, value);
465        self
466    }
467
468    /// Subscribes to the GraphQL subscription and returns a stream of responses.
469    ///
470    /// # Returns
471    ///
472    /// A pinned stream that yields the deserialized subscription responses
473    ///
474    /// # Errors
475    ///
476    /// Will return an error if:
477    /// - WebSocket connection fails
478    /// - GraphQL subscription initialization fails
479    pub async fn subscribe(self) -> anyhow::Result<BoxStream<'static, Response>> {
480        let (connection, _) = async_tungstenite::tokio::connect_async(self.request_builder).await?;
481        let (client, actor) = graphql_ws_client::Client::build(connection).await?;
482
483        tokio::spawn(actor.into_future());
484
485        let stream = client
486            .subscribe(self.operation)
487            .await?
488            .map(move |item| -> Response { item.unwrap() });
489
490        Ok(Box::pin(stream))
491    }
492}