Skip to main content

restate_sdk_testcontainers/
lib.rs

1use anyhow::Context;
2use futures::FutureExt;
3use restate_sdk::prelude::{Endpoint, HttpServer};
4use serde::{Deserialize, Serialize};
5use testcontainers::core::wait::HttpWaitStrategy;
6use testcontainers::{
7    ContainerAsync, ContainerRequest, GenericImage, ImageExt,
8    core::{IntoContainerPort, WaitFor},
9    runners::AsyncRunner,
10};
11use tokio::{io::AsyncBufReadExt, net::TcpListener, task};
12use tracing::{error, info, warn};
13
14// From restate-admin-rest-model
15#[derive(Serialize, Deserialize, Debug)]
16pub struct RegisterDeploymentRequestHttp {
17    uri: String,
18    additional_headers: Option<Vec<(String, String)>>,
19    use_http_11: bool,
20    force: bool,
21    dry_run: bool,
22}
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct RegisterDeploymentRequestLambda {
26    arn: String,
27    assume_role_arn: Option<String>,
28    force: bool,
29    dry_run: bool,
30}
31
32pub struct TestEnvironment {
33    container_name: String,
34    container_tag: String,
35    logging: bool,
36}
37
38impl Default for TestEnvironment {
39    fn default() -> Self {
40        Self {
41            container_name: "docker.io/restatedev/restate".to_string(),
42            container_tag: "latest".to_string(),
43            logging: false,
44        }
45    }
46}
47
48impl TestEnvironment {
49    // --- Builder methods
50
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    pub fn with_container_logging(mut self) -> Self {
56        self.logging = true;
57        self
58    }
59
60    pub fn with_container(mut self, container_name: String, container_tag: String) -> Self {
61        self.container_name = container_name;
62        self.container_tag = container_tag;
63
64        self
65    }
66
67    // --- Start method
68
69    pub async fn start(self, endpoint: Endpoint) -> Result<StartedTestEnvironment, anyhow::Error> {
70        let started_endpoint = StartedEndpoint::serve_endpoint(endpoint).await?;
71        let started_restate_container = StartedRestateContainer::start_container(&self).await?;
72        if let Err(e) = started_restate_container
73            .register_endpoint(&started_endpoint)
74            .await
75        {
76            return Err(anyhow::anyhow!("Failed to register endpoint: {e}"));
77        }
78
79        Ok(StartedTestEnvironment {
80            _started_endpoint: started_endpoint,
81            started_restate_container,
82        })
83    }
84}
85
86struct StartedEndpoint {
87    port: u16,
88    _cancel_tx: tokio::sync::oneshot::Sender<()>,
89}
90
91impl StartedEndpoint {
92    async fn serve_endpoint(endpoint: Endpoint) -> Result<StartedEndpoint, anyhow::Error> {
93        info!("Starting endpoint server...");
94
95        // 0.0.0.0:0 will listen on a random port, both IPv4 and IPv6
96        let host_address = "0.0.0.0:0".to_string();
97        let listener = TcpListener::bind(host_address)
98            .await
99            .expect("listener can bind");
100        let listening_addr = listener.local_addr()?;
101        let endpoint_server_url =
102            format!("http://{}:{}", listening_addr.ip(), listening_addr.port());
103
104        // Start endpoint server
105        let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
106        tokio::spawn(async move {
107            HttpServer::new(endpoint)
108                .serve_with_cancel(listener, cancel_rx)
109                .await;
110        });
111
112        let client = reqwest::Client::builder().http2_prior_knowledge().build()?;
113
114        // wait for endpoint server to respond
115        let mut retries = 0;
116        loop {
117            match client
118                .get(format!("{endpoint_server_url}/health",))
119                .send()
120                .await
121            {
122                Ok(res) if res.status().is_success() => break,
123                Ok(res) => {
124                    warn!(
125                        "Error when waiting for service endpoint server to be healthy, got response {}",
126                        res.status()
127                    );
128                    retries += 1;
129                    if retries > 10 {
130                        anyhow::bail!("Service endpoint server failed to start")
131                    }
132                }
133                Err(err) => {
134                    warn!(
135                        "Error when waiting for service endpoint server to be healthy, got error {}",
136                        err
137                    );
138                    retries += 1;
139                    if retries > 10 {
140                        anyhow::bail!("Service endpoint server failed to start")
141                    }
142                }
143            }
144        }
145
146        info!("Service endpoint server listening at: {endpoint_server_url}",);
147
148        Ok(StartedEndpoint {
149            port: listening_addr.port(),
150            _cancel_tx: cancel_tx,
151        })
152    }
153}
154
155struct StartedRestateContainer {
156    _cancel_tx: tokio::sync::oneshot::Sender<()>,
157    container: ContainerAsync<GenericImage>,
158    ingress_url: String,
159}
160
161impl StartedRestateContainer {
162    async fn start_container(
163        test_environment: &TestEnvironment,
164    ) -> Result<StartedRestateContainer, anyhow::Error> {
165        let image = GenericImage::new(
166            &test_environment.container_name,
167            &test_environment.container_tag,
168        )
169        .with_exposed_port(8080.tcp())
170        .with_exposed_port(9070.tcp())
171        .with_wait_for(WaitFor::Http(Box::new(
172            HttpWaitStrategy::new("/restate/health")
173                .with_port(8080.tcp())
174                .with_response_matcher(|res| res.status().is_success()),
175        )))
176        .with_wait_for(WaitFor::Http(Box::new(
177            HttpWaitStrategy::new("/health")
178                .with_port(9070.tcp())
179                .with_response_matcher(|res| res.status().is_success()),
180        )));
181
182        // Start container
183        let container = ContainerRequest::from(image)
184            // have to expose entire host network because testcontainer-rs doesn't implement selective SSH port forward from host
185            // see https://github.com/testcontainers/testcontainers-rs/issues/535
186            .with_host(
187                "host.docker.internal",
188                testcontainers::core::Host::HostGateway,
189            )
190            .start()
191            .await?;
192
193        let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
194        if test_environment.logging {
195            let container_stdout = container.stdout(true);
196            let mut stdout_lines = container_stdout.lines();
197            let container_stderr = container.stderr(true);
198            let mut stderr_lines = container_stderr.lines();
199
200            // Spawn a task to copy data from the AsyncBufRead to stdout
201            task::spawn(async move {
202                tokio::pin!(cancel_rx);
203                loop {
204                    tokio::select! {
205                        Some(stdout_line) = stdout_lines.next_line().map(|res| res.transpose()) => {
206                            match stdout_line {
207                                Ok(line) => info!("{}", line),
208                                Err(e) => {
209                                    error!("Error reading stdout from container stream: {}", e);
210                                    break;
211                                }
212                            }
213                        },
214                        Some(stderr_line) = stderr_lines.next_line().map(|res| res.transpose()) => {
215                            match stderr_line {
216                                Ok(line) => warn!("{}", line),
217                                Err(e) => {
218                                    error!("Error reading stderr from container stream: {}", e);
219                                    break;
220                                }
221                            }
222                        }
223                        _ = &mut cancel_rx => {
224                            break;
225                        }
226                    }
227                }
228            });
229        }
230
231        // Resolve ingress url
232        let host = container.get_host().await?;
233        let ports = container.ports().await?;
234        let ingress_port = ports.map_to_host_port_ipv4(8080.tcp()).unwrap();
235        let ingress_url = format!("http://{}:{}", host, ingress_port);
236
237        info!("Restate container started, listening on requests at {ingress_url}");
238
239        Ok(StartedRestateContainer {
240            _cancel_tx: cancel_tx,
241            container,
242            ingress_url,
243        })
244    }
245
246    async fn register_endpoint(&self, endpoint: &StartedEndpoint) -> Result<(), anyhow::Error> {
247        let host = self.container.get_host().await?;
248        let ports = self.container.ports().await?;
249        let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap();
250
251        let client = reqwest::Client::builder().http2_prior_knowledge().build()?;
252
253        let deployment_uri: String = format!("http://host.docker.internal:{}/", endpoint.port);
254        let deployment_payload = RegisterDeploymentRequestHttp {
255            uri: deployment_uri,
256            additional_headers: None,
257            use_http_11: false,
258            force: false,
259            dry_run: false,
260        };
261
262        let register_admin_url = format!("http://{}:{}/deployments", host, admin_port);
263
264        let response = client
265            .post(register_admin_url)
266            .json(&deployment_payload)
267            .send()
268            .await
269            .context("Error when trying to register the service endpoint")?;
270
271        if !response.status().is_success() {
272            anyhow::bail!(
273                "Got non success status code when trying to register the service endpoint: {}",
274                response.status()
275            )
276        }
277
278        Ok(())
279    }
280}
281
282pub struct StartedTestEnvironment {
283    _started_endpoint: StartedEndpoint,
284    started_restate_container: StartedRestateContainer,
285}
286
287impl StartedTestEnvironment {
288    pub fn ingress_url(&self) -> String {
289        self.started_restate_container.ingress_url.clone()
290    }
291}