1use std::{
2 future::IntoFuture,
3 marker::PhantomData,
4 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
5 path::Path,
6 time::Duration,
7};
8
9use super::TestConfig;
10use async_tungstenite::tungstenite::handshake::client::Request;
11use futures_util::{StreamExt, stream::BoxStream};
12use grafbase_sdk_mock::{MockGraphQlServer, MockSubgraph};
13use graphql_composition::{LoadedExtension, Subgraphs};
14use graphql_ws_client::graphql::GraphqlOperation;
15use http::{
16 HeaderValue,
17 header::{IntoHeaderName, SEC_WEBSOCKET_PROTOCOL},
18};
19use serde::de::DeserializeOwned;
20use tempfile::TempDir;
21use tungstenite::client::IntoClientRequest;
22use url::Url;
23
24pub struct TestRunner {
26 http_client: reqwest::Client,
27 config: TestConfig,
28 gateway_handle: Option<duct::Handle>,
29 gateway_listen_address: SocketAddr,
30 gateway_endpoint: Url,
31 test_specific_temp_dir: TempDir,
32 _mock_subgraphs: Vec<MockGraphQlServer>,
33 federated_graph: String,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ExtensionToml {
38 extension: ExtensionDefinition,
39}
40
41#[derive(Debug, serde::Deserialize)]
42struct ExtensionDefinition {
43 name: String,
44}
45
46#[allow(clippy::panic)]
47impl TestRunner {
48 pub async fn new(mut config: TestConfig) -> anyhow::Result<Self> {
50 let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
51 let gateway_listen_address = listen_address()?;
52 let gateway_endpoint = Url::parse(&format!("http://{}/graphql", gateway_listen_address))?;
53
54 let extension_toml_path = std::env::current_dir()?.join("extension.toml");
55 let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
56 let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
57 let extension_name = extension_toml.extension.name;
58
59 let mut mock_subgraphs = Vec::new();
60 let mut subgraphs = Subgraphs::default();
61
62 let extension_path = match config.extension_path {
63 Some(ref path) => path.to_path_buf(),
64 None => std::env::current_dir()?.join("build"),
65 };
66
67 subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
68 format!("file://{}", extension_path.display()),
69 extension_name.clone(),
70 )));
71
72 for subgraph in config.mock_subgraphs.drain(..) {
73 match subgraph {
74 MockSubgraph::Dynamic(subgraph) => {
75 let mock_graph = subgraph.start().await;
76 subgraphs.ingest_str(mock_graph.sdl(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
77 mock_subgraphs.push(mock_graph);
78 }
79 MockSubgraph::ExtensionOnly(subgraph) => {
80 subgraphs.ingest_str(subgraph.sdl(), subgraph.name(), None)?;
81 }
82 }
83 }
84
85 let federated_graph = graphql_composition::compose(&subgraphs)
86 .warnings_are_fatal()
87 .into_result()
88 .unwrap();
89 let federated_graph = graphql_composition::render_federated_sdl(&federated_graph)?;
90
91 let mut this = Self {
92 http_client: reqwest::Client::new(),
93 config,
94 gateway_handle: None,
95 gateway_listen_address,
96 gateway_endpoint,
97 test_specific_temp_dir,
98 _mock_subgraphs: mock_subgraphs,
99 federated_graph,
100 };
101
102 if this.config.extension_path.is_none() {
103 this.build_extension(&extension_path)?;
104 }
105
106 this.start_servers(&extension_name, &extension_path)
107 .await
108 .map_err(|err| anyhow::anyhow!("Failed to start servers: {err}"))?;
109
110 Ok(this)
111 }
112
113 async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
114 let extension_path = extension_path.display();
115 let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
116 let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
117 let config = &self.config.gateway_configuration;
118 let enable_stdout = self.config.enable_stdout;
119 let enable_stderr = self.config.enable_stdout;
120 let enable_networking = self.config.enable_networking;
121 let enable_environment_variables = self.config.enable_environment_variables;
122 let max_pool_size = self.config.max_pool_size.unwrap_or(100);
123
124 let config = indoc::formatdoc! {r#"
125 [extensions.{extension_name}]
126 path = "{extension_path}"
127 stdout = {enable_stdout}
128 stderr = {enable_stderr}
129 networking = {enable_networking}
130 environment_variables = {enable_environment_variables}
131 max_pool_size = {max_pool_size}
132
133 {config}
134 "#};
135
136 println!("{config}");
137
138 std::fs::write(&config_path, config.as_bytes())
139 .map_err(|err| anyhow::anyhow!("Failed to write config at {:?}: {err}", config_path))?;
140 std::fs::write(&schema_path, self.federated_graph.as_bytes())
141 .map_err(|err| anyhow::anyhow!("Failed to write schema at {:?}: {err}", schema_path))?;
142
143 let args = &[
144 "--listen-address",
145 &self.gateway_listen_address.to_string(),
146 "--config",
147 &config_path.to_string_lossy(),
148 "--schema",
149 &schema_path.to_string_lossy(),
150 "--log",
151 self.config.log_level.as_ref(),
152 ];
153
154 let mut expr = duct::cmd(&self.config.gateway_path, args);
155
156 if !self.config.enable_stderr {
157 expr = expr.stderr_capture();
158 }
159
160 if !self.config.enable_stdout {
161 expr = expr.stdout_capture();
162 }
163
164 let gateway_handle = expr
165 .unchecked()
166 .start()
167 .map_err(|err| anyhow::anyhow!("Failed to start the gateway: {err}"))?;
168
169 let mut i = 0;
170 while !self.check_gateway_health().await? {
171 if i % 10 == 0 {
173 match gateway_handle.try_wait() {
174 Ok(Some(output)) => panic!(
175 "Gateway process exited unexpectedly: {:?}\n{}\n{}",
176 output.status,
177 String::from_utf8_lossy(&output.stdout),
178 String::from_utf8_lossy(&output.stderr)
179 ),
180 Ok(None) => (),
181 Err(err) => panic!("Error waiting for gateway process: {}", err),
182 }
183 println!("Waiting for gateway to be ready...");
184 }
185 i += 1;
186 std::thread::sleep(Duration::from_millis(100));
187 }
188
189 self.gateway_handle = Some(gateway_handle);
190
191 Ok(())
192 }
193
194 async fn check_gateway_health(&self) -> anyhow::Result<bool> {
195 let url = self.gateway_endpoint.join("/health")?;
196
197 let Ok(result) = self.http_client.get(url).send().await else {
198 return Ok(false);
199 };
200
201 let result = result.error_for_status().is_ok();
202
203 Ok(result)
204 }
205
206 fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
207 let extension_path = extension_path.to_string_lossy();
208
209 let mut lock_file = fslock::LockFile::open(".build.lock")?;
212 lock_file.lock()?;
213
214 let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
215 let mut expr = duct::cmd(&self.config.cli_path, args);
216
217 if !self.config.enable_stdout {
218 expr = expr.stdout_capture();
219 }
220
221 if !self.config.enable_stderr {
222 expr = expr.stderr_capture();
223 }
224
225 let output = expr.unchecked().run()?;
226 if !output.status.success() {
227 panic!(
228 "Failed to build extension: {}\n{}\n{}",
229 output.status,
230 String::from_utf8_lossy(&output.stdout),
231 String::from_utf8_lossy(&output.stderr)
232 );
233 }
234
235 lock_file.unlock()?;
236
237 Ok(())
238 }
239
240 pub fn graphql_query<Response>(&self, query: impl Into<String>) -> QueryBuilder<Response> {
250 let reqwest_builder = self
251 .http_client
252 .post(self.gateway_endpoint.clone())
253 .header(http::header::ACCEPT, "application/json");
254
255 QueryBuilder {
256 query: query.into(),
257 variables: None,
258 phantom: PhantomData,
259 reqwest_builder,
260 }
261 }
262
263 pub fn graphql_subscription<Response>(
272 &self,
273 query: impl Into<String>,
274 ) -> anyhow::Result<SubscriptionBuilder<Response>> {
275 let mut url = self.gateway_endpoint.clone();
276
277 url.set_path("/ws");
278 url.set_scheme("ws").unwrap();
279
280 let mut request_builder = url.as_ref().into_client_request()?;
281
282 request_builder
283 .headers_mut()
284 .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("graphql-transport-ws"));
285
286 let operation = Operation {
287 query: query.into(),
288 variables: None,
289 phantom: PhantomData,
290 };
291
292 Ok(SubscriptionBuilder {
293 operation,
294 request_builder,
295 })
296 }
297
298 pub fn federated_graph(&self) -> &str {
300 &self.federated_graph
301 }
302}
303
304pub(crate) fn free_port() -> anyhow::Result<u16> {
305 const INITIAL_PORT: u16 = 14712;
306
307 let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
308 std::fs::create_dir_all(&test_dir)?;
309
310 let lock_file_path = test_dir.join("port-number.lock");
311 let port_number_file_path = test_dir.join("port-number.txt");
312
313 let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
314 lock_file.lock()?;
315
316 let port = if port_number_file_path.exists() {
317 std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
318 } else {
319 INITIAL_PORT
320 };
321
322 std::fs::write(&port_number_file_path, port.to_string())?;
323 lock_file.unlock()?;
324
325 Ok(port)
326}
327
328pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
329 let port = free_port()?;
330 Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
331}
332
333impl Drop for TestRunner {
334 fn drop(&mut self) {
335 let Some(handle) = self.gateway_handle.take() else {
336 return;
337 };
338
339 if let Err(err) = handle.kill() {
340 eprintln!("Failed to kill grafbase-gateway: {}", err)
341 }
342 }
343}
344
345#[derive(serde::Serialize)]
346#[must_use]
347pub struct QueryBuilder<Response> {
349 query: String,
351 #[serde(skip_serializing_if = "Option::is_none")]
352 variables: Option<serde_json::Value>,
353
354 #[serde(skip)]
356 phantom: PhantomData<fn() -> Response>,
357 #[serde(skip)]
358 reqwest_builder: reqwest::RequestBuilder,
359}
360
361impl<Response> QueryBuilder<Response> {
362 pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
368 self.variables = Some(serde_json::to_value(variables).unwrap());
369 self
370 }
371
372 pub fn with_header(self, name: &str, value: &str) -> Self {
374 let Self {
375 phantom,
376 query,
377 mut reqwest_builder,
378 variables,
379 } = self;
380
381 reqwest_builder = reqwest_builder.header(name, value);
382
383 Self {
384 query,
385 variables,
386 phantom,
387 reqwest_builder,
388 }
389 }
390
391 pub async fn send(self) -> anyhow::Result<Response>
404 where
405 Response: for<'de> serde::Deserialize<'de>,
406 {
407 let json = serde_json::to_value(&self)?;
408 Ok(self.reqwest_builder.json(&json).send().await?.json().await?)
409 }
410}
411
412#[must_use]
413pub struct SubscriptionBuilder<Response> {
415 operation: Operation<Response>,
416 request_builder: Request,
417}
418
419#[derive(serde::Serialize)]
420struct Operation<Response> {
421 query: String,
422 #[serde(skip_serializing_if = "Option::is_none")]
423 variables: Option<serde_json::Value>,
424 #[serde(skip)]
425 phantom: PhantomData<fn() -> Response>,
426}
427
428impl<Response> GraphqlOperation for Operation<Response>
429where
430 Response: DeserializeOwned,
431{
432 type Response = Response;
433 type Error = serde_json::Error;
434
435 fn decode(&self, data: serde_json::Value) -> Result<Self::Response, Self::Error> {
436 serde_json::from_value(data)
437 }
438}
439
440impl<Response> SubscriptionBuilder<Response>
441where
442 Response: DeserializeOwned + 'static,
443{
444 pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
450 self.operation.variables = Some(serde_json::to_value(variables).unwrap());
451 self
452 }
453
454 pub fn with_header<K>(mut self, name: K, value: HeaderValue) -> Self
461 where
462 K: IntoHeaderName,
463 {
464 self.request_builder.headers_mut().insert(name, value);
465 self
466 }
467
468 pub async fn subscribe(self) -> anyhow::Result<BoxStream<'static, Response>> {
480 let (connection, _) = async_tungstenite::tokio::connect_async(self.request_builder).await?;
481 let (client, actor) = graphql_ws_client::Client::build(connection).await?;
482
483 tokio::spawn(actor.into_future());
484
485 let stream = client
486 .subscribe(self.operation)
487 .await?
488 .map(move |item| -> Response { item.unwrap() });
489
490 Ok(Box::pin(stream))
491 }
492}