grafbase_sdk/test/
runner.rs

1use std::{
2    future::IntoFuture,
3    marker::PhantomData,
4    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
5    path::Path,
6    time::Duration,
7};
8
9use super::TestConfig;
10use async_tungstenite::tungstenite::handshake::client::Request;
11use futures_util::{stream::BoxStream, StreamExt};
12use grafbase_sdk_mock::{MockGraphQlServer, MockSubgraph};
13use graphql_composition::{LoadedExtension, Subgraphs};
14use graphql_ws_client::graphql::GraphqlOperation;
15use http::{
16    header::{IntoHeaderName, SEC_WEBSOCKET_PROTOCOL},
17    HeaderValue,
18};
19use serde::de::DeserializeOwned;
20use tempfile::TempDir;
21use tungstenite::client::IntoClientRequest;
22use url::Url;
23
24/// A test runner that can start a gateway and execute GraphQL queries against it.
25pub struct TestRunner {
26    http_client: reqwest::Client,
27    config: TestConfig,
28    gateway_handle: Option<duct::Handle>,
29    gateway_listen_address: SocketAddr,
30    gateway_endpoint: Url,
31    test_specific_temp_dir: TempDir,
32    _mock_subgraphs: Vec<MockGraphQlServer>,
33    federated_graph: String,
34}
35
36#[derive(Debug, serde::Deserialize)]
37struct ExtensionToml {
38    extension: ExtensionDefinition,
39}
40
41#[derive(Debug, serde::Deserialize)]
42struct ExtensionDefinition {
43    name: String,
44}
45
46impl TestRunner {
47    /// Creates a new [`TestRunner`] with the given [`TestConfig`].
48    pub async fn new(mut config: TestConfig) -> anyhow::Result<Self> {
49        let test_specific_temp_dir = tempfile::Builder::new().prefix("sdk-tests").tempdir()?;
50        let gateway_listen_address = listen_address()?;
51        let gateway_endpoint = Url::parse(&format!("http://{}/graphql", gateway_listen_address))?;
52
53        let extension_toml_path = std::env::current_dir()?.join("extension.toml");
54        let extension_toml = std::fs::read_to_string(&extension_toml_path)?;
55        let extension_toml: ExtensionToml = toml::from_str(&extension_toml)?;
56        let extension_name = extension_toml.extension.name;
57
58        let mut mock_subgraphs = Vec::new();
59        let mut subgraphs = Subgraphs::default();
60
61        let extension_path = match config.extension_path {
62            Some(ref path) => path.to_path_buf(),
63            None => std::env::current_dir()?.join("build"),
64        };
65
66        subgraphs.ingest_loaded_extensions(std::iter::once(LoadedExtension::new(
67            format!("file://{}", extension_path.display()),
68            extension_name.clone(),
69        )));
70
71        for subgraph in config.mock_subgraphs.drain(..) {
72            match subgraph {
73                MockSubgraph::Dynamic(subgraph) => {
74                    let mock_graph = subgraph.start().await;
75                    subgraphs.ingest_str(mock_graph.sdl(), mock_graph.name(), Some(mock_graph.url().as_str()))?;
76                    mock_subgraphs.push(mock_graph);
77                }
78                MockSubgraph::ExtensionOnly(subgraph) => {
79                    subgraphs.ingest_str(subgraph.sdl(), subgraph.name(), None)?;
80                }
81            }
82        }
83
84        let federated_graph = graphql_composition::compose(&subgraphs).into_result().unwrap();
85        let federated_graph = graphql_federated_graph::render_federated_sdl(&federated_graph)?;
86
87        let mut this = Self {
88            http_client: reqwest::Client::new(),
89            config,
90            gateway_handle: None,
91            gateway_listen_address,
92            gateway_endpoint,
93            test_specific_temp_dir,
94            _mock_subgraphs: mock_subgraphs,
95            federated_graph,
96        };
97
98        this.build_extension(&extension_path)?;
99        this.start_servers(&extension_name, &extension_path).await?;
100
101        Ok(this)
102    }
103
104    async fn start_servers(&mut self, extension_name: &str, extension_path: &Path) -> anyhow::Result<()> {
105        let extension_path = extension_path.display();
106        let config_path = self.test_specific_temp_dir.path().join("grafbase.toml");
107        let schema_path = self.test_specific_temp_dir.path().join("federated-schema.graphql");
108        let config = &self.config.gateway_configuration;
109        let enable_stdout = self.config.enable_stdout;
110        let enable_stderr = self.config.enable_stdout;
111        let enable_networking = self.config.enable_networking;
112        let enable_environment_variables = self.config.enable_environment_variables;
113        let max_pool_size = self.config.max_pool_size.unwrap_or(100);
114
115        let config = indoc::formatdoc! {r#"
116            [extensions.{extension_name}]
117            path = "{extension_path}"
118            stdout = {enable_stdout}
119            stderr = {enable_stderr}
120            networking = {enable_networking}
121            environment_variables = {enable_environment_variables}
122            max_pool_size = {max_pool_size}
123
124            {config}
125        "#};
126
127        println!("{config}");
128
129        std::fs::write(&config_path, config.as_bytes())?;
130        std::fs::write(&schema_path, self.federated_graph.as_bytes())?;
131
132        let args = &[
133            "--listen-address",
134            &self.gateway_listen_address.to_string(),
135            "--config",
136            &config_path.to_string_lossy(),
137            "--schema",
138            &schema_path.to_string_lossy(),
139            "--log",
140            self.config.log_level.as_ref(),
141        ];
142
143        let mut expr = duct::cmd(&self.config.gateway_path, args);
144
145        if !self.config.enable_stderr {
146            expr = expr.stderr_null();
147        }
148
149        if !self.config.enable_stdout {
150            expr = expr.stdout_null();
151        }
152
153        self.gateway_handle = Some(expr.start()?);
154
155        let mut i = 0;
156        while !self.check_gateway_health().await? {
157            // printing every second only
158            if i % 10 == 0 {
159                println!("Waiting for gateway to be ready...");
160            }
161            i += 1;
162            std::thread::sleep(Duration::from_millis(100));
163        }
164
165        Ok(())
166    }
167
168    async fn check_gateway_health(&self) -> anyhow::Result<bool> {
169        let url = self.gateway_endpoint.join("/health")?;
170
171        let Ok(result) = self.http_client.get(url).send().await else {
172            return Ok(false);
173        };
174
175        let result = result.error_for_status().is_ok();
176
177        Ok(result)
178    }
179
180    fn build_extension(&mut self, extension_path: &Path) -> anyhow::Result<()> {
181        let extension_path = extension_path.to_string_lossy();
182
183        // Only one test can build the extension at a time. The others must
184        // wait.
185        let mut lock_file = fslock::LockFile::open(".build.lock")?;
186        lock_file.lock()?;
187
188        let args = &["extension", "build", "--debug", "--output-dir", &*extension_path];
189        let mut expr = duct::cmd(&self.config.cli_path, args);
190
191        if !self.config.enable_stdout {
192            expr = expr.stdout_null();
193        }
194
195        if !self.config.enable_stderr {
196            expr = expr.stderr_null();
197        }
198
199        expr.run()?;
200        lock_file.unlock()?;
201
202        Ok(())
203    }
204
205    /// Creates a new GraphQL query builder with the given query.
206    ///
207    /// # Arguments
208    ///
209    /// * `query` - The GraphQL query string to execute
210    ///
211    /// # Returns
212    ///
213    /// A [`QueryBuilder`] that can be used to customize and execute the query
214    pub fn graphql_query<Response>(&self, query: impl Into<String>) -> QueryBuilder<Response> {
215        let reqwest_builder = self
216            .http_client
217            .post(self.gateway_endpoint.clone())
218            .header(http::header::ACCEPT, "application/json");
219
220        QueryBuilder {
221            query: query.into(),
222            variables: None,
223            phantom: PhantomData,
224            reqwest_builder,
225        }
226    }
227
228    ///
229    /// # Arguments
230    ///
231    /// * `query` - The GraphQL subscription query string to execute
232    ///
233    /// # Returns
234    ///
235    /// A [`SubscriptionBuilder`] that can be used to customize and execute the subscription
236    pub fn graphql_subscription<Response>(
237        &self,
238        query: impl Into<String>,
239    ) -> anyhow::Result<SubscriptionBuilder<Response>> {
240        let mut url = self.gateway_endpoint.clone();
241
242        url.set_path("/ws");
243        url.set_scheme("ws").unwrap();
244
245        let mut request_builder = url.as_ref().into_client_request()?;
246
247        request_builder
248            .headers_mut()
249            .insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("graphql-transport-ws"));
250
251        let operation = Operation {
252            query: query.into(),
253            variables: None,
254            phantom: PhantomData,
255        };
256
257        Ok(SubscriptionBuilder {
258            operation,
259            request_builder,
260        })
261    }
262
263    /// Returns the federated schema as a string.
264    pub fn federated_graph(&self) -> &str {
265        &self.federated_graph
266    }
267}
268
269pub(crate) fn free_port() -> anyhow::Result<u16> {
270    const INITIAL_PORT: u16 = 14712;
271
272    let test_dir = std::env::temp_dir().join("grafbase/sdk-tests");
273    std::fs::create_dir_all(&test_dir)?;
274
275    let lock_file_path = test_dir.join("port-number.lock");
276    let port_number_file_path = test_dir.join("port-number.txt");
277
278    let mut lock_file = fslock::LockFile::open(&lock_file_path)?;
279    lock_file.lock()?;
280
281    let port = if port_number_file_path.exists() {
282        std::fs::read_to_string(&port_number_file_path)?.trim().parse::<u16>()? + 1
283    } else {
284        INITIAL_PORT
285    };
286
287    std::fs::write(&port_number_file_path, port.to_string())?;
288    lock_file.unlock()?;
289
290    Ok(port)
291}
292
293pub(crate) fn listen_address() -> anyhow::Result<SocketAddr> {
294    let port = free_port()?;
295    Ok(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), port)))
296}
297
298impl Drop for TestRunner {
299    fn drop(&mut self) {
300        let Some(handle) = self.gateway_handle.take() else {
301            return;
302        };
303
304        if let Err(err) = handle.kill() {
305            eprintln!("Failed to kill grafbase-gateway: {}", err)
306        }
307    }
308}
309
310#[derive(serde::Serialize)]
311#[must_use]
312/// A builder for constructing GraphQL queries with customizable parameters and headers.
313pub struct QueryBuilder<Response> {
314    // These two will be serialized into the request
315    query: String,
316    #[serde(skip_serializing_if = "Option::is_none")]
317    variables: Option<serde_json::Value>,
318
319    // These won't
320    #[serde(skip)]
321    phantom: PhantomData<fn() -> Response>,
322    #[serde(skip)]
323    reqwest_builder: reqwest::RequestBuilder,
324}
325
326impl<Response> QueryBuilder<Response> {
327    /// Adds variables to the GraphQL query.
328    ///
329    /// # Arguments
330    ///
331    /// * `variables` - The variables to include with the query, serializable to JSON
332    pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
333        self.variables = Some(serde_json::to_value(variables).unwrap());
334        self
335    }
336
337    /// Adds a header to the GraphQL request.
338    pub fn with_header(self, name: &str, value: &str) -> Self {
339        let Self {
340            phantom,
341            query,
342            mut reqwest_builder,
343            variables,
344        } = self;
345
346        reqwest_builder = reqwest_builder.header(name, value);
347
348        Self {
349            query,
350            variables,
351            phantom,
352            reqwest_builder,
353        }
354    }
355
356    /// Sends the GraphQL query and returns the response.
357    ///
358    /// # Returns
359    ///
360    /// The deserialized response from the GraphQL server
361    ///
362    /// # Errors
363    ///
364    /// Will return an error if:
365    /// - Request serialization fails
366    /// - Network request fails
367    /// - Response deserialization fails
368    pub async fn send(self) -> anyhow::Result<Response>
369    where
370        Response: for<'de> serde::Deserialize<'de>,
371    {
372        let json = serde_json::to_value(&self)?;
373        Ok(self.reqwest_builder.json(&json).send().await?.json().await?)
374    }
375}
376
377#[must_use]
378/// A builder for constructing GraphQL queries with customizable parameters and headers.
379pub struct SubscriptionBuilder<Response> {
380    operation: Operation<Response>,
381    request_builder: Request,
382}
383
384#[derive(serde::Serialize)]
385struct Operation<Response> {
386    query: String,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    variables: Option<serde_json::Value>,
389    #[serde(skip)]
390    phantom: PhantomData<fn() -> Response>,
391}
392
393impl<Response> GraphqlOperation for Operation<Response>
394where
395    Response: DeserializeOwned,
396{
397    type Response = Response;
398    type Error = serde_json::Error;
399
400    fn decode(&self, data: serde_json::Value) -> Result<Self::Response, Self::Error> {
401        serde_json::from_value(data)
402    }
403}
404
405impl<Response> SubscriptionBuilder<Response>
406where
407    Response: DeserializeOwned + 'static,
408{
409    /// Adds variables to the GraphQL subscription.
410    ///
411    /// # Arguments
412    ///
413    /// * `variables` - The variables to include with the subscription, serializable to JSON
414    pub fn with_variables(mut self, variables: impl serde::Serialize) -> Self {
415        self.operation.variables = Some(serde_json::to_value(variables).unwrap());
416        self
417    }
418
419    /// Adds a header to the GraphQL request.
420    ///
421    /// # Arguments
422    ///
423    /// * `name` - The header name
424    /// * `value` - The header value
425    pub fn with_header<K, V>(mut self, name: K, value: HeaderValue) -> Self
426    where
427        K: IntoHeaderName,
428    {
429        self.request_builder.headers_mut().insert(name, value);
430        self
431    }
432
433    /// Subscribes to the GraphQL subscription and returns a stream of responses.
434    ///
435    /// # Returns
436    ///
437    /// A pinned stream that yields the deserialized subscription responses
438    ///
439    /// # Errors
440    ///
441    /// Will return an error if:
442    /// - WebSocket connection fails
443    /// - GraphQL subscription initialization fails
444    pub async fn subscribe(self) -> anyhow::Result<BoxStream<'static, Response>> {
445        let (connection, _) = async_tungstenite::tokio::connect_async(self.request_builder).await?;
446        let (client, actor) = graphql_ws_client::Client::build(connection).await?;
447
448        tokio::spawn(actor.into_future());
449
450        let stream = client
451            .subscribe(self.operation)
452            .await?
453            .map(move |item| -> Response { item.unwrap() });
454
455        Ok(Box::pin(stream))
456    }
457}