1use std::{
2 collections::hash_map::Entry,
3 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
4 path::{Path, PathBuf},
5 str::FromStr,
6 time::Duration,
7};
8
9use crate::test::{
10 GraphqlRequest, LogLevel,
11 config::{
12 CLI_BINARY_NAME, ExtensionConfig, ExtensionToml, GATEWAY_BINARY_NAME, GatewayToml, StructuredExtensionConfig,
13 },
14 request::{Body, IntrospectionRequest},
15};
16
17use anyhow::{Context, anyhow};
18use grafbase_sdk_mock::{MockGraphQlServer, Subgraph};
19use graphql_composition::{LoadedExtension, Subgraphs};
20use itertools::Itertools;
21use regex::Regex;
22use tempfile::TempDir;
23use url::Url;
24
25pub struct TestGateway {
27 http_client: reqwest::Client,
28 handle: duct::Handle,
29 url: Url,
30 federated_sdl: String,
31 #[allow(unused)]
33 tmp_dir: TempDir,
34 #[allow(unused)]
35 mock_subgraphs: Vec<MockGraphQlServer>,
36}
37
38impl TestGateway {
39 pub fn builder() -> TestGatewayBuilder {
41 TestGatewayBuilder::new()
42 }
43
44 pub fn url(&self) -> &Url {
46 &self.url
47 }
48
49 pub fn query(&self, query: impl Into<Body>) -> GraphqlRequest {
59 let builder = self.http_client.post(self.url.clone());
60 GraphqlRequest {
61 builder,
62 body: query.into(),
63 }
64 }
65
66 pub fn federated_sdl(&self) -> &str {
68 &self.federated_sdl
69 }
70
71 pub fn introspect(&self) -> IntrospectionRequest {
78 let operation = cynic_introspection::IntrospectionQuery::with_capabilities(
79 cynic_introspection::SpecificationVersion::October2021.capabilities(),
80 );
81 IntrospectionRequest(self.query(Body {
82 query: Some(operation.query),
83 variables: None,
84 }))
85 }
86
87 pub async fn health(&self) -> anyhow::Result<()> {
89 let url = self.url.join("/health")?;
90 let _ = self.http_client.get(url).send().await?.error_for_status()?;
91 Ok(())
92 }
93}
94
95#[derive(Debug, Default, Clone)]
96pub struct TestGatewayBuilder {
98 gateway_path: Option<PathBuf>,
99 cli_path: Option<PathBuf>,
100 toml_config: Option<String>,
101 subgraphs: Vec<Subgraph>,
102 stream_stdout_stderr: Option<bool>,
103 log_level: Option<LogLevel>,
104}
105
106impl TestGatewayBuilder {
107 pub(crate) fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn subgraph(mut self, subgraph: impl Into<Subgraph>) -> Self {
114 self.subgraphs.push(subgraph.into());
115 self
116 }
117
118 pub fn with_gateway(mut self, gateway_path: impl Into<PathBuf>) -> Self {
120 self.gateway_path = Some(gateway_path.into());
121 self
122 }
123
124 pub fn with_cli(mut self, cli_path: impl Into<PathBuf>) -> Self {
126 self.cli_path = Some(cli_path.into());
127 self
128 }
129
130 pub fn toml_config(mut self, cfg: impl ToString) -> Self {
133 self.toml_config = Some(cfg.to_string());
134 self
135 }
136
137 pub fn log_level(mut self, level: impl Into<LogLevel>) -> Self {
139 self.log_level = Some(level.into());
140 self
141 }
142
143 pub fn stream_stdout_stderr(mut self) -> Self {
147 self.stream_stdout_stderr = Some(true);
148 self
149 }
150
151 pub async fn build(self) -> anyhow::Result<TestGateway> {
153 println!("Building the gateway:");
154
155 let gateway_path = match self.gateway_path {
156 Some(path) => path,
157 None => which::which(GATEWAY_BINARY_NAME).context("Could not fild grafbase-gateway binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
158 };
159
160 let cli_path = match self.cli_path {
161 Some(path) => path,
162 None => which::which(CLI_BINARY_NAME).context("Could not fild grafbase binary in the PATH. Either install it or specify the gateway path in the test configuration.")?,
163 };
164
165 let log_level = self.log_level.unwrap_or_default();
166
167 let extension_path = std::env::current_dir()?;
168 let extension_name =
169 toml::from_str::<ExtensionToml>(&std::fs::read_to_string(extension_path.join("extension.toml"))?)?
170 .extension
171 .name;
172
173 {
175 println!("* Building current extension.");
176 let lock_path = extension_path.join(".build.lock");
177 let mut lock_file = fslock::LockFile::open(&lock_path)?;
178 lock_file.lock()?;
179
180 let output = {
181 let cmd = duct::cmd(&cli_path, &["extension", "build", "--debug"]).dir(&extension_path);
182 if self.stream_stdout_stderr.unwrap_or(false) {
183 cmd
184 } else {
185 cmd.stdout_capture().stderr_capture()
186 }
187 }
188 .unchecked()
189 .stderr_to_stdout()
190 .run()?;
191
192 if !output.status.success() {
193 return Err(anyhow!(
194 "Failed to build extension: {}\n{}\n{}",
195 output.status,
196 String::from_utf8_lossy(&output.stdout),
197 String::from_utf8_lossy(&output.stderr)
198 ));
199 }
200
201 lock_file.unlock()?;
202 anyhow::Ok(())
203 }?;
204
205 println!("* Preparing the grafbase.toml & schema.graphql files.");
206 let mut toml_config: GatewayToml = toml::from_str(&self.toml_config.unwrap_or_default())?;
208 match toml_config.extensions.entry(extension_name.clone()) {
209 Entry::Occupied(mut entry) => match entry.get_mut() {
210 ExtensionConfig::Version(_) => {
211 return Err(anyhow!(
212 "Current extension {extension_name} cannot be specified with a version"
213 ));
214 }
215 ExtensionConfig::Structured(config) => {
216 config
217 .path
218 .get_or_insert_with(|| extension_path.join("build").to_string_lossy().into_owned());
219 }
220 },
221 Entry::Vacant(entry) => {
222 entry.insert(ExtensionConfig::Structured(StructuredExtensionConfig {
223 path: Some(extension_path.join("build").to_string_lossy().into_owned()),
224 version: None,
225 rest: Default::default(),
226 }));
227 }
228 }
229
230 let (federated_sdl, mock_subgraphs) = {
232 let extensions = toml_config
233 .extensions
234 .iter()
235 .map(|(name, config)| match config {
236 ExtensionConfig::Version(version) => Ok(new_loaded_extension(
237 format!("https://extensions.grafbase.com/{extension_name}/{version}")
238 .parse()
239 .unwrap(),
240 name.clone(),
241 )),
242 ExtensionConfig::Structured(StructuredExtensionConfig {
243 path: None, version, ..
244 }) => Ok(new_loaded_extension(
245 format!(
246 "https://extensions.grafbase.com/{extension_name}/{}",
247 version
248 .as_ref()
249 .ok_or_else(|| anyhow!("Missing path or version for extension '{name}'"))?
250 )
251 .parse()
252 .unwrap(),
253 name.clone(),
254 )),
255 ExtensionConfig::Structured(StructuredExtensionConfig { path: Some(path), .. }) => {
256 let mut path = PathBuf::from_str(path.as_str())
257 .context(format!("Invalid path for extension {name}: {path}"))?;
258 if path.is_relative() {
259 path = extension_path.join(path);
260 }
261 anyhow::Ok(new_loaded_extension(
262 Url::from_file_path(&path).unwrap(),
263 name.to_owned(),
264 ))
265 }
266 })
267 .collect::<anyhow::Result<Vec<_>>>()?;
268
269 compose(self.subgraphs, &extension_path, extensions).await
270 }?;
271
272 if toml_config.wasm.cache_path.is_none() {
273 toml_config.wasm.cache_path = Some(extension_path.join("build").join("wasm-cache"));
274 }
275
276 let tmp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
278 let config_path = tmp_dir.path().join("grafbase.toml");
279 let schema_path = tmp_dir.path().join("schema.graphql");
280
281 std::fs::write(&config_path, toml::to_string(&toml_config)?).context("Failed to write grafbase.toml")?;
282 std::fs::write(&schema_path, &federated_sdl).context("Failed to write schema.graphql")?;
283
284 if toml_config.extensions.len() > 1 {
286 println!("* Installing other extensions.");
287 let output = {
288 let cmd = duct::cmd(&cli_path, &["extension", "install"]).dir(tmp_dir.path());
289 if self.stream_stdout_stderr.unwrap_or(false) {
290 cmd
291 } else {
292 cmd.stdout_capture().stderr_capture()
293 }
294 }
295 .unchecked()
296 .stderr_to_stdout()
297 .run()?;
298
299 if !output.status.success() {
300 return Err(anyhow!(
301 "Failed to install extensions: {}\n{}\n{}",
302 output.status,
303 String::from_utf8_lossy(&output.stdout),
304 String::from_utf8_lossy(&output.stderr)
305 ));
306 }
307 }
308
309 println!("* Starting the gateway.");
310 let listen_address = new_listen_address()?;
311 let url = Url::parse(&format!("http://{listen_address}/graphql")).unwrap();
312
313 let handle = {
314 let cmd = duct::cmd(
315 &gateway_path,
316 &[
317 "--listen-address",
318 &listen_address.to_string(),
319 "--config",
320 &config_path.to_string_lossy(),
321 "--schema",
322 &schema_path.to_string_lossy(),
323 "--log",
324 log_level.as_ref(),
325 ],
326 )
327 .dir(tmp_dir.path());
328 if self.stream_stdout_stderr.unwrap_or(false) {
329 cmd
330 } else {
331 cmd.stdout_capture().stderr_capture()
332 }
333 }
334 .unchecked()
335 .stderr_to_stdout()
336 .start()
337 .map_err(|err| anyhow!("Failed to start the gateway: {err}"))?;
338
339 let gateway = TestGateway {
340 http_client: reqwest::Client::new(),
341 handle,
342 url,
343 tmp_dir,
344 mock_subgraphs,
345 federated_sdl,
346 };
347
348 let mut i = 0;
349 while gateway.health().await.is_err() {
350 if i % 10 == 0 {
352 match gateway.handle.try_wait() {
353 Ok(Some(output)) => {
354 return Err(anyhow!(
355 "Gateway process exited unexpectedly: {}\n{}\n{}",
356 output.status,
357 String::from_utf8_lossy(&output.stdout),
358 String::from_utf8_lossy(&output.stderr)
359 ));
360 }
361 Ok(None) => (),
362 Err(err) => return Err(anyhow!("Error waiting for gateway process: {err}")),
363 }
364 println!("Waiting for gateway to be ready...");
365 }
366 i += 1;
367 tokio::time::sleep(Duration::from_millis(100)).await;
368 }
369
370 Ok(gateway)
371 }
372}
373
374pub(crate) fn new_listen_address() -> anyhow::Result<SocketAddr> {
375 let port = free_port()?;
376 Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
377}
378
379pub(crate) fn free_port() -> anyhow::Result<u16> {
380 const INITIAL_PORT: u16 = 14712;
381
382 let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
383 std::fs::create_dir_all(&test_dir)?;
384
385 let lock_file_path = test_dir.join("port-number.lock");
386 let port_number_file_path = test_dir.join("port-number.txt");
387
388 let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
389 lock_file.lock()?;
390
391 let port = if port_number_file_path.exists() {
392 std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
393 } else {
394 INITIAL_PORT
395 };
396
397 std::fs::write(&port_number_file_path, port.to_string())?;
398 lock_file.unlock()?;
399
400 Ok(port)
401}
402
403async fn compose(
404 subgraphs: impl IntoIterator<Item = Subgraph>,
405 extension_path: &Path,
406 extensions: impl IntoIterator<Item = LoadedExtension>,
407) -> anyhow::Result<(String, Vec<MockGraphQlServer>)> {
408 let mut mock_subgraphs = Vec::new();
409 let mut composition_subgraphs = Subgraphs::default();
410
411 composition_subgraphs.ingest_loaded_extensions(extensions);
412
413 let extension_url = url::Url::from_file_path(extension_path.join("build")).unwrap();
414 let re = Regex::new(r#"@link\(\s*url\s*:\s*"(<self>)""#).unwrap();
415 let rep = format!(r#"@link(url: "{extension_url}""#);
416
417 for subgraph in subgraphs {
418 match subgraph {
419 Subgraph::Graphql(subgraph) => {
420 let mock_graph = subgraph.start().await;
421 let sdl = re.replace_all(mock_graph.schema(), &rep);
422 composition_subgraphs.ingest_str(sdl.as_ref(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
423 mock_subgraphs.push(mock_graph);
424 }
425 Subgraph::Virtual(subgraph) => {
426 let sdl = re.replace_all(subgraph.schema(), &rep);
427 composition_subgraphs.ingest_str(sdl.as_ref(), subgraph.name(), None)?;
428 }
429 }
430 }
431
432 let federated_graph = match graphql_composition::compose(&mut composition_subgraphs)
433 .warnings_are_fatal()
434 .into_result()
435 {
436 Ok(graph) => graph,
437 Err(diagnostics) => {
438 return Err(anyhow!(
439 "Failed to compose subgraphs:\n{}\n",
440 diagnostics
441 .iter_messages()
442 .format_with("\n", |msg, f| f(&format_args!("- {msg}")))
443 ));
444 }
445 };
446 let federated_sdl = graphql_composition::render_federated_sdl(&federated_graph)?;
447
448 Ok((federated_sdl, mock_subgraphs))
449}
450
451impl Drop for TestGateway {
452 fn drop(&mut self) {
453 if let Err(err) = self.handle.kill() {
454 eprintln!("Failed to kill grafbase-gateway: {err}")
455 }
456 }
457}
458
459fn new_loaded_extension(url: Url, name: String) -> LoadedExtension {
460 LoadedExtension {
461 link_url: url.to_string(),
462 url,
463 name,
464 }
465}