1use std::{
2 fmt::Write,
3 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4 path::{Path, PathBuf},
5 time::Duration,
6};
7
8use crate::test::{
9 GraphqlRequest, LogLevel,
10 config::{CLI_BINARY_NAME, GATEWAY_BINARY_NAME, TestConfig},
11 request::Body,
12};
13
14use anyhow::Context;
15use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
16use graphql_composition::{LoadedExtension, Subgraphs};
17use itertools::Itertools;
18use regex::Regex;
19use tempfile::TempDir;
20use url::Url;
21
22pub struct TestGateway {
24 http_client: reqwest::Client,
25 config: TestConfig,
26 gateway_handle: Option<duct::Handle>,
27 gateway_listen_address: SocketAddr,
28 gateway_endpoint: Url,
29 test_specific_temp_dir: TempDir,
30 _mock_subgraphs: Vec<MockGraphQlServer>,
31 federated_sdl: String,
32}
33
34#[derive(Debug, serde::Deserialize)]
35struct ExtensionToml {
36 extension: ExtensionDefinition,
37}
38
39#[derive(Debug, serde::Deserialize)]
40struct ExtensionDefinition {
41 name: String,
42}
43
44impl TestGateway {
45 pub fn builder() -> TestGatewayBuilder {
47 TestGatewayBuilder::new()
48 }
49
50 pub fn url(&self) -> &Url {
52 &self.gateway_endpoint
53 }
54
55 pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
65 let builder = self.http_client.post(self.gateway_endpoint.clone());
66 GraphqlRequest {
67 builder,
68 body: query.into(),
69 }
70 }
71
72 pub fn federated_sdl(&self) -> &str {
74 &self.federated_sdl
75 }
76}
77
78#[derive(Debug, Default, Clone)]
79pub struct TestGatewayBuilder {
81 gateway_path: Option<PathBuf>,
82 cli_path: Option<PathBuf>,
83 extension_path: Option<PathBuf>,
84 toml_config: Option<String>,
85 subgraphs: Vec<Subgraph>,
86 enable_stdout: Option<bool>,
87 enable_stderr: Option<bool>,
88 enable_networking: Option<bool>,
89 enable_environment_variables: Option<bool>,
90 stream_stdout_stderr: Option<bool>,
91 max_pool_size: Option<usize>,
92 log_level: Option<LogLevel>,
93}
94
95impl TestGatewayBuilder {
96 pub(crate) fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
103 self.subgraphs.push(subgraph.into());
104 self
105 }
106
107 pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
109 self.gateway_path = Some(gateway_path.into());
110 self
111 }
112
113 pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
115 self.cli_path = Some(cli_path.into());
116 self
117 }
118
119 pub fn with_extension_path(mut self, extension_path: impl Into<PathBuf>) -> Self {
121 self.extension_path = Some(extension_path.into().canonicalize().unwrap());
122 self
123 }
124
125 pub fn enable_stdout(mut self) -> Self {
128 self.enable_stdout = Some(true);
129 self
130 }
131
132 pub fn enable_stderr(mut self) -> Self {
135 self.enable_stderr = Some(true);
136 self
137 }
138
139 pub fn enable_networking(mut self) -> Self {
141 self.enable_networking = Some(true);
142 self
143 }
144
145 pub fn enable_environment_variables(mut self) -> Self {
147 self.enable_environment_variables = Some(true);
148 self
149 }
150
151 pub fn max_pool_size(mut self, size: usize) -> Self {
153 self.max_pool_size = Some(size);
154 self
155 }
156
157 pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
159 self.log_level = Some(level.into());
160 self
161 }
162
163 pub fn toml_config(mut self, cfg: impl ToString) -> Self {
166 self.toml_config = Some(cfg.to_string());
167 self
168 }
169
170 pub fn stream_stdout_stderr(mut self) -> Self {
174 self.stream_stdout_stderr = Some(true);
175 self
176 }
177
178 pub async fn build(self) -> anyhow::Result<TestGateway> {
180 let Self {
181 gateway_path,
182 cli_path,
183 extension_path,
184 enable_stdout,
185 enable_stderr,
186 subgraphs: mock_subgraphs,
187 enable_networking,
188 enable_environment_variables,
189 max_pool_size,
190 log_level,
191 toml_config,
192 stream_stdout_stderr,
193 } = self;
194
195 let gateway_path = match gateway_path {
196 Some(path) => path,
197 None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
198 };
199
200 let cli_path = match cli_path {
201 Some(path) => path,
202 None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
203 };
204
205 let log_level = log_level.unwrap_or_default();
206
207 TestGateway::setup(TestConfig {
208 gateway_path,
209 cli_path,
210 toml_config: toml_config.unwrap_or_default(),
211 extension_path,
212 enable_stdout,
213 enable_stderr,
214 mock_subgraphs,
215 enable_networking,
216 enable_environment_variables,
217 max_pool_size,
218 log_level,
219 stream_stdout_stderr,
220 })
221 .await
222 }
223}
224
225#[allow(clippy::panic)]
226impl TestGateway {
227 async fn setup(mut config: TestConfig) -> anyhow::Result<Self> {
228 let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
229 let gateway_listen_address = listen_address()?;
230 let gateway_endpoint = Url::parse(&format!("http://{gateway_listen_address}/graphql"))?;
231
232 let extension_toml_path = std::env::current_dir()?.join("extension.toml");
233 let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
234 let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
235 let extension_name = extension_toml.extension.name;
236
237 let mut mock_subgraphs = Vec::new();
238 let mut subgraphs = Subgraphs::default();
239
240 let extension_path = match config.extension_path {
241 Some(ref path) => path.to_path_buf(),
242 None => std::env::current_dir()?.join("build"),
243 };
244
245 let extension_url = url::Url::from_file_path(&extension_path).unwrap();
246 subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
247 extension_url.to_string(),
248 extension_name.clone(),
249 )));
250
251 let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
252 let rep = format!(r#"@link(url: "{extension_url}""#);
253 for subgraph in config.mock_subgraphs.drain(..) {
254 match subgraph {
255 Subgraph::Graphql(subgraph) => {
256 let mock_graph = subgraph.start().await;
257 let sdl = re.replace_all(mock_graph.schema(), &rep);
258 subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
259 mock_subgraphs.push(mock_graph);
260 }
261 Subgraph::Virtual(subgraph) => {
262 let sdl = re.replace_all(subgraph.schema(), &rep);
263 subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
264 }
265 }
266 }
267
268 let federated_graph = match graphql_composition::compose(&subgraphs)
269 .warnings_are_fatal()
270 .into_result()
271 {
272 Ok(graph) => graph,
273 Err(diagnostics) => {
274 return Err(anyhow::anyhow!(
275 "Failed to compose subgraphs:\n{}\n",
276 diagnostics
277 .iter_messages()
278 .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
279 ));
280 }
281 };
282 let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
283
284 let mut this = Self {
285 http_client: reqwest::Client::new(),
286 config,
287 gateway_handle: None,
288 gateway_listen_address,
289 gateway_endpoint,
290 test_specific_temp_dir,
291 _mock_subgraphs: mock_subgraphs,
292 federated_sdl,
293 };
294
295 if this.config.extension_path.is_none() {
296 this.build_extension(&extension_path)?;
297 }
298
299 this.start_servers(&extension_name, &extension_path)
300 .await
301 .map_err(|err| anyhow::anyhow!("Failed to start servers: {err}"))?;
302
303 Ok(this)
304 }
305
306 async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
307 let extension_path = extension_path.display();
308 let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
309 let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
310
311 let config = {
312 let max_pool_size = self.config.max_pool_size.unwrap_or(100);
313 let mut config = indoc::formatdoc! {r#"
314 [extensions.{extension_name}]
315 path = "{extension_path}"
316 max_pool_size = {max_pool_size}
317 "#};
318 if let Some(enabled) = self.config.enable_stderr {
319 writeln!(config, "stderr = {enabled}").unwrap();
320 }
321 if let Some(enabled) = self.config.enable_stdout {
322 writeln!(config, "stdout = {enabled}").unwrap();
323 }
324 if let Some(enabled) = self.config.enable_networking {
325 writeln!(config, "networking = {enabled}").unwrap();
326 }
327 if let Some(enabled) = self.config.enable_environment_variables {
328 writeln!(config, "environment_variables = {enabled}").unwrap();
329 }
330 config.push_str("\n\n");
331 config.push_str(&self.config.toml_config);
332 config
333 };
334 println!("{config}");
335
336 std::fs::write(&config_path, config.as_bytes())
337 .map_err(|err| anyhow::anyhow!("Failed to write config at {:?}: {err}", config_path))?;
338 std::fs::write(&schema_path, self.federated_sdl.as_bytes())
339 .map_err(|err| anyhow::anyhow!("Failed to write schema at {:?}: {err}", schema_path))?;
340
341 let args = &[
342 "--listen-address",
343 &self.gateway_listen_address.to_string(),
344 "--config",
345 &config_path.to_string_lossy(),
346 "--schema",
347 &schema_path.to_string_lossy(),
348 "--log",
349 self.config.log_level.as_ref(),
350 ];
351
352 let gateway_handle = {
353 let cmd = duct::cmd(&self.config.gateway_path, args);
354 if self.config.stream_stdout_stderr.unwrap_or(false) {
355 cmd
356 } else {
357 cmd.stdout_capture().stderr_capture()
358 }
359 }
360 .unchecked()
361 .stderr_to_stdout()
362 .start()
363 .map_err(|err| anyhow::anyhow!("Failed to start the gateway: {err}"))?;
364
365 let mut i = 0;
366 while !self.check_gateway_health().await? {
367 if i % 10 == 0 {
369 match gateway_handle.try_wait() {
370 Ok(Some(output)) => panic!(
371 "Gateway process exited unexpectedly: {}\n{}\n{}",
372 output.status,
373 String::from_utf8_lossy(&output.stdout),
374 String::from_utf8_lossy(&output.stderr)
375 ),
376 Ok(None) => (),
377 Err(err) => panic!("Error waiting for gateway process: {err}"),
378 }
379 println!("Waiting for gateway to be ready...");
380 }
381 i += 1;
382 std::thread::sleep(Duration::from_millis(100));
383 }
384
385 self.gateway_handle = Some(gateway_handle);
386
387 Ok(())
388 }
389
390 async fn check_gateway_health(&self) -> anyhow::Result<bool> {
391 let url = self.gateway_endpoint.join("/health")?;
392
393 let Ok(result) = self.http_client.get(url).send().await else {
394 return Ok(false);
395 };
396
397 let result = result.error_for_status().is_ok();
398
399 Ok(result)
400 }
401
402 fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
403 let extension_path = extension_path.to_string_lossy();
404
405 let mut lock_file = fslock::LockFile::open(".build.lock")?;
408 lock_file.lock()?;
409
410 let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
411 let output = {
412 let cmd = duct::cmd(&self.config.cli_path, args);
413 if self.config.stream_stdout_stderr.unwrap_or(false) {
414 cmd
415 } else {
416 cmd.stdout_capture().stderr_capture()
417 }
418 }
419 .unchecked()
420 .stderr_to_stdout()
421 .run()?;
422 if !output.status.success() {
423 panic!(
424 "Failed to build extension: {}\n{}\n{}",
425 output.status,
426 String::from_utf8_lossy(&output.stdout),
427 String::from_utf8_lossy(&output.stderr)
428 );
429 }
430
431 lock_file.unlock()?;
432
433 Ok(())
434 }
435}
436
437pub(crate) fn free_port() -> anyhow::Result<u16> {
438 const INITIAL_PORT: u16 = 14712;
439
440 let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
441 std::fs::create_dir_all(&test_dir)?;
442
443 let lock_file_path = test_dir.join("port-number.lock");
444 let port_number_file_path = test_dir.join("port-number.txt");
445
446 let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
447 lock_file.lock()?;
448
449 let port = if port_number_file_path.exists() {
450 std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
451 } else {
452 INITIAL_PORT
453 };
454
455 std::fs::write(&port_number_file_path, port.to_string())?;
456 lock_file.unlock()?;
457
458 Ok(port)
459}
460
461pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
462 let port = free_port()?;
463 Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
464}
465
466impl Drop for TestGateway {
467 fn drop(&mut self) {
468 let Some(handle) = self.gateway_handle.take() else {
469 return;
470 };
471
472 if let Err(err) = handle.kill() {
473 eprintln!("Failed to kill grafbase-gateway: {err}")
474 }
475 }
476}