grafbase_sdk/test/
gateway.rs

1use std::{
2    collections::hash_map::Entry,
3    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use crate::test::{
10    GraphqlRequest, LogLevel,
11    config::{
12        CLI_BINARY_NAME, ExtensionConfig, ExtensionToml, GATEWAY_BINARY_NAME, GatewayToml, StructuredExtensionConfig,
13    },
14    request::{Body, IntrospectionRequest},
15};
16
17use anyhow::{Context, anyhow};
18use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
19use graphql_composition::{LoadedExtension, Subgraphs};
20use itertools::Itertools;
21use regex::Regex;
22use tempfile::TempDir;
23use url::Url;
24
25/// A test runner that can start a gateway and execute GraphQL queries against it.
26pub struct TestGateway {
27    http_client: reqwest::Client,
28    handle: duct::Handle,
29    url: Url,
30    federated_sdl: String,
31    // Kept to drop them at the right time.
32    #[allow(unused)]
33    tmp_dir: TempDir,
34    #[allow(unused)]
35    mock_subgraphs: Vec<MockGraphQlServer>,
36}
37
38impl TestGateway {
39    /// Creates a new test configuration builder.
40    pub fn builder() -> TestGatewayBuilder {
41        TestGatewayBuilder::new()
42    }
43
44    /// Full url of the GraphQL endpoint on the gateway.
45    pub fn url(&self) -> &Url {
46        &self.url
47    }
48
49    /// Creates a new GraphQL query builder with the given query.
50    ///
51    /// # Arguments
52    ///
53    /// * `query` - The GraphQL query string to execute
54    ///
55    /// # Returns
56    ///
57    /// A [`QueryBuilder`] that can be used to customize and execute the query
58    pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
59        let builder = self.http_client.post(self.url.clone());
60        GraphqlRequest {
61            builder,
62            body: query.into(),
63        }
64    }
65
66    /// Returns the federated schema as a string.
67    pub fn federated_sdl(&self) -> &str {
68        &self.federated_sdl
69    }
70
71    /// Execute a GraphQL introspection query to retrieve the API schema as a string.
72    /// Beware that introspection must be explicitly enabled with in the TOML config:
73    /// ```toml
74    /// [graph]
75    /// introspection = true
76    /// ```
77    pub fn introspect(&self) -> IntrospectionRequest {
78        let operation = cynic_introspection::IntrospectionQuery::with_capabilities(
79            cynic_introspection::SpecificationVersion::October2021.capabilities(),
80        );
81        IntrospectionRequest(self.query(Body {
82            query: Some(operation.query),
83            variables: None,
84        }))
85    }
86
87    /// Checks if the gateway is healthy by sending a request to the `/health` endpoint.
88    pub async fn health(&self) -> anyhow::Result<()> {
89        let url = self.url.join("/health")?;
90        let _ = self.http_client.get(url).send().await?.error_for_status()?;
91        Ok(())
92    }
93}
94
95#[derive(Debug, Default, Clone)]
96/// Builder pattern to create a [`TestGateway`].
97pub struct TestGatewayBuilder {
98    gateway_path: Option<PathBuf>,
99    cli_path: Option<PathBuf>,
100    toml_config: Option<String>,
101    subgraphs: Vec<Subgraph>,
102    stream_stdout_stderr: Option<bool>,
103    log_level: Option<LogLevel>,
104}
105
106impl TestGatewayBuilder {
107    /// Creates a new [`TestConfigBuilder`] with default values.
108    pub(crate) fn new() -> Self {
109        Self::default()
110    }
111
112    /// Adds a subgraph to the test configuration.
113    pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
114        self.subgraphs.push(subgraph.into());
115        self
116    }
117
118    /// Specifies a custom path to the gateway binary. If not defined, the binary will be searched in the PATH.
119    pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
120        self.gateway_path = Some(gateway_path.into());
121        self
122    }
123
124    /// Specifies a custom path to the CLI binary. If not defined, the binary will be searched in the PATH.
125    pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
126        self.cli_path = Some(cli_path.into());
127        self
128    }
129
130    /// Sets the TOML configuration for the gateway. The extension and subgraphs will be
131    /// automatically added to the configuration.
132    pub fn toml_config(mut self, cfg: impl ToString) -> Self {
133        self.toml_config = Some(cfg.to_string());
134        self
135    }
136
137    /// Sets the log level for the gateway process output.
138    pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
139        self.log_level = Some(level.into());
140        self
141    }
142
143    /// Stream stdout and stderr from the gateway & cli commands.
144    /// Useful if you need to debug subscriptions for example. Not recommended in a CI for
145    /// reporting clarity.
146    pub fn stream_stdout_stderr(mut self) -> Self {
147        self.stream_stdout_stderr = Some(true);
148        self
149    }
150
151    /// Build the [`TestGateway`]
152    pub async fn build(self) -> anyhow::Result<TestGateway> {
153        println!("Building the gateway:");
154
155        let gateway_path = match self.gateway_path {
156            Some(path) => path,
157            None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
158        };
159
160        let cli_path = match self.cli_path {
161            Some(path) => path,
162            None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
163        };
164
165        let log_level = self.log_level.unwrap_or_default();
166
167        let extension_path = std::env::current_dir()?;
168        let extension_name =
169            toml::from_str::<ExtensionToml>(&std::fs::read_to_string(extension_path.join("extension.toml"))?)?
170                .extension
171                .name;
172
173        // Ensure current extension is built and up to date.
174        {
175            println!("* Building current extension.");
176            let lock_path = extension_path.join(".build.lock");
177            let mut lock_file = fslock::LockFile::open(&lock_path)?;
178            lock_file.lock()?;
179
180            let output = {
181                let cmd = duct::cmd(&cli_path, &["extension", "build", "--debug"]).dir(&extension_path);
182                if self.stream_stdout_stderr.unwrap_or(false) {
183                    cmd
184                } else {
185                    cmd.stdout_capture().stderr_capture()
186                }
187            }
188            .unchecked()
189            .stderr_to_stdout()
190            .run()?;
191
192            if !output.status.success() {
193                return Err(anyhow!(
194                    "Failed to build extension: {}\n{}\n{}",
195                    output.status,
196                    String::from_utf8_lossy(&output.stdout),
197                    String::from_utf8_lossy(&output.stderr)
198                ));
199            }
200
201            lock_file.unlock()?;
202            anyhow::Ok(())
203        }?;
204
205        println!("* Preparing the grafbase.toml & schema.graphql files.");
206        // Update grafbase TOML with current extension path.
207        let mut toml_config: GatewayToml = toml::from_str(&self.toml_config.unwrap_or_default())?;
208        match toml_config.extensions.entry(extension_name.clone()) {
209            Entry::Occupied(mut entry) => match entry.get_mut() {
210                ExtensionConfig::Version(_) => {
211                    return Err(anyhow!(
212                        "Current extension {extension_name} cannot be specified with a version"
213                    ));
214                }
215                ExtensionConfig::Structured(config) => {
216                    config
217                        .path
218                        .get_or_insert_with(|| extension_path.join("build").to_string_lossy().into_owned());
219                }
220            },
221            Entry::Vacant(entry) => {
222                entry.insert(ExtensionConfig::Structured(StructuredExtensionConfig {
223                    path: Some(extension_path.join("build").to_string_lossy().into_owned()),
224                    version: None,
225                    rest: Default::default(),
226                }));
227            }
228        }
229
230        // Composition
231        let (federated_sdl, mock_subgraphs) = {
232            let extensions = toml_config
233                .extensions
234                .iter()
235                .map(|(name, config)| match config {
236                    ExtensionConfig::Version(version) => Ok(new_loaded_extension(
237                        format!("https://extensions.grafbase.com/{extension_name}/{version}")
238                            .parse()
239                            .unwrap(),
240                        name.clone(),
241                    )),
242                    ExtensionConfig::Structured(StructuredExtensionConfig {
243                        path: None, version, ..
244                    }) => Ok(new_loaded_extension(
245                        format!(
246                            "https://extensions.grafbase.com/{extension_name}/{}",
247                            version
248                                .as_ref()
249                                .ok_or_else(|| anyhow!("Missing path or version for extension '{name}'"))?
250                        )
251                        .parse()
252                        .unwrap(),
253                        name.clone(),
254                    )),
255                    ExtensionConfig::Structured(StructuredExtensionConfig { path: Some(path), .. }) => {
256                        let mut path = PathBuf::from_str(path.as_str())
257                            .context(format!("Invalid path for extension {name}: {path}"))?;
258                        if path.is_relative() {
259                            path = extension_path.join(path);
260                        }
261                        anyhow::Ok(new_loaded_extension(
262                            Url::from_file_path(&path).unwrap(),
263                            name.to_owned(),
264                        ))
265                    }
266                })
267                .collect::<anyhow::Result<Vec<_>>>()?;
268
269            compose(self.subgraphs, &extension_path, extensions).await
270        }?;
271
272        if toml_config.wasm.cache_path.is_none() {
273            toml_config.wasm.cache_path = Some(extension_path.join("build").join("wasm-cache"));
274        }
275
276        // Build test dir
277        let tmp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
278        let config_path = tmp_dir.path().join("grafbase.toml");
279        let schema_path = tmp_dir.path().join("schema.graphql");
280
281        std::fs::write(&config_path, toml::to_string(&toml_config)?).context("Failed to write grafbase.toml")?;
282        std::fs::write(&schema_path, &federated_sdl).context("Failed to write schema.graphql")?;
283
284        // Install other extensions if necessary.
285        if toml_config.extensions.len() > 1 {
286            println!("* Installing other extensions.");
287            let output = {
288                let cmd = duct::cmd(&cli_path, &["extension", "install"]).dir(tmp_dir.path());
289                if self.stream_stdout_stderr.unwrap_or(false) {
290                    cmd
291                } else {
292                    cmd.stdout_capture().stderr_capture()
293                }
294            }
295            .unchecked()
296            .stderr_to_stdout()
297            .run()?;
298
299            if !output.status.success() {
300                return Err(anyhow!(
301                    "Failed to install extensions: {}\n{}\n{}",
302                    output.status,
303                    String::from_utf8_lossy(&output.stdout),
304                    String::from_utf8_lossy(&output.stderr)
305                ));
306            }
307        }
308
309        println!("* Starting the gateway.");
310        let listen_address = new_listen_address()?;
311        let url = Url::parse(&format!("http://{listen_address}/graphql")).unwrap();
312
313        let handle = {
314            let cmd = duct::cmd(
315                &gateway_path,
316                &[
317                    "--listen-address",
318                    &listen_address.to_string(),
319                    "--config",
320                    &config_path.to_string_lossy(),
321                    "--schema",
322                    &schema_path.to_string_lossy(),
323                    "--log",
324                    log_level.as_ref(),
325                ],
326            )
327            .dir(tmp_dir.path());
328            if self.stream_stdout_stderr.unwrap_or(false) {
329                cmd
330            } else {
331                cmd.stdout_capture().stderr_capture()
332            }
333        }
334        .unchecked()
335        .stderr_to_stdout()
336        .start()
337        .map_err(|err| anyhow!("Failed to start the gateway: {err}"))?;
338
339        let gateway = TestGateway {
340            http_client: reqwest::Client::new(),
341            handle,
342            url,
343            tmp_dir,
344            mock_subgraphs,
345            federated_sdl,
346        };
347
348        let mut i = 0;
349        while gateway.health().await.is_err() {
350            // printing every second only
351            if i % 10 == 0 {
352                match gateway.handle.try_wait() {
353                    Ok(Some(output)) => {
354                        return Err(anyhow!(
355                            "Gateway process exited unexpectedly: {}\n{}\n{}",
356                            output.status,
357                            String::from_utf8_lossy(&output.stdout),
358                            String::from_utf8_lossy(&output.stderr)
359                        ));
360                    }
361                    Ok(None) => (),
362                    Err(err) => return Err(anyhow!("Error waiting for gateway process: {err}")),
363                }
364                println!("Waiting for gateway to be ready...");
365            }
366            i += 1;
367            tokio::time::sleep(Duration::from_millis(100)).await;
368        }
369
370        Ok(gateway)
371    }
372}
373
374pub(crate) fn new_listen_address() -> anyhow::Result<SocketAddr> {
375    let port = free_port()?;
376    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
377}
378
379pub(crate) fn free_port() -> anyhow::Result<u16> {
380    const INITIAL_PORT: u16 = 14712;
381
382    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
383    std::fs::create_dir_all(&test_dir)?;
384
385    let lock_file_path = test_dir.join("port-number.lock");
386    let port_number_file_path = test_dir.join("port-number.txt");
387
388    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
389    lock_file.lock()?;
390
391    let port = if port_number_file_path.exists() {
392        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
393    } else {
394        INITIAL_PORT
395    };
396
397    std::fs::write(&port_number_file_path, port.to_string())?;
398    lock_file.unlock()?;
399
400    Ok(port)
401}
402
403async fn compose(
404    subgraphs: impl IntoIterator<Item = Subgraph>,
405    extension_path: &Path,
406    extensions: impl IntoIterator<Item = LoadedExtension>,
407) -> anyhow::Result<(String, Vec<MockGraphQlServer>)> {
408    let mut mock_subgraphs = Vec::new();
409    let mut composition_subgraphs = Subgraphs::default();
410
411    composition_subgraphs.ingest_loaded_extensions(extensions);
412
413    let extension_url = url::Url::from_file_path(extension_path.join("build")).unwrap();
414    let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
415    let rep = format!(r#"@link(url: "{extension_url}""#);
416
417    for subgraph in subgraphs {
418        match subgraph {
419            Subgraph::Graphql(subgraph) => {
420                let mock_graph = subgraph.start().await;
421                let sdl = re.replace_all(mock_graph.schema(), &rep);
422                composition_subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
423                mock_subgraphs.push(mock_graph);
424            }
425            Subgraph::Virtual(subgraph) => {
426                let sdl = re.replace_all(subgraph.schema(), &rep);
427                composition_subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
428            }
429        }
430    }
431
432    let federated_graph = match graphql_composition::compose(&mut composition_subgraphs)
433        .warnings_are_fatal()
434        .into_result()
435    {
436        Ok(graph) => graph,
437        Err(diagnostics) => {
438            return Err(anyhow!(
439                "Failed to compose subgraphs:\n{}\n",
440                diagnostics
441                    .iter_messages()
442                    .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
443            ));
444        }
445    };
446    let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
447
448    Ok((federated_sdl, mock_subgraphs))
449}
450
451impl Drop for TestGateway {
452    fn drop(&mut self) {
453        if let Err(err) = self.handle.kill() {
454            eprintln!("Failed to kill grafbase-gateway: {err}")
455        }
456    }
457}
458
459fn new_loaded_extension(url: Url, name: String) -> LoadedExtension {
460    LoadedExtension {
461        link_url: url.to_string(),
462        url,
463        name,
464    }
465}