grafbase_sdk/test/
gateway.rs

1use std::{
2    collections::hash_map::Entry,
3    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5    time::Duration,
6};
7
8use crate::test::{
9    GraphqlRequest, LogLevel,
10    config::{
11        CLI_BINARY_NAME, ExtensionConfig, ExtensionToml, GATEWAY_BINARY_NAME, GatewayToml, StructuredExtensionConfig,
12    },
13    request::{Body, IntrospectionRequest},
14};
15
16use anyhow::{Context, anyhow};
17use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
18use graphql_composition::Subgraphs;
19use itertools::Itertools;
20use regex::Regex;
21use tempfile::TempDir;
22use url::Url;
23
24/// A test runner that can start a gateway and execute GraphQL queries against it.
25pub struct TestGateway {
26    http_client: reqwest::Client,
27    handle: duct::Handle,
28    url: Url,
29    federated_sdl: String,
30    // Kept to drop them at the right time.
31    #[allow(unused)]
32    tmp_dir: TempDir,
33    #[allow(unused)]
34    mock_subgraphs: Vec<MockGraphQlServer>,
35}
36
37impl TestGateway {
38    /// Creates a new test configuration builder.
39    pub fn builder() -> TestGatewayBuilder {
40        TestGatewayBuilder::new()
41    }
42
43    /// Full url of the GraphQL endpoint on the gateway.
44    pub fn url(&self) -> &Url {
45        &self.url
46    }
47
48    /// Creates a new GraphQL query builder with the given query.
49    ///
50    /// # Arguments
51    ///
52    /// * `query` - The GraphQL query string to execute
53    ///
54    /// # Returns
55    ///
56    /// A [`QueryBuilder`] that can be used to customize and execute the query
57    pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
58        let builder = self.http_client.post(self.url.clone());
59        GraphqlRequest {
60            builder,
61            body: query.into(),
62        }
63    }
64
65    /// Returns the federated schema as a string.
66    pub fn federated_sdl(&self) -> &str {
67        &self.federated_sdl
68    }
69
70    /// Execute a GraphQL introspection query to retrieve the API schema as a string.
71    /// Beware that introspection must be explicitly enabled with in the TOML config:
72    /// ```toml
73    /// [graph]
74    /// introspection = true
75    /// ```
76    pub fn introspect(&self) -> IntrospectionRequest {
77        let operation = cynic_introspection::IntrospectionQuery::with_capabilities(
78            cynic_introspection::SpecificationVersion::October2021.capabilities(),
79        );
80        IntrospectionRequest(self.query(Body {
81            query: Some(operation.query),
82            variables: None,
83        }))
84    }
85
86    /// Checks if the gateway is healthy by sending a request to the `/health` endpoint.
87    pub async fn health(&self) -> anyhow::Result<()> {
88        let url = self.url.join("/health")?;
89        let _ = self.http_client.get(url).send().await?.error_for_status()?;
90        Ok(())
91    }
92}
93
94#[derive(Debug, Default, Clone)]
95/// Builder pattern to create a [`TestGateway`].
96pub struct TestGatewayBuilder {
97    gateway_path: Option<PathBuf>,
98    cli_path: Option<PathBuf>,
99    toml_config: Option<String>,
100    subgraphs: Vec<Subgraph>,
101    stream_stdout_stderr: Option<bool>,
102    log_level: Option<LogLevel>,
103}
104
105impl TestGatewayBuilder {
106    /// Creates a new [`TestConfigBuilder`] with default values.
107    pub(crate) fn new() -> Self {
108        Self::default()
109    }
110
111    /// Adds a subgraph to the test configuration.
112    pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
113        self.subgraphs.push(subgraph.into());
114        self
115    }
116
117    /// Specifies a custom path to the gateway binary. If not defined, the binary will be searched in the PATH.
118    pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
119        self.gateway_path = Some(gateway_path.into());
120        self
121    }
122
123    /// Specifies a custom path to the CLI binary. If not defined, the binary will be searched in the PATH.
124    pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
125        self.cli_path = Some(cli_path.into());
126        self
127    }
128
129    /// Sets the TOML configuration for the gateway. The extension and subgraphs will be
130    /// automatically added to the configuration.
131    pub fn toml_config(mut self, cfg: impl ToString) -> Self {
132        self.toml_config = Some(cfg.to_string());
133        self
134    }
135
136    /// Sets the log level for the gateway process output.
137    pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
138        self.log_level = Some(level.into());
139        self
140    }
141
142    /// Stream stdout and stderr from the gateway & cli commands.
143    /// Useful if you need to debug subscriptions for example. Not recommended in a CI for
144    /// reporting clarity.
145    pub fn stream_stdout_stderr(mut self) -> Self {
146        self.stream_stdout_stderr = Some(true);
147        self
148    }
149
150    /// Build the [`TestGateway`]
151    pub async fn build(self) -> anyhow::Result<TestGateway> {
152        println!("Building the gateway:");
153
154        let gateway_path = match self.gateway_path {
155            Some(path) => path,
156            None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
157        };
158
159        let cli_path = match self.cli_path {
160            Some(path) => path,
161            None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
162        };
163
164        let log_level = self.log_level.unwrap_or_default();
165
166        let extension_path = std::env::current_dir()?;
167        let extension_name =
168            toml::from_str::<ExtensionToml>(&std::fs::read_to_string(extension_path.join("extension.toml"))?)?
169                .extension
170                .name;
171
172        // Ensure current extension is built and up to date.
173        {
174            println!("* Building current extension.");
175            let lock_path = extension_path.join(".build.lock");
176            let mut lock_file = fslock::LockFile::open(&lock_path)?;
177            lock_file.lock()?;
178
179            let output = {
180                let cmd = duct::cmd(&cli_path, &["extension", "build", "--debug"]).dir(&extension_path);
181                if self.stream_stdout_stderr.unwrap_or(false) {
182                    cmd
183                } else {
184                    cmd.stdout_capture().stderr_capture()
185                }
186            }
187            .unchecked()
188            .stderr_to_stdout()
189            .run()?;
190
191            if !output.status.success() {
192                return Err(anyhow!(
193                    "Failed to build extension: {}\n{}\n{}",
194                    output.status,
195                    String::from_utf8_lossy(&output.stdout),
196                    String::from_utf8_lossy(&output.stderr)
197                ));
198            }
199
200            lock_file.unlock()?;
201            anyhow::Ok(())
202        }?;
203
204        println!("* Preparing the grafbase.toml & schema.graphql files.");
205        // Update grafbase TOML with current extension path.
206        let mut toml_config: GatewayToml = toml::from_str(&self.toml_config.unwrap_or_default())?;
207        match toml_config.extensions.entry(extension_name.clone()) {
208            Entry::Occupied(mut entry) => match entry.get_mut() {
209                ExtensionConfig::Version(_) => {
210                    return Err(anyhow!(
211                        "Current extension {extension_name} cannot be specified with a version"
212                    ));
213                }
214                ExtensionConfig::Structured(config) => {
215                    config
216                        .path
217                        .get_or_insert_with(|| extension_path.join("build").to_string_lossy().into_owned());
218                }
219            },
220            Entry::Vacant(entry) => {
221                entry.insert(ExtensionConfig::Structured(StructuredExtensionConfig {
222                    path: Some(extension_path.join("build").to_string_lossy().into_owned()),
223                    version: None,
224                    rest: Default::default(),
225                }));
226            }
227        }
228
229        // Composition
230        let (federated_sdl, mock_subgraphs) = compose(self.subgraphs, &extension_path).await?;
231
232        if toml_config.wasm.cache_path.is_none() {
233            toml_config.wasm.cache_path = Some(extension_path.join("build").join("wasm-cache"));
234        }
235
236        // Build test dir
237        let tmp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
238        let config_path = tmp_dir.path().join("grafbase.toml");
239        let schema_path = tmp_dir.path().join("schema.graphql");
240
241        std::fs::write(&config_path, toml::to_string(&toml_config)?).context("Failed to write grafbase.toml")?;
242        std::fs::write(&schema_path, &federated_sdl).context("Failed to write schema.graphql")?;
243
244        // Install other extensions if necessary.
245        if toml_config.extensions.len() > 1 {
246            println!("* Installing other extensions.");
247            let output = {
248                let cmd = duct::cmd(&cli_path, &["extension", "install"]).dir(tmp_dir.path());
249                if self.stream_stdout_stderr.unwrap_or(false) {
250                    cmd
251                } else {
252                    cmd.stdout_capture().stderr_capture()
253                }
254            }
255            .unchecked()
256            .stderr_to_stdout()
257            .run()?;
258
259            if !output.status.success() {
260                return Err(anyhow!(
261                    "Failed to install extensions: {}\n{}\n{}",
262                    output.status,
263                    String::from_utf8_lossy(&output.stdout),
264                    String::from_utf8_lossy(&output.stderr)
265                ));
266            }
267        }
268
269        println!("* Starting the gateway.");
270        let listen_address = new_listen_address()?;
271        let url = Url::parse(&format!("http://{listen_address}/graphql")).unwrap();
272
273        let handle = {
274            let cmd = duct::cmd(
275                &gateway_path,
276                &[
277                    "--listen-address",
278                    &listen_address.to_string(),
279                    "--config",
280                    &config_path.to_string_lossy(),
281                    "--schema",
282                    &schema_path.to_string_lossy(),
283                    "--log",
284                    log_level.as_ref(),
285                ],
286            )
287            .dir(tmp_dir.path());
288            if self.stream_stdout_stderr.unwrap_or(false) {
289                cmd
290            } else {
291                cmd.stdout_capture().stderr_capture()
292            }
293        }
294        .unchecked()
295        .stderr_to_stdout()
296        .start()
297        .map_err(|err| anyhow!("Failed to start the gateway: {err}"))?;
298
299        let gateway = TestGateway {
300            http_client: reqwest::Client::new(),
301            handle,
302            url,
303            tmp_dir,
304            mock_subgraphs,
305            federated_sdl,
306        };
307
308        let mut i = 0;
309        while gateway.health().await.is_err() {
310            // printing every second only
311            if i % 10 == 0 {
312                match gateway.handle.try_wait() {
313                    Ok(Some(output)) => {
314                        return Err(anyhow!(
315                            "Gateway process exited unexpectedly: {}\n{}\n{}",
316                            output.status,
317                            String::from_utf8_lossy(&output.stdout),
318                            String::from_utf8_lossy(&output.stderr)
319                        ));
320                    }
321                    Ok(None) => (),
322                    Err(err) => return Err(anyhow!("Error waiting for gateway process: {err}")),
323                }
324                println!("Waiting for gateway to be ready...");
325            }
326            i += 1;
327            tokio::time::sleep(Duration::from_millis(100)).await;
328        }
329
330        Ok(gateway)
331    }
332}
333
334pub(crate) fn new_listen_address() -> anyhow::Result<SocketAddr> {
335    let port = free_port()?;
336    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
337}
338
339pub(crate) fn free_port() -> anyhow::Result<u16> {
340    const INITIAL_PORT: u16 = 14712;
341
342    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
343    std::fs::create_dir_all(&test_dir)?;
344
345    let lock_file_path = test_dir.join("port-number.lock");
346    let port_number_file_path = test_dir.join("port-number.txt");
347
348    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
349    lock_file.lock()?;
350
351    let port = if port_number_file_path.exists() {
352        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
353    } else {
354        INITIAL_PORT
355    };
356
357    std::fs::write(&port_number_file_path, port.to_string())?;
358    lock_file.unlock()?;
359
360    Ok(port)
361}
362
363async fn compose(
364    subgraphs: impl IntoIterator<Item = Subgraph>,
365    extension_path: &Path,
366) -> anyhow::Result<(String, Vec<MockGraphQlServer>)> {
367    let mut mock_subgraphs = Vec::new();
368    let mut composition_subgraphs = Subgraphs::default();
369
370    let extension_url = url::Url::from_file_path(extension_path.join("build")).unwrap();
371    let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
372    let rep = format!(r#"@link(url: "{extension_url}""#);
373
374    for subgraph in subgraphs {
375        match subgraph {
376            Subgraph::Graphql(subgraph) => {
377                let mock_graph = subgraph.start().await;
378                let sdl = re.replace_all(mock_graph.schema(), &rep);
379                composition_subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
380                mock_subgraphs.push(mock_graph);
381            }
382            Subgraph::Virtual(subgraph) => {
383                let sdl = re.replace_all(subgraph.schema(), &rep);
384                composition_subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
385            }
386        }
387    }
388
389    let federated_graph = match graphql_composition::compose(&mut composition_subgraphs)
390        .warnings_are_fatal()
391        .into_result()
392    {
393        Ok(graph) => graph,
394        Err(diagnostics) => {
395            return Err(anyhow!(
396                "Failed to compose subgraphs:\n{}\n",
397                diagnostics
398                    .iter_messages()
399                    .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
400            ));
401        }
402    };
403    let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
404
405    Ok((federated_sdl, mock_subgraphs))
406}
407
408impl Drop for TestGateway {
409    fn drop(&mut self) {
410        if let Err(err) = self.handle.kill() {
411            eprintln!("Failed to kill grafbase-gateway: {err}")
412        }
413    }
414}