restate_sdk_testcontainers/
lib.rs1use anyhow::Context;
2use futures::FutureExt;
3use restate_sdk::prelude::{Endpoint, HttpServer};
4use serde::{Deserialize, Serialize};
5use testcontainers::core::wait::HttpWaitStrategy;
6use testcontainers::{
7 ContainerAsync, ContainerRequest, GenericImage, ImageExt,
8 core::{IntoContainerPort, WaitFor},
9 runners::AsyncRunner,
10};
11use tokio::{io::AsyncBufReadExt, net::TcpListener, task};
12use tracing::{error, info, warn};
13
14#[derive(Serialize, Deserialize, Debug)]
16pub struct RegisterDeploymentRequestHttp {
17 uri: String,
18 additional_headers: Option<Vec<(String, String)>>,
19 use_http_11: bool,
20 force: bool,
21 dry_run: bool,
22}
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct RegisterDeploymentRequestLambda {
26 arn: String,
27 assume_role_arn: Option<String>,
28 force: bool,
29 dry_run: bool,
30}
31
32pub struct TestEnvironment {
33 container_name: String,
34 container_tag: String,
35 logging: bool,
36}
37
38impl Default for TestEnvironment {
39 fn default() -> Self {
40 Self {
41 container_name: "docker.io/restatedev/restate".to_string(),
42 container_tag: "latest".to_string(),
43 logging: false,
44 }
45 }
46}
47
48impl TestEnvironment {
49 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn with_container_logging(mut self) -> Self {
56 self.logging = true;
57 self
58 }
59
60 pub fn with_container(mut self, container_name: String, container_tag: String) -> Self {
61 self.container_name = container_name;
62 self.container_tag = container_tag;
63
64 self
65 }
66
67 pub async fn start(self, endpoint: Endpoint) -> Result<StartedTestEnvironment, anyhow::Error> {
70 let started_endpoint = StartedEndpoint::serve_endpoint(endpoint).await?;
71 let started_restate_container = StartedRestateContainer::start_container(&self).await?;
72 if let Err(e) = started_restate_container
73 .register_endpoint(&started_endpoint)
74 .await
75 {
76 return Err(anyhow::anyhow!("Failed to register endpoint: {e}"));
77 }
78
79 Ok(StartedTestEnvironment {
80 _started_endpoint: started_endpoint,
81 started_restate_container,
82 })
83 }
84}
85
86struct StartedEndpoint {
87 port: u16,
88 _cancel_tx: tokio::sync::oneshot::Sender<()>,
89}
90
91impl StartedEndpoint {
92 async fn serve_endpoint(endpoint: Endpoint) -> Result<StartedEndpoint, anyhow::Error> {
93 info!("Starting endpoint server...");
94
95 let host_address = "0.0.0.0:0".to_string();
97 let listener = TcpListener::bind(host_address)
98 .await
99 .expect("listener can bind");
100 let listening_addr = listener.local_addr()?;
101 let endpoint_server_url =
102 format!("http://{}:{}", listening_addr.ip(), listening_addr.port());
103
104 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
106 tokio::spawn(async move {
107 HttpServer::new(endpoint)
108 .serve_with_cancel(listener, cancel_rx)
109 .await;
110 });
111
112 let client = reqwest::Client::builder().http2_prior_knowledge().build()?;
113
114 let mut retries = 0;
116 loop {
117 match client
118 .get(format!("{endpoint_server_url}/health",))
119 .send()
120 .await
121 {
122 Ok(res) if res.status().is_success() => break,
123 Ok(res) => {
124 warn!(
125 "Error when waiting for service endpoint server to be healthy, got response {}",
126 res.status()
127 );
128 retries += 1;
129 if retries > 10 {
130 anyhow::bail!("Service endpoint server failed to start")
131 }
132 }
133 Err(err) => {
134 warn!(
135 "Error when waiting for service endpoint server to be healthy, got error {}",
136 err
137 );
138 retries += 1;
139 if retries > 10 {
140 anyhow::bail!("Service endpoint server failed to start")
141 }
142 }
143 }
144 }
145
146 info!("Service endpoint server listening at: {endpoint_server_url}",);
147
148 Ok(StartedEndpoint {
149 port: listening_addr.port(),
150 _cancel_tx: cancel_tx,
151 })
152 }
153}
154
155struct StartedRestateContainer {
156 _cancel_tx: tokio::sync::oneshot::Sender<()>,
157 container: ContainerAsync<GenericImage>,
158 ingress_url: String,
159}
160
161impl StartedRestateContainer {
162 async fn start_container(
163 test_environment: &TestEnvironment,
164 ) -> Result<StartedRestateContainer, anyhow::Error> {
165 let image = GenericImage::new(
166 &test_environment.container_name,
167 &test_environment.container_tag,
168 )
169 .with_exposed_port(8080.tcp())
170 .with_exposed_port(9070.tcp())
171 .with_wait_for(WaitFor::Http(Box::new(
172 HttpWaitStrategy::new("/restate/health")
173 .with_port(8080.tcp())
174 .with_response_matcher(|res| res.status().is_success()),
175 )))
176 .with_wait_for(WaitFor::Http(Box::new(
177 HttpWaitStrategy::new("/health")
178 .with_port(9070.tcp())
179 .with_response_matcher(|res| res.status().is_success()),
180 )));
181
182 let container = ContainerRequest::from(image)
184 .with_host(
187 "host.docker.internal",
188 testcontainers::core::Host::HostGateway,
189 )
190 .start()
191 .await?;
192
193 let (cancel_tx, cancel_rx) = tokio::sync::oneshot::channel();
194 if test_environment.logging {
195 let container_stdout = container.stdout(true);
196 let mut stdout_lines = container_stdout.lines();
197 let container_stderr = container.stderr(true);
198 let mut stderr_lines = container_stderr.lines();
199
200 task::spawn(async move {
202 tokio::pin!(cancel_rx);
203 loop {
204 tokio::select! {
205 Some(stdout_line) = stdout_lines.next_line().map(|res| res.transpose()) => {
206 match stdout_line {
207 Ok(line) => info!("{}", line),
208 Err(e) => {
209 error!("Error reading stdout from container stream: {}", e);
210 break;
211 }
212 }
213 },
214 Some(stderr_line) = stderr_lines.next_line().map(|res| res.transpose()) => {
215 match stderr_line {
216 Ok(line) => warn!("{}", line),
217 Err(e) => {
218 error!("Error reading stderr from container stream: {}", e);
219 break;
220 }
221 }
222 }
223 _ = &mut cancel_rx => {
224 break;
225 }
226 }
227 }
228 });
229 }
230
231 let host = container.get_host().await?;
233 let ports = container.ports().await?;
234 let ingress_port = ports.map_to_host_port_ipv4(8080.tcp()).unwrap();
235 let ingress_url = format!("http://{}:{}", host, ingress_port);
236
237 info!("Restate container started, listening on requests at {ingress_url}");
238
239 Ok(StartedRestateContainer {
240 _cancel_tx: cancel_tx,
241 container,
242 ingress_url,
243 })
244 }
245
246 async fn register_endpoint(&self, endpoint: &StartedEndpoint) -> Result<(), anyhow::Error> {
247 let host = self.container.get_host().await?;
248 let ports = self.container.ports().await?;
249 let admin_port = ports.map_to_host_port_ipv4(9070.tcp()).unwrap();
250
251 let client = reqwest::Client::builder().http2_prior_knowledge().build()?;
252
253 let deployment_uri: String = format!("http://host.docker.internal:{}/", endpoint.port);
254 let deployment_payload = RegisterDeploymentRequestHttp {
255 uri: deployment_uri,
256 additional_headers: None,
257 use_http_11: false,
258 force: false,
259 dry_run: false,
260 };
261
262 let register_admin_url = format!("http://{}:{}/deployments", host, admin_port);
263
264 let response = client
265 .post(register_admin_url)
266 .json(&deployment_payload)
267 .send()
268 .await
269 .context("Error when trying to register the service endpoint")?;
270
271 if !response.status().is_success() {
272 anyhow::bail!(
273 "Got non success status code when trying to register the service endpoint: {}",
274 response.status()
275 )
276 }
277
278 Ok(())
279 }
280}
281
282pub struct StartedTestEnvironment {
283 _started_endpoint: StartedEndpoint,
284 started_restate_container: StartedRestateContainer,
285}
286
287impl StartedTestEnvironment {
288 pub fn ingress_url(&self) -> String {
289 self.started_restate_container.ingress_url.clone()
290 }
291}