Skip to main content

pdk_test/
composite.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5//! Test composite orchestration
6//!
7//! This module provides the main test environment orchestration functionality.
8//! It manages Docker containers, networks, and services in a coordinated way
9//! for integration testing scenarios.
10//!
11//! ## Primary types
12//!
13//! - [`TestComposite`]: orchestrates test environment lifecycle
14//! - [`TestCompositeBuilder`]: configures and builds test environments
15//!
16
17use std::any::{type_name, Any, TypeId};
18use std::collections::hash_map::Entry;
19use std::collections::HashMap;
20use std::rc::Rc;
21use std::time::Duration;
22
23use crate::cleanup::Cleanup;
24use bollard::Docker;
25use futures::future::{join_all, try_join_all};
26use tokio::runtime::{Handle, RuntimeFlavor};
27
28use crate::config::Config;
29use crate::container::Container;
30use crate::error::TestError;
31use crate::host::Host;
32use crate::network::Network;
33use crate::runner::Test;
34use crate::service::Service;
35use crate::services::httpmock::HttpMockConfig;
36
37struct UntypedConfig {
38    erased: Rc<dyn Any>,
39    source: Rc<dyn Config>,
40}
41
42impl UntypedConfig {
43    fn new<T: Config + 'static>(config: T) -> Self {
44        let erased: Rc<dyn Any> = Rc::new(config);
45        let source: Rc<dyn Config> = erased.clone().downcast::<T>().unwrap();
46        Self { erased, source }
47    }
48
49    fn upcast(&self) -> &dyn Config {
50        self.source.as_ref()
51    }
52
53    fn downcast<T: Config + 'static>(&self) -> &T {
54        self.erased.downcast_ref().unwrap()
55    }
56}
57
58struct Inner {
59    configs: HashMap<TypeId, HashMap<String, UntypedConfig>>,
60    containers: HashMap<String, Container>,
61    network: Network,
62    test: Rc<Test>,
63}
64
65impl Inner {
66    fn configs<T: Service + 'static>(&self) -> Result<&HashMap<String, UntypedConfig>, TestError> {
67        self.configs
68            .get(&TypeId::of::<T::Config>())
69            .ok_or(TestError::UnknownService(type_name::<T>()))
70    }
71
72    fn service<T: Service + 'static>(&self) -> Result<T, TestError> {
73        let config = self.configs::<T>()?.values().next().unwrap();
74        let container = self.containers.get(config.upcast().hostname()).unwrap();
75        Ok(T::new(config.downcast(), container))
76    }
77
78    fn service_by_hostname<T: Service + 'static>(&self, hostname: &str) -> Result<T, TestError> {
79        let config = self
80            .configs::<T>()?
81            .get(hostname)
82            .ok_or_else(|| TestError::UnknownServiceHostname(hostname.to_string()))?;
83        let container = self.containers.get(config.upcast().hostname()).unwrap();
84        Ok(T::new(config.downcast(), container))
85    }
86}
87
88/// Main test environment orchestrator for integration tests.
89pub struct TestComposite {
90    inner: Option<Inner>,
91}
92
93impl TestComposite {
94    /// Creates a new builder for configuring test environments.
95    pub fn builder() -> TestCompositeBuilder {
96        TestCompositeBuilder::new()
97    }
98
99    /// Gets a service instance by its type.
100    /// Returns the first configured service of the specified type.
101    pub fn service<T: Service + 'static>(&self) -> Result<T, TestError> {
102        self.inner().service()
103    }
104
105    /// Gets a service instance by its hostname.
106    /// Returns the service with the specified hostname and type.
107    pub fn service_by_hostname<T: Service + 'static>(&self, name: &str) -> Result<T, TestError> {
108        self.inner().service_by_hostname(name)
109    }
110
111    fn inner(&self) -> &Inner {
112        self.inner.as_ref().unwrap()
113    }
114}
115
116fn check_runtime() -> Result<(), TestError> {
117    let handle = Handle::try_current()?;
118    if !matches!(handle.runtime_flavor(), RuntimeFlavor::MultiThread) {
119        return Err(TestError::UnavailableMultiThread);
120    }
121    Ok(())
122}
123
124async fn check_docker(docker: &Docker) -> Result<(), TestError> {
125    match docker.ping().await {
126        Ok(_) => Ok(()),
127        Err(err) => Err(TestError::UnavailableDocker(err.into())),
128    }
129}
130
131/// Builder for configuring and creating test environments.
132///
133/// This builder is used to configure the test environment with the services
134/// that will be used in the test. It provides an API for adding services
135/// and configuring them.
136pub struct TestCompositeBuilder {
137    configs: HashMap<TypeId, HashMap<String, UntypedConfig>>,
138}
139
140impl TestCompositeBuilder {
141    fn new() -> Self {
142        Self {
143            configs: HashMap::new(),
144        }
145    }
146
147    /// Configures a service with the provided configuration. The service will be
148    /// started when the test environment is built. Each service must have a unique hostname.
149    pub fn with_service<C: Config + 'static>(mut self, config: C) -> Self {
150        let entry = self
151            .configs
152            .entry(TypeId::of::<C>())
153            .or_default()
154            .entry(config.hostname().to_string());
155
156        match entry {
157            Entry::Occupied(_) => panic!("Name {} configured twice", config.hostname()),
158            Entry::Vacant(e) => e.insert(UntypedConfig::new(config)),
159        };
160
161        self
162    }
163
164    /// Builds the test environment with all configured services.
165    /// Starts Docker containers, creates networks and initializes all services.
166    /// Returns a `TestComposite` that can be used to access the services.
167    pub async fn build(self) -> Result<TestComposite, TestError> {
168        check_runtime()?;
169
170        let httpmock_configs_len = self
171            .configs
172            .get(&TypeId::of::<HttpMockConfig>())
173            .map(|f| f.len())
174            .unwrap_or(0);
175
176        if httpmock_configs_len > 1 {
177            return Err(TestError::NotSupportedConfig(
178                "Only 1 HttpMock can be defined per test".to_string(),
179            ));
180        }
181
182        let test = Test::current()?;
183        log::info!(
184            "Framework starting environment module={} test={}",
185            test.module(),
186            test.name()
187        );
188
189        let docker = Docker::connect_with_local_defaults()?;
190        check_docker(&docker).await?;
191        log::debug!("Framework docker ping OK");
192
193        let host = Host::current(&docker).await?;
194        log::debug!("Framework host mode = {:?}", host.mode());
195
196        Cleanup::new(docker.clone()).purge().await?;
197
198        let mut network = Network::new(docker.clone()).await?;
199        log::info!("Framework created docker network id={}", network.id());
200
201        if let Some(host_container) = host.container() {
202            log::info!("Creating testing environment in containerized mode.");
203
204            // Containerized mode connects the current container into the network.
205            network.connect(host_container.id()).await?;
206        } else {
207            log::info!("Creating testing environment in standalone mode.");
208        }
209
210        let starts = self.configs.iter().flat_map(|(_, configs)| {
211            configs.values().map(|config| {
212                log::info!(
213                    "Framework initializing service hostname={}",
214                    config.upcast().hostname()
215                );
216                Container::initialized(
217                    docker.clone(),
218                    test.clone(),
219                    host.mode(),
220                    &network,
221                    config.upcast(),
222                )
223            })
224        });
225
226        let containers = try_join_all(starts)
227            .await?
228            .into_iter()
229            .map(|c| (c.config().hostname().to_string(), c));
230
231        Ok(TestComposite {
232            inner: Some(Inner {
233                configs: self.configs,
234                containers: containers.collect(),
235                network,
236                test: test.clone(),
237            }),
238        })
239    }
240}
241
242impl Drop for TestComposite {
243    fn drop(&mut self) {
244        let Inner {
245            mut network,
246            containers,
247            test,
248            ..
249        } = self.inner.take().unwrap();
250        tokio::task::block_in_place(|| {
251            log::info!("Dropping testing environment.");
252
253            Handle::current().block_on(async {
254                if !test.is_success() {
255                    tokio::time::sleep(Duration::from_secs(1)).await;
256                }
257                join_all(containers.into_values().map(|mut container| async move {
258                    container.dispose().await;
259                }))
260                .await;
261                network.remove().await;
262            })
263        });
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use std::collections::HashMap;
270
271    use bollard::container::{CreateContainerOptions, NetworkingConfig};
272    use bollard::errors::Error as BollardError;
273    use bollard::network::CreateNetworkOptions;
274    use bollard::secret::EndpointSettings;
275    use bollard::Docker;
276
277    use crate::constants::NETWORK_NAME;
278    use crate::error::TestError;
279    use crate::image::Image;
280    use crate::runner::Test;
281    use crate::services::httpbin::HttpBinConfig;
282
283    use super::TestComposite;
284
285    #[tokio::test]
286    async fn multi_thread_required_error() {
287        let result = TestComposite::builder().build().await;
288        assert!(matches!(result, Err(TestError::UnavailableMultiThread)));
289    }
290
291    #[test]
292    fn runtime_required() {
293        let result = futures::executor::block_on(TestComposite::builder().build());
294        assert!(matches!(result, Err(TestError::UnavailableRuntime(_))));
295    }
296
297    #[test]
298    fn create_container_logs() {
299        let test = Test::builder().module("foo").name("bar").build();
300
301        let target_dir = test.target_dir().to_owned();
302        let _ = test.run(async {
303            let s1 = HttpBinConfig::builder().hostname("service-1").build();
304            let s2 = HttpBinConfig::builder().hostname("service-2").build();
305
306            let _ = TestComposite::builder()
307                .with_service(s1)
308                .with_service(s2)
309                .build()
310                .await?;
311
312            assert!(target_dir.join("service-1.log").exists());
313            assert!(target_dir.join("service-2.log").exists());
314
315            Ok::<_, TestError>(())
316        });
317
318        assert!(!target_dir.join("service-1.log").exists());
319        assert!(!target_dir.join("service-2.log").exists());
320    }
321
322    #[test]
323    fn drop_network() {
324        let docker = Docker::connect_with_local_defaults().unwrap();
325        let test = Test::builder().module("foo").name("bar").build();
326
327        let _ = test.run(async {
328            let s1 = HttpBinConfig::builder().hostname("service-1").build();
329            let s2 = HttpBinConfig::builder().hostname("service-2").build();
330
331            let _tc = TestComposite::builder()
332                .with_service(s1)
333                .with_service(s2)
334                .build()
335                .await?;
336
337            // Check created network
338            let result = docker.inspect_network::<String>(NETWORK_NAME, None).await;
339            assert!(result.is_ok());
340
341            Ok::<_, TestError>(())
342        });
343
344        // Check network deletion
345        let runtime = tokio::runtime::Runtime::new().unwrap();
346        let result = runtime.block_on(docker.inspect_network::<String>(NETWORK_NAME, None));
347
348        assert!(matches!(
349            result,
350            Err(BollardError::DockerResponseServerError {
351                status_code: 404,
352                ..
353            })
354        ));
355    }
356
357    #[test]
358    fn purge_test_assets() -> Result<(), TestError> {
359        let test = Test::builder().module("foo").name("bar").build();
360
361        test.run(async {
362            let docker = bollard::Docker::connect_with_local_defaults()?;
363
364            // Ensure hello-world image
365            let hello_world_image = Image::from_repository("hello-world").with_version("linux");
366            hello_world_image.pull(&docker).await?;
367
368            // Create a network that shares the name and is properly labeled.
369            let network = docker
370                .create_network(CreateNetworkOptions {
371                    name: "pdk-test-network",
372                    driver: "bridge",
373                    labels: HashMap::from([("CreatedBy", "pdk-test")]),
374                    ..Default::default()
375                })
376                .await?;
377
378            let net_id = network.id;
379
380            let hello_world_locator = hello_world_image.locator();
381            let hello_world_name = "hello-world";
382
383            // Create a container that uses the network and is properly labeled.
384            let container = docker
385                .create_container(
386                    Some(CreateContainerOptions {
387                        name: hello_world_name,
388                        platform: None,
389                    }),
390                    bollard::container::Config {
391                        image: Some(hello_world_locator.as_str()),
392                        hostname: Some("helloWorld"),
393                        network_disabled: Some(false),
394                        networking_config: Some(NetworkingConfig {
395                            endpoints_config: HashMap::from([(
396                                net_id.as_str(),
397                                EndpointSettings {
398                                    ..Default::default()
399                                },
400                            )]),
401                        }),
402                        labels: Some(HashMap::from([("CreatedBy", "pdk-test")])),
403                        ..Default::default()
404                    },
405                )
406                .await?;
407
408            // start the container to connect it to the network
409            docker.start_container::<&str>(&container.id, None).await?;
410
411            let hello_world_inspect = docker.inspect_container(hello_world_name, None).await;
412
413            // Assert that hello-world container exists
414            assert!(hello_world_inspect.is_ok());
415
416            let httpbin_config = HttpBinConfig::builder().hostname("httpbin").build();
417
418            let _composite = TestComposite::builder()
419                .with_service(httpbin_config)
420                .build()
421                .await?;
422
423            let hello_world_inspect = docker.inspect_container(hello_world_name, None).await;
424
425            // Assert that hello-world container no longer exists
426            assert!(matches!(
427                hello_world_inspect,
428                Err(BollardError::DockerResponseServerError {
429                    status_code: 404,
430                    ..
431                })
432            ));
433
434            Ok(())
435        })
436    }
437}