grafbase_sdk/test/
gateway.rs

1use std::{
2    collections::hash_map::Entry,
3    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use crate::test::{
10    GraphqlRequest, LogLevel,
11    config::{
12        CLI_BINARY_NAME, ExtensionConfig, ExtensionToml, GATEWAY_BINARY_NAME, GatewayToml, StructuredExtensionConfig,
13    },
14    request::{Body, IntrospectionRequest},
15};
16
17use anyhow::{Context, anyhow};
18use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
19use graphql_composition::{LoadedExtension, Subgraphs};
20use itertools::Itertools;
21use regex::Regex;
22use tempfile::TempDir;
23use url::Url;
24
25/// A test runner that can start a gateway and execute GraphQL queries against it.
26pub struct TestGateway {
27    http_client: reqwest::Client,
28    handle: duct::Handle,
29    url: Url,
30    federated_sdl: String,
31    // Kept to drop them at the right time.
32    #[allow(unused)]
33    tmp_dir: TempDir,
34    #[allow(unused)]
35    mock_subgraphs: Vec<MockGraphQlServer>,
36}
37
38impl TestGateway {
39    /// Creates a new test configuration builder.
40    pub fn builder() -> TestGatewayBuilder {
41        TestGatewayBuilder::new()
42    }
43
44    /// Full url of the GraphQL endpoint on the gateway.
45    pub fn url(&self) -> &Url {
46        &self.url
47    }
48
49    /// Creates a new GraphQL query builder with the given query.
50    ///
51    /// # Arguments
52    ///
53    /// * `query` - The GraphQL query string to execute
54    ///
55    /// # Returns
56    ///
57    /// A [`QueryBuilder`] that can be used to customize and execute the query
58    pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
59        let builder = self.http_client.post(self.url.clone());
60        GraphqlRequest {
61            builder,
62            body: query.into(),
63        }
64    }
65
66    /// Returns the federated schema as a string.
67    pub fn federated_sdl(&self) -> &str {
68        &self.federated_sdl
69    }
70
71    /// Execute a GraphQL introspection query to retrieve the API schema as a string.
72    /// Beware that introspection must be explicitly enabled with in the TOML config:
73    /// ```toml
74    /// [graph]
75    /// introspection = true
76    /// ```
77    pub fn introspect(&self) -> IntrospectionRequest {
78        let operation = cynic_introspection::IntrospectionQuery::with_capabilities(
79            cynic_introspection::SpecificationVersion::October2021.capabilities(),
80        );
81        IntrospectionRequest(self.query(Body {
82            query: Some(operation.query),
83            variables: None,
84        }))
85    }
86
87    /// Checks if the gateway is healthy by sending a request to the `/health` endpoint.
88    pub async fn health(&self) -> anyhow::Result<()> {
89        let url = self.url.join("/health")?;
90        let _ = self.http_client.get(url).send().await?.error_for_status()?;
91        Ok(())
92    }
93}
94
95#[derive(Debug, Default, Clone)]
96/// Builder pattern to create a [`TestGateway`].
97pub struct TestGatewayBuilder {
98    gateway_path: Option<PathBuf>,
99    cli_path: Option<PathBuf>,
100    toml_config: Option<String>,
101    subgraphs: Vec<Subgraph>,
102    stream_stdout_stderr: Option<bool>,
103    log_level: Option<LogLevel>,
104}
105
106impl TestGatewayBuilder {
107    /// Creates a new [`TestConfigBuilder`] with default values.
108    pub(crate) fn new() -> Self {
109        Self::default()
110    }
111
112    /// Adds a subgraph to the test configuration.
113    pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
114        self.subgraphs.push(subgraph.into());
115        self
116    }
117
118    /// Specifies a custom path to the gateway binary. If not defined, the binary will be searched in the PATH.
119    pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
120        self.gateway_path = Some(gateway_path.into());
121        self
122    }
123
124    /// Specifies a custom path to the CLI binary. If not defined, the binary will be searched in the PATH.
125    pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
126        self.cli_path = Some(cli_path.into());
127        self
128    }
129
130    /// Sets the TOML configuration for the gateway. The extension and subgraphs will be
131    /// automatically added to the configuration.
132    pub fn toml_config(mut self, cfg: impl ToString) -> Self {
133        self.toml_config = Some(cfg.to_string());
134        self
135    }
136
137    /// Sets the log level for the gateway process output.
138    pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
139        self.log_level = Some(level.into());
140        self
141    }
142
143    /// Stream stdout and stderr from the gateway & cli commands.
144    /// Useful if you need to debug subscriptions for example. Not recommended in a CI for
145    /// reporting clarity.
146    pub fn stream_stdout_stderr(mut self) -> Self {
147        self.stream_stdout_stderr = Some(true);
148        self
149    }
150
151    /// Build the [`TestGateway`]
152    pub async fn build(self) -> anyhow::Result<TestGateway> {
153        println!("Building the gateway:");
154
155        let gateway_path = match self.gateway_path {
156            Some(path) => path,
157            None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
158        };
159
160        let cli_path = match self.cli_path {
161            Some(path) => path,
162            None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
163        };
164
165        let log_level = self.log_level.unwrap_or_default();
166
167        let extension_path = std::env::current_dir()?;
168        let extension_name =
169            toml::from_str::<ExtensionToml>(&std::fs::read_to_string(extension_path.join("extension.toml"))?)?
170                .extension
171                .name;
172
173        // Ensure current extension is built and up to date.
174        {
175            println!("* Building current extension.");
176            let lock_path = extension_path.join(".build.lock");
177            let mut lock_file = fslock::LockFile::open(&lock_path)?;
178            lock_file.lock()?;
179
180            let output = {
181                let cmd = duct::cmd(&cli_path, &["extension", "build", "--debug"]).dir(&extension_path);
182                if self.stream_stdout_stderr.unwrap_or(false) {
183                    cmd
184                } else {
185                    cmd.stdout_capture().stderr_capture()
186                }
187            }
188            .unchecked()
189            .stderr_to_stdout()
190            .run()?;
191
192            if !output.status.success() {
193                return Err(anyhow!(
194                    "Failed to build extension: {}\n{}\n{}",
195                    output.status,
196                    String::from_utf8_lossy(&output.stdout),
197                    String::from_utf8_lossy(&output.stderr)
198                ));
199            }
200
201            lock_file.unlock()?;
202            anyhow::Ok(())
203        }?;
204
205        println!("* Preparing the grafbase.toml & schema.graphql files.");
206        // Update grafbase TOML with current extension path.
207        let mut toml_config: GatewayToml = toml::from_str(&self.toml_config.unwrap_or_default())?;
208        match toml_config.extensions.entry(extension_name.clone()) {
209            Entry::Occupied(mut entry) => match entry.get_mut() {
210                ExtensionConfig::Version(_) => {
211                    return Err(anyhow!(
212                        "Current extension {extension_name} cannot be specified with a version"
213                    ));
214                }
215                ExtensionConfig::Structured(config) => {
216                    config
217                        .path
218                        .get_or_insert_with(|| extension_path.join("build").to_string_lossy().into_owned());
219                }
220            },
221            Entry::Vacant(entry) => {
222                entry.insert(ExtensionConfig::Structured(StructuredExtensionConfig {
223                    path: Some(extension_path.join("build").to_string_lossy().into_owned()),
224                    version: None,
225                    rest: Default::default(),
226                }));
227            }
228        }
229
230        // Composition
231        let (federated_sdl, mock_subgraphs) = {
232            let extensions = toml_config
233                .extensions
234                .iter()
235                .map(|(name, config)| match config {
236                    ExtensionConfig::Version(version) => Ok(LoadedExtension::new(
237                        format!("https://extensions.grafbase.com/{extension_name}/{version}"),
238                        name.clone(),
239                    )),
240                    ExtensionConfig::Structured(StructuredExtensionConfig {
241                        path: None, version, ..
242                    }) => Ok(LoadedExtension::new(
243                        format!(
244                            "https://extensions.grafbase.com/{extension_name}/{}",
245                            version
246                                .as_ref()
247                                .ok_or_else(|| anyhow!("Missing path or version for extension '{name}'"))?
248                        ),
249                        name.clone(),
250                    )),
251                    ExtensionConfig::Structured(StructuredExtensionConfig { path: Some(path), .. }) => {
252                        let mut path = PathBuf::from_str(path.as_str())
253                            .context(format!("Invalid path for extension {name}: {path}"))?;
254                        if path.is_relative() {
255                            path = extension_path.join(path);
256                        }
257                        anyhow::Ok(LoadedExtension::new(
258                            Url::from_file_path(&path)
259                                .map_err(|_| anyhow!("Invalid path for extension {name}: {}", path.display()))?
260                                .to_string(),
261                            name.to_owned(),
262                        ))
263                    }
264                })
265                .collect::<anyhow::Result<Vec<_>>>()?;
266
267            compose(self.subgraphs, &extension_path, extensions).await
268        }?;
269
270        if toml_config.wasm.cache_path.is_none() {
271            toml_config.wasm.cache_path = Some(extension_path.join("build").join("wasm-cache"));
272        }
273
274        // Build test dir
275        let tmp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
276        let config_path = tmp_dir.path().join("grafbase.toml");
277        let schema_path = tmp_dir.path().join("schema.graphql");
278
279        std::fs::write(&config_path, toml::to_string(&toml_config)?).context("Failed to write grafbase.toml")?;
280        std::fs::write(&schema_path, &federated_sdl).context("Failed to write schema.graphql")?;
281
282        // Install other extensions if necessary.
283        if toml_config.extensions.len() > 1 {
284            println!("* Installing other extensions.");
285            let output = {
286                let cmd = duct::cmd(&cli_path, &["extension", "install"]).dir(tmp_dir.path());
287                if self.stream_stdout_stderr.unwrap_or(false) {
288                    cmd
289                } else {
290                    cmd.stdout_capture().stderr_capture()
291                }
292            }
293            .unchecked()
294            .stderr_to_stdout()
295            .run()?;
296
297            if !output.status.success() {
298                return Err(anyhow!(
299                    "Failed to install extensions: {}\n{}\n{}",
300                    output.status,
301                    String::from_utf8_lossy(&output.stdout),
302                    String::from_utf8_lossy(&output.stderr)
303                ));
304            }
305        }
306
307        println!("* Starting the gateway.");
308        let listen_address = new_listen_address()?;
309        let url = Url::parse(&format!("http://{listen_address}/graphql")).unwrap();
310
311        let handle = {
312            let cmd = duct::cmd(
313                &gateway_path,
314                &[
315                    "--listen-address",
316                    &listen_address.to_string(),
317                    "--config",
318                    &config_path.to_string_lossy(),
319                    "--schema",
320                    &schema_path.to_string_lossy(),
321                    "--log",
322                    log_level.as_ref(),
323                ],
324            )
325            .dir(tmp_dir.path());
326            if self.stream_stdout_stderr.unwrap_or(false) {
327                cmd
328            } else {
329                cmd.stdout_capture().stderr_capture()
330            }
331        }
332        .unchecked()
333        .stderr_to_stdout()
334        .start()
335        .map_err(|err| anyhow!("Failed to start the gateway: {err}"))?;
336
337        let gateway = TestGateway {
338            http_client: reqwest::Client::new(),
339            handle,
340            url,
341            tmp_dir,
342            mock_subgraphs,
343            federated_sdl,
344        };
345
346        let mut i = 0;
347        while gateway.health().await.is_err() {
348            // printing every second only
349            if i % 10 == 0 {
350                match gateway.handle.try_wait() {
351                    Ok(Some(output)) => {
352                        return Err(anyhow!(
353                            "Gateway process exited unexpectedly: {}\n{}\n{}",
354                            output.status,
355                            String::from_utf8_lossy(&output.stdout),
356                            String::from_utf8_lossy(&output.stderr)
357                        ));
358                    }
359                    Ok(None) => (),
360                    Err(err) => return Err(anyhow!("Error waiting for gateway process: {err}")),
361                }
362                println!("Waiting for gateway to be ready...");
363            }
364            i += 1;
365            tokio::time::sleep(Duration::from_millis(100)).await;
366        }
367
368        Ok(gateway)
369    }
370}
371
372pub(crate) fn new_listen_address() -> anyhow::Result<SocketAddr> {
373    let port = free_port()?;
374    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
375}
376
377pub(crate) fn free_port() -> anyhow::Result<u16> {
378    const INITIAL_PORT: u16 = 14712;
379
380    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
381    std::fs::create_dir_all(&test_dir)?;
382
383    let lock_file_path = test_dir.join("port-number.lock");
384    let port_number_file_path = test_dir.join("port-number.txt");
385
386    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
387    lock_file.lock()?;
388
389    let port = if port_number_file_path.exists() {
390        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
391    } else {
392        INITIAL_PORT
393    };
394
395    std::fs::write(&port_number_file_path, port.to_string())?;
396    lock_file.unlock()?;
397
398    Ok(port)
399}
400
401async fn compose(
402    subgraphs: impl IntoIterator<Item = Subgraph>,
403    extension_path: &Path,
404    extensions: impl IntoIterator<Item = LoadedExtension>,
405) -> anyhow::Result<(String, Vec<MockGraphQlServer>)> {
406    let mut mock_subgraphs = Vec::new();
407    let mut composition_subgraphs = Subgraphs::default();
408
409    composition_subgraphs.ingest_loaded_extensions(extensions);
410
411    let extension_url = url::Url::from_file_path(extension_path.join("build")).unwrap();
412    let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
413    let rep = format!(r#"@link(url: "{extension_url}""#);
414
415    for subgraph in subgraphs {
416        match subgraph {
417            Subgraph::Graphql(subgraph) => {
418                let mock_graph = subgraph.start().await;
419                let sdl = re.replace_all(mock_graph.schema(), &rep);
420                composition_subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
421                mock_subgraphs.push(mock_graph);
422            }
423            Subgraph::Virtual(subgraph) => {
424                let sdl = re.replace_all(subgraph.schema(), &rep);
425                composition_subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
426            }
427        }
428    }
429
430    let federated_graph = match graphql_composition::compose(composition_subgraphs)
431        .warnings_are_fatal()
432        .into_result()
433    {
434        Ok(graph) => graph,
435        Err(diagnostics) => {
436            return Err(anyhow!(
437                "Failed to compose subgraphs:\n{}\n",
438                diagnostics
439                    .iter_messages()
440                    .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
441            ));
442        }
443    };
444    let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
445
446    Ok((federated_sdl, mock_subgraphs))
447}
448
449impl Drop for TestGateway {
450    fn drop(&mut self) {
451        if let Err(err) = self.handle.kill() {
452            eprintln!("Failed to kill grafbase-gateway: {err}")
453        }
454    }
455}