1use std::{
2 future::IntoFuture,
3 marker::PhantomData,
4 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
5 path::Path,
6 time::Duration,
7};
8
9use super::TestConfig;
10use async_tungstenite::tungstenite::handshake::client::Request;
11use futures_util::{stream::BoxStream, StreamExt};
12use grafbase_sdk_mock::{MockGraphQlServer, MockSubgraph};
13use graphql_composition::{LoadedExtension, Subgraphs};
14use graphql_ws_client::graphql::GraphqlOperation;
15use http::{
16 header::{IntoHeaderName, SEC_WEBSOCKET_PROTOCOL},
17 HeaderValue,
18};
19use serde::de::DeserializeOwned;
20use tempfile::TempDir;
21use tungstenite::client::IntoClientRequest;
22use url::Url;
23
24pub struct TestRunner {
26 http_client: reqwest::Client,
27 config: TestConfig,
28 gateway_handle: Option<duct::Handle>,
29 gateway_listen_address: SocketAddr,
30 gateway_endpoint: Url,
31 test_specific_temp_dir: TempDir,
32 _mock_subgraphs: Vec<MockGraphQlServer>,
33 federated_graph: String,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ExtensionToml {
38 extension: ExtensionDefinition,
39}
40
41#[derive(Debug, serde::Deserialize)]
42struct ExtensionDefinition {
43 name: String,
44}
45
46impl TestRunner {
47 pub async fn new(mut config: TestConfig) -> anyhow::Result<Self> {
49 let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
50 let gateway_listen_address = listen_address()?;
51 let gateway_endpoint = Url::parse(&format!("http://{}/graphql", gateway_listen_address))?;
52
53 let extension_toml_path = std::env::current_dir()?.join("extension.toml");
54 let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
55 let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
56 let extension_name = extension_toml.extension.name;
57
58 let mut mock_subgraphs = Vec::new();
59 let mut subgraphs = Subgraphs::default();
60
61 let extension_path = match config.extension_path {
62 Some(ref path) => path.to_path_buf(),
63 None => std::env::current_dir()?.join("build"),
64 };
65
66 subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
67 format!("file://{}", extension_path.display()),
68 extension_name.clone(),
69 )));
70
71 for subgraph in config.mock_subgraphs.drain(..) {
72 match subgraph {
73 MockSubgraph::Dynamic(subgraph) => {
74 let mock_graph = subgraph.start().await;
75 subgraphs.ingest_str(mock_graph.sdl(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
76 mock_subgraphs.push(mock_graph);
77 }
78 MockSubgraph::ExtensionOnly(subgraph) => {
79 subgraphs.ingest_str(subgraph.sdl(), subgraph.name(), None)?;
80 }
81 }
82 }
83
84 let federated_graph = graphql_composition::compose(&subgraphs).into_result().unwrap();
85 let federated_graph = graphql_federated_graph::render_federated_sdl(&federated_graph)?;
86
87 let mut this = Self {
88 http_client: reqwest::Client::new(),
89 config,
90 gateway_handle: None,
91 gateway_listen_address,
92 gateway_endpoint,
93 test_specific_temp_dir,
94 _mock_subgraphs: mock_subgraphs,
95 federated_graph,
96 };
97
98 this.build_extension(&extension_path)?;
99 this.start_servers(&extension_name, &extension_path).await?;
100
101 Ok(this)
102 }
103
104 async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
105 let extension_path = extension_path.display();
106 let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
107 let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
108 let config = &self.config.gateway_configuration;
109 let enable_stdout = self.config.enable_stdout;
110 let enable_stderr = self.config.enable_stdout;
111 let enable_networking = self.config.enable_networking;
112 let enable_environment_variables = self.config.enable_environment_variables;
113 let max_pool_size = self.config.max_pool_size.unwrap_or(100);
114
115 let config = indoc::formatdoc! {r#"
116 [extensions.{extension_name}]
117 path = "{extension_path}"
118 stdout = {enable_stdout}
119 stderr = {enable_stderr}
120 networking = {enable_networking}
121 environment_variables = {enable_environment_variables}
122 max_pool_size = {max_pool_size}
123
124 {config}
125 "#};
126
127 println!("{config}");
128
129 std::fs::write(&config_path, config.as_bytes())?;
130 std::fs::write(&schema_path, self.federated_graph.as_bytes())?;
131
132 let args = &[
133 "--listen-address",
134 &self.gateway_listen_address.to_string(),
135 "--config",
136 &config_path.to_string_lossy(),
137 "--schema",
138 &schema_path.to_string_lossy(),
139 "--log",
140 self.config.log_level.as_ref(),
141 ];
142
143 let mut expr = duct::cmd(&self.config.gateway_path, args);
144
145 if !dbg!(self.config.enable_stderr) {
146 expr = expr.stderr_null();
147 }
148
149 if !self.config.enable_stdout {
150 expr = expr.stdout_null();
151 }
152
153 self.gateway_handle = Some(expr.start()?);
154
155 let mut i = 0;
156 while !self.check_gateway_health().await? {
157 if i % 10 == 0 {
159 println!("Waiting for gateway to be ready...");
160 }
161 i += 1;
162 std::thread::sleep(Duration::from_millis(100));
163 }
164
165 Ok(())
166 }
167
168 async fn check_gateway_health(&self) -> anyhow::Result<bool> {
169 let url = self.gateway_endpoint.join("/health")?;
170
171 let Ok(result) = self.http_client.get(url).send().await else {
172 return Ok(false);
173 };
174
175 let result = result.error_for_status().is_ok();
176
177 Ok(result)
178 }
179
180 fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
181 let extension_path = extension_path.to_string_lossy();
182
183 let mut lock_file = fslock::LockFile::open(".build.lock")?;
186 lock_file.lock()?;
187
188 let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
189 let mut expr = duct::cmd(&self.config.cli_path, args);
190
191 if !self.config.enable_stdout {
192 expr = expr.stdout_null();
193 }
194
195 if !self.config.enable_stderr {
196 expr = expr.stderr_null();
197 }
198
199 expr.run()?;
200 lock_file.unlock()?;
201
202 Ok(())
203 }
204
205 pub fn graphql_query<Response>(&self, query: impl Into<String>) -> QueryBuilder<Response> {
215 let reqwest_builder = self
216 .http_client
217 .post(self.gateway_endpoint.clone())
218 .header(http::header::ACCEPT, "application/json");
219
220 QueryBuilder {
221 query: query.into(),
222 variables: None,
223 phantom: PhantomData,
224 reqwest_builder,
225 }
226 }
227
228 pub fn graphql_subscription<Response>(
237 &self,
238 query: impl Into<String>,
239 ) -> anyhow::Result<SubscriptionBuilder<Response>> {
240 let mut url = self.gateway_endpoint.clone();
241
242 url.set_path("/ws");
243 url.set_scheme("ws").unwrap();
244
245 let mut request_builder = url.as_ref().into_client_request()?;
246
247 request_builder
248 .headers_mut()
249 .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("graphql-transport-ws"));
250
251 let operation = Operation {
252 query: query.into(),
253 variables: None,
254 phantom: PhantomData,
255 };
256
257 Ok(SubscriptionBuilder {
258 operation,
259 request_builder,
260 })
261 }
262
263 pub fn federated_graph(&self) -> &str {
265 &self.federated_graph
266 }
267}
268
269pub(crate) fn free_port() -> anyhow::Result<u16> {
270 const INITIAL_PORT: u16 = 14712;
271
272 let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
273 std::fs::create_dir_all(&test_dir)?;
274
275 let lock_file_path = test_dir.join("port-number.lock");
276 let port_number_file_path = test_dir.join("port-number.txt");
277
278 let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
279 lock_file.lock()?;
280
281 let port = if port_number_file_path.exists() {
282 std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
283 } else {
284 INITIAL_PORT
285 };
286
287 std::fs::write(&port_number_file_path, port.to_string())?;
288 lock_file.unlock()?;
289
290 Ok(port)
291}
292
293pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
294 let port = free_port()?;
295 Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
296}
297
298impl Drop for TestRunner {
299 fn drop(&mut self) {
300 let Some(handle) = self.gateway_handle.take() else {
301 return;
302 };
303
304 if let Err(err) = handle.kill() {
305 eprintln!("Failed to kill grafbase-gateway: {}", err)
306 }
307 }
308}
309
310#[derive(serde::Serialize)]
311#[must_use]
312pub struct QueryBuilder<Response> {
314 query: String,
316 #[serde(skip_serializing_if = "Option::is_none")]
317 variables: Option<serde_json::Value>,
318
319 #[serde(skip)]
321 phantom: PhantomData<fn() -> Response>,
322 #[serde(skip)]
323 reqwest_builder: reqwest::RequestBuilder,
324}
325
326impl<Response> QueryBuilder<Response> {
327 pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
333 self.variables = Some(serde_json::to_value(variables).unwrap());
334 self
335 }
336
337 pub fn with_header(self, name: &str, value: &str) -> Self {
339 let Self {
340 phantom,
341 query,
342 mut reqwest_builder,
343 variables,
344 } = self;
345
346 reqwest_builder = reqwest_builder.header(name, value);
347
348 Self {
349 query,
350 variables,
351 phantom,
352 reqwest_builder,
353 }
354 }
355
356 pub async fn send(self) -> anyhow::Result<Response>
369 where
370 Response: for<'de> serde::Deserialize<'de>,
371 {
372 let json = serde_json::to_value(&self)?;
373 Ok(self.reqwest_builder.json(&json).send().await?.json().await?)
374 }
375}
376
377#[must_use]
378pub struct SubscriptionBuilder<Response> {
380 operation: Operation<Response>,
381 request_builder: Request,
382}
383
384#[derive(serde::Serialize)]
385struct Operation<Response> {
386 query: String,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 variables: Option<serde_json::Value>,
389 #[serde(skip)]
390 phantom: PhantomData<fn() -> Response>,
391}
392
393impl<Response> GraphqlOperation for Operation<Response>
394where
395 Response: DeserializeOwned,
396{
397 type Response = Response;
398 type Error = serde_json::Error;
399
400 fn decode(&self, data: serde_json::Value) -> Result<Self::Response, Self::Error> {
401 serde_json::from_value(data)
402 }
403}
404
405impl<Response> SubscriptionBuilder<Response>
406where
407 Response: DeserializeOwned + 'static,
408{
409 pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
415 self.operation.variables = Some(serde_json::to_value(variables).unwrap());
416 self
417 }
418
419 pub fn with_header<K>(mut self, name: K, value: HeaderValue) -> Self
426 where
427 K: IntoHeaderName,
428 {
429 self.request_builder.headers_mut().insert(name, value);
430 self
431 }
432
433 pub async fn subscribe(self) -> anyhow::Result<BoxStream<'static, Response>> {
445 let (connection, _) = async_tungstenite::tokio::connect_async(self.request_builder).await?;
446 let (client, actor) = graphql_ws_client::Client::build(connection).await?;
447
448 tokio::spawn(actor.into_future());
449
450 let stream = client
451 .subscribe(self.operation)
452 .await?
453 .map(move |item| -> Response { item.unwrap() });
454
455 Ok(Box::pin(stream))
456 }
457}