grafbase_sdk/test/
gateway.rs

1use std::{
2    fmt::Write,
3    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5    time::Duration,
6};
7
8use crate::test::{
9    GraphqlRequest, LogLevel,
10    config::{CLI_BINARY_NAME, GATEWAY_BINARY_NAME, TestConfig},
11    request::Body,
12};
13
14use anyhow::Context;
15use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
16use graphql_composition::{LoadedExtension, Subgraphs};
17use itertools::Itertools;
18use regex::Regex;
19use tempfile::TempDir;
20use url::Url;
21
22/// A test runner that can start a gateway and execute GraphQL queries against it.
23pub struct TestGateway {
24    http_client: reqwest::Client,
25    config: TestConfig,
26    gateway_handle: Option<duct::Handle>,
27    gateway_listen_address: SocketAddr,
28    gateway_endpoint: Url,
29    test_specific_temp_dir: TempDir,
30    _mock_subgraphs: Vec<MockGraphQlServer>,
31    federated_sdl: String,
32}
33
34#[derive(Debug, serde::Deserialize)]
35struct ExtensionToml {
36    extension: ExtensionDefinition,
37}
38
39#[derive(Debug, serde::Deserialize)]
40struct ExtensionDefinition {
41    name: String,
42}
43
44impl TestGateway {
45    /// Creates a new test configuration builder.
46    pub fn builder() -> TestGatewayBuilder {
47        TestGatewayBuilder::new()
48    }
49
50    /// Full url of the GraphQL endpoint on the gateway.
51    pub fn url(&self) -> &Url {
52        &self.gateway_endpoint
53    }
54
55    /// Creates a new GraphQL query builder with the given query.
56    ///
57    /// # Arguments
58    ///
59    /// * `query` - The GraphQL query string to execute
60    ///
61    /// # Returns
62    ///
63    /// A [`QueryBuilder`] that can be used to customize and execute the query
64    pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
65        let builder = self.http_client.post(self.gateway_endpoint.clone());
66        GraphqlRequest {
67            builder,
68            body: query.into(),
69        }
70    }
71
72    /// Returns the federated schema as a string.
73    pub fn federated_sdl(&self) -> &str {
74        &self.federated_sdl
75    }
76}
77
78#[derive(Debug, Default, Clone)]
79/// Builder pattern to create a [`TestGateway`].
80pub struct TestGatewayBuilder {
81    gateway_path: Option<PathBuf>,
82    cli_path: Option<PathBuf>,
83    extension_path: Option<PathBuf>,
84    toml_config: Option<String>,
85    subgraphs: Vec<Subgraph>,
86    enable_stdout: Option<bool>,
87    enable_stderr: Option<bool>,
88    enable_networking: Option<bool>,
89    enable_environment_variables: Option<bool>,
90    stream_stdout_stderr: Option<bool>,
91    max_pool_size: Option<usize>,
92    log_level: Option<LogLevel>,
93}
94
95impl TestGatewayBuilder {
96    /// Creates a new [`TestConfigBuilder`] with default values.
97    pub(crate) fn new() -> Self {
98        Self::default()
99    }
100
101    /// Adds a subgraph to the test configuration.
102    pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
103        self.subgraphs.push(subgraph.into());
104        self
105    }
106
107    /// Specifies a custom path to the gateway binary. If not defined, the binary will be searched in the PATH.
108    pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
109        self.gateway_path = Some(gateway_path.into());
110        self
111    }
112
113    /// Specifies a custom path to the CLI binary. If not defined, the binary will be searched in the PATH.
114    pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
115        self.cli_path = Some(cli_path.into());
116        self
117    }
118
119    /// Specifies a path to a pre-built extension. If not defined, the extension will be built.
120    pub fn with_extension_path(mut self, extension_path: impl Into<PathBuf>) -> Self {
121        self.extension_path = Some(extension_path.into().canonicalize().unwrap());
122        self
123    }
124
125    /// Enables stdout output from the gateway and CLI. Useful for debugging errors in the gateway
126    /// and in the extension.
127    pub fn enable_stdout(mut self) -> Self {
128        self.enable_stdout = Some(true);
129        self
130    }
131
132    /// Enables stderr output from the gateway and CLI. Useful for debugging errors in the gateway
133    /// and in the extension.
134    pub fn enable_stderr(mut self) -> Self {
135        self.enable_stderr = Some(true);
136        self
137    }
138
139    /// Enables networking for the extension.
140    pub fn enable_networking(mut self) -> Self {
141        self.enable_networking = Some(true);
142        self
143    }
144
145    /// Enables environment variables for the extension.
146    pub fn enable_environment_variables(mut self) -> Self {
147        self.enable_environment_variables = Some(true);
148        self
149    }
150
151    /// Sets the maximum pool size for the extension.
152    pub fn max_pool_size(mut self, size: usize) -> Self {
153        self.max_pool_size = Some(size);
154        self
155    }
156
157    /// Sets the log level for the gateway process output.
158    pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
159        self.log_level = Some(level.into());
160        self
161    }
162
163    /// Sets the TOML configuration for the gateway. The extension and subgraphs will be
164    /// automatically added to the configuration.
165    pub fn toml_config(mut self, cfg: impl ToString) -> Self {
166        self.toml_config = Some(cfg.to_string());
167        self
168    }
169
170    /// Stream stdout and stderr from the gateway & cli commands.
171    /// Useful if you need to debug subscriptions for example. Not recommended in a CI for
172    /// reporting clarity.
173    pub fn stream_stdout_stderr(mut self) -> Self {
174        self.stream_stdout_stderr = Some(true);
175        self
176    }
177
178    /// Build the [`TestGateway`]
179    pub async fn build(self) -> anyhow::Result<TestGateway> {
180        let Self {
181            gateway_path,
182            cli_path,
183            extension_path,
184            enable_stdout,
185            enable_stderr,
186            subgraphs: mock_subgraphs,
187            enable_networking,
188            enable_environment_variables,
189            max_pool_size,
190            log_level,
191            toml_config,
192            stream_stdout_stderr,
193        } = self;
194
195        let gateway_path = match gateway_path {
196            Some(path) => path,
197            None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
198        };
199
200        let cli_path = match cli_path {
201            Some(path) => path,
202            None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
203        };
204
205        let log_level = log_level.unwrap_or_default();
206
207        TestGateway::setup(TestConfig {
208            gateway_path,
209            cli_path,
210            toml_config: toml_config.unwrap_or_default(),
211            extension_path,
212            enable_stdout,
213            enable_stderr,
214            mock_subgraphs,
215            enable_networking,
216            enable_environment_variables,
217            max_pool_size,
218            log_level,
219            stream_stdout_stderr,
220        })
221        .await
222    }
223}
224
225#[allow(clippy::panic)]
226impl TestGateway {
227    async fn setup(mut config: TestConfig) -> anyhow::Result<Self> {
228        let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
229        let gateway_listen_address = listen_address()?;
230        let gateway_endpoint = Url::parse(&format!("http://{gateway_listen_address}/graphql"))?;
231
232        let extension_toml_path = std::env::current_dir()?.join("extension.toml");
233        let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
234        let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
235        let extension_name = extension_toml.extension.name;
236
237        let mut mock_subgraphs = Vec::new();
238        let mut subgraphs = Subgraphs::default();
239
240        let extension_path = match config.extension_path {
241            Some(ref path) => path.to_path_buf(),
242            None => std::env::current_dir()?.join("build"),
243        };
244
245        let extension_url = url::Url::from_file_path(&extension_path).unwrap();
246        subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
247            extension_url.to_string(),
248            extension_name.clone(),
249        )));
250
251        let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
252        let rep = format!(r#"@link(url: "{extension_url}""#);
253        for subgraph in config.mock_subgraphs.drain(..) {
254            match subgraph {
255                Subgraph::Graphql(subgraph) => {
256                    let mock_graph = subgraph.start().await;
257                    let sdl = re.replace_all(mock_graph.schema(), &rep);
258                    subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
259                    mock_subgraphs.push(mock_graph);
260                }
261                Subgraph::Virtual(subgraph) => {
262                    let sdl = re.replace_all(subgraph.schema(), &rep);
263                    subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
264                }
265            }
266        }
267
268        let federated_graph = match graphql_composition::compose(&subgraphs)
269            .warnings_are_fatal()
270            .into_result()
271        {
272            Ok(graph) => graph,
273            Err(diagnostics) => {
274                return Err(anyhow::anyhow!(
275                    "Failed to compose subgraphs:\n{}\n",
276                    diagnostics
277                        .iter_messages()
278                        .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
279                ));
280            }
281        };
282        let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
283
284        let mut this = Self {
285            http_client: reqwest::Client::new(),
286            config,
287            gateway_handle: None,
288            gateway_listen_address,
289            gateway_endpoint,
290            test_specific_temp_dir,
291            _mock_subgraphs: mock_subgraphs,
292            federated_sdl,
293        };
294
295        if this.config.extension_path.is_none() {
296            this.build_extension(&extension_path)?;
297        }
298
299        this.start_servers(&extension_name, &extension_path)
300            .await
301            .map_err(|err| anyhow::anyhow!("Failed to start servers: {err}"))?;
302
303        Ok(this)
304    }
305
306    async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
307        let extension_path = extension_path.display();
308        let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
309        let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
310
311        let config = {
312            let max_pool_size = self.config.max_pool_size.unwrap_or(100);
313            let mut config = indoc::formatdoc! {r#"
314                [extensions.{extension_name}]
315                path = "{extension_path}"
316                max_pool_size = {max_pool_size}
317            "#};
318            if let Some(enabled) = self.config.enable_stderr {
319                writeln!(config, "stderr = {enabled}").unwrap();
320            }
321            if let Some(enabled) = self.config.enable_stdout {
322                writeln!(config, "stdout = {enabled}").unwrap();
323            }
324            if let Some(enabled) = self.config.enable_networking {
325                writeln!(config, "networking = {enabled}").unwrap();
326            }
327            if let Some(enabled) = self.config.enable_environment_variables {
328                writeln!(config, "environment_variables = {enabled}").unwrap();
329            }
330            config.push_str("\n\n");
331            config.push_str(&self.config.toml_config);
332            config
333        };
334        println!("{config}");
335
336        std::fs::write(&config_path, config.as_bytes())
337            .map_err(|err| anyhow::anyhow!("Failed to write config at {:?}: {err}", config_path))?;
338        std::fs::write(&schema_path, self.federated_sdl.as_bytes())
339            .map_err(|err| anyhow::anyhow!("Failed to write schema at {:?}: {err}", schema_path))?;
340
341        let args = &[
342            "--listen-address",
343            &self.gateway_listen_address.to_string(),
344            "--config",
345            &config_path.to_string_lossy(),
346            "--schema",
347            &schema_path.to_string_lossy(),
348            "--log",
349            self.config.log_level.as_ref(),
350        ];
351
352        let gateway_handle = {
353            let cmd = duct::cmd(&self.config.gateway_path, args);
354            if self.config.stream_stdout_stderr.unwrap_or(false) {
355                cmd
356            } else {
357                cmd.stdout_capture().stderr_capture()
358            }
359        }
360        .unchecked()
361        .stderr_to_stdout()
362        .start()
363        .map_err(|err| anyhow::anyhow!("Failed to start the gateway: {err}"))?;
364
365        let mut i = 0;
366        while !self.check_gateway_health().await? {
367            // printing every second only
368            if i % 10 == 0 {
369                match gateway_handle.try_wait() {
370                    Ok(Some(output)) => panic!(
371                        "Gateway process exited unexpectedly: {}\n{}\n{}",
372                        output.status,
373                        String::from_utf8_lossy(&output.stdout),
374                        String::from_utf8_lossy(&output.stderr)
375                    ),
376                    Ok(None) => (),
377                    Err(err) => panic!("Error waiting for gateway process: {err}"),
378                }
379                println!("Waiting for gateway to be ready...");
380            }
381            i += 1;
382            std::thread::sleep(Duration::from_millis(100));
383        }
384
385        self.gateway_handle = Some(gateway_handle);
386
387        Ok(())
388    }
389
390    async fn check_gateway_health(&self) -> anyhow::Result<bool> {
391        let url = self.gateway_endpoint.join("/health")?;
392
393        let Ok(result) = self.http_client.get(url).send().await else {
394            return Ok(false);
395        };
396
397        let result = result.error_for_status().is_ok();
398
399        Ok(result)
400    }
401
402    fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
403        let extension_path = extension_path.to_string_lossy();
404
405        // Only one test can build the extension at a time. The others must
406        // wait.
407        let mut lock_file = fslock::LockFile::open(".build.lock")?;
408        lock_file.lock()?;
409
410        let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
411        let output = {
412            let cmd = duct::cmd(&self.config.cli_path, args);
413            if self.config.stream_stdout_stderr.unwrap_or(false) {
414                cmd
415            } else {
416                cmd.stdout_capture().stderr_capture()
417            }
418        }
419        .unchecked()
420        .stderr_to_stdout()
421        .run()?;
422        if !output.status.success() {
423            panic!(
424                "Failed to build extension: {}\n{}\n{}",
425                output.status,
426                String::from_utf8_lossy(&output.stdout),
427                String::from_utf8_lossy(&output.stderr)
428            );
429        }
430
431        lock_file.unlock()?;
432
433        Ok(())
434    }
435}
436
437pub(crate) fn free_port() -> anyhow::Result<u16> {
438    const INITIAL_PORT: u16 = 14712;
439
440    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
441    std::fs::create_dir_all(&test_dir)?;
442
443    let lock_file_path = test_dir.join("port-number.lock");
444    let port_number_file_path = test_dir.join("port-number.txt");
445
446    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
447    lock_file.lock()?;
448
449    let port = if port_number_file_path.exists() {
450        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
451    } else {
452        INITIAL_PORT
453    };
454
455    std::fs::write(&port_number_file_path, port.to_string())?;
456    lock_file.unlock()?;
457
458    Ok(port)
459}
460
461pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
462    let port = free_port()?;
463    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
464}
465
466impl Drop for TestGateway {
467    fn drop(&mut self) {
468        let Some(handle) = self.gateway_handle.take() else {
469            return;
470        };
471
472        if let Err(err) = handle.kill() {
473            eprintln!("Failed to kill grafbase-gateway: {err}")
474        }
475    }
476}