grafbase_sdk/test/
gateway.rs

1use std::{
2    collections::hash_map::Entry,
3    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4    path::{Path, PathBuf},
5    str::FromStr,
6    time::Duration,
7};
8
9use crate::test::{
10    GraphqlRequest, LogLevel,
11    config::{
12        CLI_BINARY_NAME, ExtensionConfig, ExtensionToml, GATEWAY_BINARY_NAME, GatewayToml, StructuredExtensionConfig,
13    },
14    request::{Body, IntrospectionRequest},
15};
16
17use anyhow::{Context, anyhow};
18use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
19use graphql_composition::{LoadedExtension, Subgraphs};
20use itertools::Itertools;
21use regex::Regex;
22use tempfile::TempDir;
23use url::Url;
24
25/// A test runner that can start a gateway and execute GraphQL queries against it.
26pub struct TestGateway {
27    http_client: reqwest::Client,
28    handle: duct::Handle,
29    url: Url,
30    federated_sdl: String,
31    // Kept to drop them at the right time.
32    #[allow(unused)]
33    tmp_dir: TempDir,
34    #[allow(unused)]
35    mock_subgraphs: Vec<MockGraphQlServer>,
36}
37
38impl TestGateway {
39    /// Creates a new test configuration builder.
40    pub fn builder() -> TestGatewayBuilder {
41        TestGatewayBuilder::new()
42    }
43
44    /// Full url of the GraphQL endpoint on the gateway.
45    pub fn url(&self) -> &Url {
46        &self.url
47    }
48
49    /// Creates a new GraphQL query builder with the given query.
50    ///
51    /// # Arguments
52    ///
53    /// * `query` - The GraphQL query string to execute
54    ///
55    /// # Returns
56    ///
57    /// A [`QueryBuilder`] that can be used to customize and execute the query
58    pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
59        let builder = self.http_client.post(self.url.clone());
60        GraphqlRequest {
61            builder,
62            body: query.into(),
63        }
64    }
65
66    /// Returns the federated schema as a string.
67    pub fn federated_sdl(&self) -> &str {
68        &self.federated_sdl
69    }
70
71    /// Execute a GraphQL introspection query to retrieve the API schema as a string.
72    /// Beware that introspection must be explicitly enabled with in the TOML config:
73    /// ```toml
74    /// [graph]
75    /// introspection = true
76    /// ```
77    pub fn introspect(&self) -> IntrospectionRequest {
78        let operation = cynic_introspection::IntrospectionQuery::with_capabilities(
79            cynic_introspection::SpecificationVersion::October2021.capabilities(),
80        );
81        IntrospectionRequest(self.query(Body {
82            query: Some(operation.query),
83            variables: None,
84        }))
85    }
86
87    /// Checks if the gateway is healthy by sending a request to the `/health` endpoint.
88    pub async fn health(&self) -> anyhow::Result<()> {
89        let url = self.url.join("/health")?;
90        let _ = self.http_client.get(url).send().await?.error_for_status()?;
91        Ok(())
92    }
93}
94
95#[derive(Debug, Default, Clone)]
96/// Builder pattern to create a [`TestGateway`].
97pub struct TestGatewayBuilder {
98    gateway_path: Option<PathBuf>,
99    cli_path: Option<PathBuf>,
100    toml_config: Option<String>,
101    subgraphs: Vec<Subgraph>,
102    stream_stdout_stderr: Option<bool>,
103    log_level: Option<LogLevel>,
104}
105
106impl TestGatewayBuilder {
107    /// Creates a new [`TestConfigBuilder`] with default values.
108    pub(crate) fn new() -> Self {
109        Self::default()
110    }
111
112    /// Adds a subgraph to the test configuration.
113    pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
114        self.subgraphs.push(subgraph.into());
115        self
116    }
117
118    /// Specifies a custom path to the gateway binary. If not defined, the binary will be searched in the PATH.
119    pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
120        self.gateway_path = Some(gateway_path.into());
121        self
122    }
123
124    /// Specifies a custom path to the CLI binary. If not defined, the binary will be searched in the PATH.
125    pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
126        self.cli_path = Some(cli_path.into());
127        self
128    }
129
130    /// Sets the TOML configuration for the gateway. The extension and subgraphs will be
131    /// automatically added to the configuration.
132    pub fn toml_config(mut self, cfg: impl ToString) -> Self {
133        self.toml_config = Some(cfg.to_string());
134        self
135    }
136
137    /// Sets the log level for the gateway process output.
138    pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
139        self.log_level = Some(level.into());
140        self
141    }
142
143    /// Stream stdout and stderr from the gateway & cli commands.
144    /// Useful if you need to debug subscriptions for example. Not recommended in a CI for
145    /// reporting clarity.
146    pub fn stream_stdout_stderr(mut self) -> Self {
147        self.stream_stdout_stderr = Some(true);
148        self
149    }
150
151    /// Build the [`TestGateway`]
152    pub async fn build(self) -> anyhow::Result<TestGateway> {
153        println!("Building the gateway:");
154
155        let gateway_path = match self.gateway_path {
156            Some(path) => path,
157            None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
158        };
159
160        let cli_path = match self.cli_path {
161            Some(path) => path,
162            None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
163        };
164
165        let log_level = self.log_level.unwrap_or_default();
166
167        let extension_path = std::env::current_dir()?;
168        let extension_name =
169            toml::from_str::<ExtensionToml>(&std::fs::read_to_string(extension_path.join("extension.toml"))?)?
170                .extension
171                .name;
172
173        // Ensure current extension is built and up to date.
174        {
175            println!("* Building current extension.");
176            let lock_path = extension_path.join(".build.lock");
177            let mut lock_file = fslock::LockFile::open(&lock_path)?;
178            lock_file.lock()?;
179
180            let output = {
181                let cmd = duct::cmd(&cli_path, &["extension", "build", "--debug"]).dir(&extension_path);
182                if self.stream_stdout_stderr.unwrap_or(false) {
183                    cmd
184                } else {
185                    cmd.stdout_capture().stderr_capture()
186                }
187            }
188            .unchecked()
189            .stderr_to_stdout()
190            .run()?;
191
192            if !output.status.success() {
193                return Err(anyhow!(
194                    "Failed to build extension: {}\n{}\n{}",
195                    output.status,
196                    String::from_utf8_lossy(&output.stdout),
197                    String::from_utf8_lossy(&output.stderr)
198                ));
199            }
200
201            lock_file.unlock()?;
202            anyhow::Ok(())
203        }?;
204
205        println!("* Preparing the grafbase.toml & schema.graphql files.");
206        // Update grafbase TOML with current extension path.
207        let mut toml_config: GatewayToml = toml::from_str(&self.toml_config.unwrap_or_default())?;
208        match toml_config.extensions.entry(extension_name.clone()) {
209            Entry::Occupied(mut entry) => match entry.get_mut() {
210                ExtensionConfig::Version(_) => {
211                    return Err(anyhow!(
212                        "Current extension {extension_name} cannot be specified with a version"
213                    ));
214                }
215                ExtensionConfig::Structured(config) => {
216                    config
217                        .path
218                        .get_or_insert_with(|| extension_path.join("build").to_string_lossy().into_owned());
219                }
220            },
221            Entry::Vacant(entry) => {
222                entry.insert(ExtensionConfig::Structured(StructuredExtensionConfig {
223                    path: Some(extension_path.join("build").to_string_lossy().into_owned()),
224                    version: None,
225                    rest: Default::default(),
226                }));
227            }
228        }
229
230        // Composition
231        let (federated_sdl, mock_subgraphs) = {
232            let extensions = toml_config
233                .extensions
234                .iter()
235                .map(|(name, config)| match config {
236                    ExtensionConfig::Version(version) => Ok(LoadedExtension::new(
237                        format!("https://extensions.grafbase.com/{extension_name}/{version}"),
238                        name.clone(),
239                    )),
240                    ExtensionConfig::Structured(StructuredExtensionConfig {
241                        path: None, version, ..
242                    }) => Ok(LoadedExtension::new(
243                        format!(
244                            "https://extensions.grafbase.com/{extension_name}/{}",
245                            version
246                                .as_ref()
247                                .ok_or_else(|| anyhow!("Missing path or version for extension '{name}'"))?
248                        ),
249                        name.clone(),
250                    )),
251                    ExtensionConfig::Structured(StructuredExtensionConfig { path: Some(path), .. }) => {
252                        let mut path = PathBuf::from_str(path.as_str())
253                            .context(format!("Invalid path for extension {name}: {path}"))?;
254                        if path.is_relative() {
255                            path = extension_path.join(path);
256                        }
257                        anyhow::Ok(LoadedExtension::new(
258                            Url::from_file_path(&path)
259                                .map_err(|_| anyhow!("Invalid path for extension {name}: {}", path.display()))?
260                                .to_string(),
261                            name.to_owned(),
262                        ))
263                    }
264                })
265                .collect::<anyhow::Result<Vec<_>>>()?;
266
267            compose(self.subgraphs, &extension_path, extensions).await
268        }?;
269
270        // Build test dir
271        let tmp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
272        let config_path = tmp_dir.path().join("grafbase.toml");
273        let schema_path = tmp_dir.path().join("schema.graphql");
274
275        std::fs::write(&config_path, toml::to_string(&toml_config)?).context("Failed to write grafbase.toml")?;
276        std::fs::write(&schema_path, &federated_sdl).context("Failed to write schema.graphql")?;
277
278        // Install other extensions if necessary.
279        if toml_config.extensions.len() > 1 {
280            println!("* Installing other extensions.");
281            let output = {
282                let cmd = duct::cmd(&cli_path, &["extension", "install"]).dir(tmp_dir.path());
283                if self.stream_stdout_stderr.unwrap_or(false) {
284                    cmd
285                } else {
286                    cmd.stdout_capture().stderr_capture()
287                }
288            }
289            .unchecked()
290            .stderr_to_stdout()
291            .run()?;
292
293            if !output.status.success() {
294                return Err(anyhow!(
295                    "Failed to install extensions: {}\n{}\n{}",
296                    output.status,
297                    String::from_utf8_lossy(&output.stdout),
298                    String::from_utf8_lossy(&output.stderr)
299                ));
300            }
301        }
302
303        println!("* Starting the gateway.");
304        let listen_address = new_listen_address()?;
305        let url = Url::parse(&format!("http://{listen_address}/graphql")).unwrap();
306
307        let handle = {
308            let cmd = duct::cmd(
309                &gateway_path,
310                &[
311                    "--listen-address",
312                    &listen_address.to_string(),
313                    "--config",
314                    &config_path.to_string_lossy(),
315                    "--schema",
316                    &schema_path.to_string_lossy(),
317                    "--log",
318                    log_level.as_ref(),
319                ],
320            )
321            .dir(tmp_dir.path());
322            if self.stream_stdout_stderr.unwrap_or(false) {
323                cmd
324            } else {
325                cmd.stdout_capture().stderr_capture()
326            }
327        }
328        .unchecked()
329        .stderr_to_stdout()
330        .start()
331        .map_err(|err| anyhow!("Failed to start the gateway: {err}"))?;
332
333        let gateway = TestGateway {
334            http_client: reqwest::Client::new(),
335            handle,
336            url,
337            tmp_dir,
338            mock_subgraphs,
339            federated_sdl,
340        };
341
342        let mut i = 0;
343        while gateway.health().await.is_err() {
344            // printing every second only
345            if i % 10 == 0 {
346                match gateway.handle.try_wait() {
347                    Ok(Some(output)) => {
348                        return Err(anyhow!(
349                            "Gateway process exited unexpectedly: {}\n{}\n{}",
350                            output.status,
351                            String::from_utf8_lossy(&output.stdout),
352                            String::from_utf8_lossy(&output.stderr)
353                        ));
354                    }
355                    Ok(None) => (),
356                    Err(err) => return Err(anyhow!("Error waiting for gateway process: {err}")),
357                }
358                println!("Waiting for gateway to be ready...");
359            }
360            i += 1;
361            tokio::time::sleep(Duration::from_millis(100)).await;
362        }
363
364        Ok(gateway)
365    }
366}
367
368pub(crate) fn new_listen_address() -> anyhow::Result<SocketAddr> {
369    let port = free_port()?;
370    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
371}
372
373pub(crate) fn free_port() -> anyhow::Result<u16> {
374    const INITIAL_PORT: u16 = 14712;
375
376    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
377    std::fs::create_dir_all(&test_dir)?;
378
379    let lock_file_path = test_dir.join("port-number.lock");
380    let port_number_file_path = test_dir.join("port-number.txt");
381
382    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
383    lock_file.lock()?;
384
385    let port = if port_number_file_path.exists() {
386        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
387    } else {
388        INITIAL_PORT
389    };
390
391    std::fs::write(&port_number_file_path, port.to_string())?;
392    lock_file.unlock()?;
393
394    Ok(port)
395}
396
397async fn compose(
398    subgraphs: impl IntoIterator<Item = Subgraph>,
399    extension_path: &Path,
400    extensions: impl IntoIterator<Item = LoadedExtension>,
401) -> anyhow::Result<(String, Vec<MockGraphQlServer>)> {
402    let mut mock_subgraphs = Vec::new();
403    let mut composition_subgraphs = Subgraphs::default();
404
405    composition_subgraphs.ingest_loaded_extensions(extensions);
406
407    let extension_url = url::Url::from_file_path(extension_path.join("build")).unwrap();
408    let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
409    let rep = format!(r#"@link(url: "{extension_url}""#);
410
411    for subgraph in subgraphs {
412        match subgraph {
413            Subgraph::Graphql(subgraph) => {
414                let mock_graph = subgraph.start().await;
415                let sdl = re.replace_all(mock_graph.schema(), &rep);
416                composition_subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
417                mock_subgraphs.push(mock_graph);
418            }
419            Subgraph::Virtual(subgraph) => {
420                let sdl = re.replace_all(subgraph.schema(), &rep);
421                composition_subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
422            }
423        }
424    }
425
426    let federated_graph = match graphql_composition::compose(&composition_subgraphs)
427        .warnings_are_fatal()
428        .into_result()
429    {
430        Ok(graph) => graph,
431        Err(diagnostics) => {
432            return Err(anyhow!(
433                "Failed to compose subgraphs:\n{}\n",
434                diagnostics
435                    .iter_messages()
436                    .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
437            ));
438        }
439    };
440    let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
441
442    Ok((federated_sdl, mock_subgraphs))
443}
444
445impl Drop for TestGateway {
446    fn drop(&mut self) {
447        if let Err(err) = self.handle.kill() {
448            eprintln!("Failed to kill grafbase-gateway: {err}")
449        }
450    }
451}