Skip to main content

wfe_containerd/
step.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use async_trait::async_trait;
5use tonic::transport::{Channel, Endpoint, Uri};
6use wfe_core::WfeError;
7use wfe_core::models::ExecutionResult;
8use wfe_core::traits::step::{StepBody, StepExecutionContext};
9
10use wfe_containerd_protos::containerd::services::containers::v1::{
11    Container, CreateContainerRequest, DeleteContainerRequest, container::Runtime,
12    containers_client::ContainersClient,
13};
14use wfe_containerd_protos::containerd::services::content::v1::{
15    ReadContentRequest, content_client::ContentClient,
16};
17use wfe_containerd_protos::containerd::services::images::v1::{
18    GetImageRequest, images_client::ImagesClient,
19};
20use wfe_containerd_protos::containerd::services::snapshots::v1::{
21    MountsRequest, PrepareSnapshotRequest, snapshots_client::SnapshotsClient,
22};
23use wfe_containerd_protos::containerd::services::tasks::v1::{
24    CreateTaskRequest, DeleteTaskRequest, StartRequest, WaitRequest, tasks_client::TasksClient,
25};
26use wfe_containerd_protos::containerd::services::version::v1::version_client::VersionClient;
27
28use crate::config::ContainerdConfig;
29
30/// Default containerd namespace.
31const DEFAULT_NAMESPACE: &str = "default";
32
33/// Default snapshotter for rootless containerd.
34const DEFAULT_SNAPSHOTTER: &str = "overlayfs";
35
36/// Containerdstep.
37pub struct ContainerdStep {
38    config: ContainerdConfig,
39}
40
41impl ContainerdStep {
42    pub fn new(config: ContainerdConfig) -> Self {
43        Self { config }
44    }
45
46    /// Connect to the containerd daemon and return a raw tonic `Channel`.
47    ///
48    /// Supports Unix socket paths (bare `/path` or `unix:///path`) and
49    /// TCP/HTTP endpoints.
50    pub(crate) async fn connect(addr: &str) -> Result<Channel, WfeError> {
51        let channel = if addr.starts_with('/') || addr.starts_with("unix://") {
52            let socket_path = addr.strip_prefix("unix://").unwrap_or(addr).to_string();
53
54            if !Path::new(&socket_path).exists() {
55                return Err(WfeError::StepExecution(format!(
56                    "containerd socket not found: {socket_path}"
57                )));
58            }
59
60            Endpoint::try_from("http://[::]:50051")
61                .map_err(|e| WfeError::StepExecution(format!("failed to create endpoint: {e}")))?
62                .connect_with_connector(tower::service_fn(move |_: Uri| {
63                    let path = socket_path.clone();
64                    async move {
65                        tokio::net::UnixStream::connect(path)
66                            .await
67                            .map(hyper_util::rt::TokioIo::new)
68                    }
69                }))
70                .await
71                .map_err(|e| {
72                    WfeError::StepExecution(format!(
73                        "failed to connect to containerd via Unix socket at {addr}: {e}"
74                    ))
75                })?
76        } else {
77            let connect_addr = if addr.starts_with("tcp://") {
78                addr.replacen("tcp://", "http://", 1)
79            } else {
80                addr.to_string()
81            };
82
83            Endpoint::from_shared(connect_addr.clone())
84                .map_err(|e| {
85                    WfeError::StepExecution(format!(
86                        "invalid containerd endpoint {connect_addr}: {e}"
87                    ))
88                })?
89                .timeout(std::time::Duration::from_secs(30))
90                .connect()
91                .await
92                .map_err(|e| {
93                    WfeError::StepExecution(format!(
94                        "failed to connect to containerd at {connect_addr}: {e}"
95                    ))
96                })?
97        };
98
99        Ok(channel)
100    }
101
102    /// Check whether an image exists in containerd's image store.
103    ///
104    /// Image pulling via raw containerd gRPC is complex (content store +
105    /// snapshots + transfer). For now we only verify the image exists and
106    /// return an error if it does not. Images must be pre-pulled via
107    /// `ctr image pull` or `nerdctl pull`.
108    ///
109    /// TODO: implement full image pull via TransferService or content ingest.
110    async fn ensure_image(channel: &Channel, image: &str, namespace: &str) -> Result<(), WfeError> {
111        let mut client = ImagesClient::new(channel.clone());
112
113        let mut req = tonic::Request::new(GetImageRequest {
114            name: image.to_string(),
115        });
116        req.metadata_mut()
117            .insert("containerd-namespace", namespace.parse().unwrap());
118
119        match client.get(req).await {
120            Ok(_) => Ok(()),
121            Err(status) => Err(WfeError::StepExecution(format!(
122                "image '{image}' not found in containerd (namespace={namespace}). \
123                 Pre-pull it with: ctr -n {namespace} image pull {image}  \
124                 (gRPC status: {status})"
125            ))),
126        }
127    }
128
129    /// Resolve the snapshot chain ID for an image.
130    ///
131    /// This reads the image manifest and config from the content store to
132    /// compute the chain ID of the topmost layer. The chain ID is used as
133    /// the parent snapshot when preparing a writable rootfs for a container.
134    ///
135    /// Chain ID computation follows the OCI image spec:
136    ///   chain_id[0] = diff_id[0]
137    ///   chain_id[n] = sha256(chain_id[n-1] + " " + diff_id[n])
138    async fn resolve_image_chain_id(
139        channel: &Channel,
140        image: &str,
141        namespace: &str,
142    ) -> Result<String, WfeError> {
143        use sha2::{Digest, Sha256};
144
145        // 1. Get the image record to find the manifest digest.
146        let mut images_client = ImagesClient::new(channel.clone());
147        let req = Self::with_namespace(
148            GetImageRequest {
149                name: image.to_string(),
150            },
151            namespace,
152        );
153        let image_resp = images_client
154            .get(req)
155            .await
156            .map_err(|e| WfeError::StepExecution(format!("failed to get image '{image}': {e}")))?;
157        let img = image_resp
158            .into_inner()
159            .image
160            .ok_or_else(|| WfeError::StepExecution(format!("image '{image}' has no record")))?;
161        let target = img.target.ok_or_else(|| {
162            WfeError::StepExecution(format!("image '{image}' has no target descriptor"))
163        })?;
164
165        // The target might be an index (multi-platform) or a manifest.
166        // Read the content and determine based on mediaType.
167        let manifest_digest = target.digest.clone();
168        let manifest_bytes = Self::read_content(channel, &manifest_digest, namespace).await?;
169        let manifest_json: serde_json::Value = serde_json::from_slice(&manifest_bytes)
170            .map_err(|e| WfeError::StepExecution(format!("failed to parse manifest: {e}")))?;
171
172        // 2. If it's an index, pick the matching platform manifest.
173        let manifest_json = if manifest_json.get("manifests").is_some() {
174            // OCI image index — find the platform-matching manifest.
175            let arch = std::env::consts::ARCH;
176            let oci_arch = match arch {
177                "aarch64" => "arm64",
178                "x86_64" => "amd64",
179                other => other,
180            };
181            let manifests = manifest_json["manifests"].as_array().ok_or_else(|| {
182                WfeError::StepExecution("image index has no manifests array".to_string())
183            })?;
184            let platform_manifest = manifests
185                .iter()
186                .find(|m| {
187                    m.get("platform")
188                        .and_then(|p| p.get("architecture"))
189                        .and_then(|a| a.as_str())
190                        == Some(oci_arch)
191                })
192                .ok_or_else(|| {
193                    WfeError::StepExecution(format!(
194                        "no manifest for architecture '{oci_arch}' in image index"
195                    ))
196                })?;
197            let digest = platform_manifest["digest"].as_str().ok_or_else(|| {
198                WfeError::StepExecution("platform manifest has no digest".to_string())
199            })?;
200            let bytes = Self::read_content(channel, digest, namespace).await?;
201            serde_json::from_slice(&bytes).map_err(|e| {
202                WfeError::StepExecution(format!("failed to parse platform manifest: {e}"))
203            })?
204        } else {
205            manifest_json
206        };
207
208        // 3. Get the config digest from the manifest.
209        let config_digest = manifest_json["config"]["digest"]
210            .as_str()
211            .ok_or_else(|| WfeError::StepExecution("manifest has no config.digest".to_string()))?;
212
213        // 4. Read the image config.
214        let config_bytes = Self::read_content(channel, config_digest, namespace).await?;
215        let config_json: serde_json::Value = serde_json::from_slice(&config_bytes)
216            .map_err(|e| WfeError::StepExecution(format!("failed to parse image config: {e}")))?;
217
218        // 5. Extract diff_ids and compute chain ID.
219        let diff_ids = config_json["rootfs"]["diff_ids"]
220            .as_array()
221            .ok_or_else(|| {
222                WfeError::StepExecution("image config has no rootfs.diff_ids".to_string())
223            })?;
224
225        if diff_ids.is_empty() {
226            return Err(WfeError::StepExecution(
227                "image has no layers (empty diff_ids)".to_string(),
228            ));
229        }
230
231        let mut chain_id = diff_ids[0]
232            .as_str()
233            .ok_or_else(|| WfeError::StepExecution("diff_id is not a string".to_string()))?
234            .to_string();
235
236        for diff_id in &diff_ids[1..] {
237            let diff = diff_id
238                .as_str()
239                .ok_or_else(|| WfeError::StepExecution("diff_id is not a string".to_string()))?;
240            let mut hasher = Sha256::new();
241            hasher.update(format!("{chain_id} {diff}"));
242            chain_id = format!("sha256:{:x}", hasher.finalize());
243        }
244
245        tracing::debug!(image = image, chain_id = %chain_id, "resolved image chain ID");
246        Ok(chain_id)
247    }
248
249    /// Read content from the containerd content store by digest.
250    async fn read_content(
251        channel: &Channel,
252        digest: &str,
253        namespace: &str,
254    ) -> Result<Vec<u8>, WfeError> {
255        use tokio_stream::StreamExt;
256
257        let mut client = ContentClient::new(channel.clone());
258        let req = Self::with_namespace(
259            ReadContentRequest {
260                digest: digest.to_string(),
261                offset: 0,
262                size: 0, // read all
263            },
264            namespace,
265        );
266
267        let mut stream = client
268            .read(req)
269            .await
270            .map_err(|e| WfeError::StepExecution(format!("failed to read content {digest}: {e}")))?
271            .into_inner();
272
273        let mut data = Vec::new();
274        while let Some(chunk) = stream.next().await {
275            let chunk = chunk.map_err(|e| {
276                WfeError::StepExecution(format!("error reading content {digest}: {e}"))
277            })?;
278            data.extend_from_slice(&chunk.data);
279        }
280
281        Ok(data)
282    }
283
284    /// Build a minimal OCI runtime spec as a `prost_types::Any`.
285    ///
286    /// The spec is serialized as JSON and wrapped in a protobuf Any with
287    /// the containerd OCI spec type URL.
288    pub(crate) fn build_oci_spec(&self, merged_env: &HashMap<String, String>) -> prost_types::Any {
289        // Build the args array for the process.
290        let args: Vec<String> = if let Some(ref run) = self.config.run {
291            vec!["/bin/sh".to_string(), "-c".to_string(), run.clone()]
292        } else if let Some(ref command) = self.config.command {
293            command.clone()
294        } else {
295            vec![]
296        };
297
298        // Build env in KEY=VALUE form.
299        let env: Vec<String> = merged_env.iter().map(|(k, v)| format!("{k}={v}")).collect();
300
301        // Build mounts.
302        let mut mounts = vec![
303            serde_json::json!({
304                "destination": "/proc",
305                "type": "proc",
306                "source": "proc",
307                "options": ["nosuid", "noexec", "nodev"]
308            }),
309            serde_json::json!({
310                "destination": "/dev",
311                "type": "tmpfs",
312                "source": "tmpfs",
313                "options": ["nosuid", "strictatime", "mode=755", "size=65536k"]
314            }),
315            serde_json::json!({
316                "destination": "/sys",
317                "type": "sysfs",
318                "source": "sysfs",
319                "options": ["nosuid", "noexec", "nodev", "ro"]
320            }),
321        ];
322
323        for vol in &self.config.volumes {
324            let mut opts = vec!["rbind".to_string()];
325            if vol.readonly {
326                opts.push("ro".to_string());
327            }
328            mounts.push(serde_json::json!({
329                "destination": vol.target,
330                "type": "bind",
331                "source": vol.source,
332                "options": opts,
333            }));
334        }
335
336        // Parse user / group.
337        let (uid, gid) = parse_user_spec(&self.config.user);
338
339        let mut process = serde_json::json!({
340            "terminal": false,
341            "user": {
342                "uid": uid,
343                "gid": gid,
344            },
345            "args": args,
346            "env": env,
347            "cwd": self.config.working_dir.as_deref().unwrap_or("/"),
348        });
349
350        // Add capabilities. When running as root, grant the default Docker
351        // capability set so tools like apt-get work. Non-root gets nothing.
352        let caps = if uid == 0 {
353            serde_json::json!([
354                "CAP_AUDIT_WRITE",
355                "CAP_CHOWN",
356                "CAP_DAC_OVERRIDE",
357                "CAP_FOWNER",
358                "CAP_FSETID",
359                "CAP_KILL",
360                "CAP_MKNOD",
361                "CAP_NET_BIND_SERVICE",
362                "CAP_NET_RAW",
363                "CAP_SETFCAP",
364                "CAP_SETGID",
365                "CAP_SETPCAP",
366                "CAP_SETUID",
367                "CAP_SYS_CHROOT",
368            ])
369        } else {
370            serde_json::json!([])
371        };
372        process["capabilities"] = serde_json::json!({
373            "bounding": caps,
374            "effective": caps,
375            "inheritable": caps,
376            "permitted": caps,
377            "ambient": caps,
378        });
379
380        let spec = serde_json::json!({
381            "ociVersion": "1.0.2",
382            "process": process,
383            "root": {
384                "path": "rootfs",
385                "readonly": false,
386            },
387            "mounts": mounts,
388            "linux": {
389                "namespaces": [
390                    { "type": "pid" },
391                    { "type": "ipc" },
392                    { "type": "uts" },
393                    { "type": "mount" },
394                ],
395            },
396        });
397
398        let json_bytes = serde_json::to_vec(&spec).expect("OCI spec serialization cannot fail");
399
400        prost_types::Any {
401            type_url: "types.containerd.io/opencontainers/runtime-spec/1/Spec".to_string(),
402            value: json_bytes,
403        }
404    }
405
406    /// Inject a `containerd-namespace` header into a tonic request.
407    pub(crate) fn with_namespace<T>(req: T, namespace: &str) -> tonic::Request<T> {
408        let mut request = tonic::Request::new(req);
409        request
410            .metadata_mut()
411            .insert("containerd-namespace", namespace.parse().unwrap());
412        request
413    }
414
415    /// Start a long-running service container via the containerd gRPC API.
416    ///
417    /// Used by `ContainerdServiceProvider` to provision infrastructure services.
418    /// The container runs on the host network so its ports are accessible on 127.0.0.1.
419    /// Unlike step execution, this does NOT wait for the container to exit.
420    pub async fn run_service(
421        addr: &str,
422        container_id: &str,
423        image: &str,
424        env: &std::collections::HashMap<String, String>,
425    ) -> Result<(), WfeError> {
426        let namespace = DEFAULT_NAMESPACE;
427        let channel = Self::connect(addr).await?;
428
429        // Verify image exists.
430        Self::ensure_image(&channel, image, namespace).await?;
431
432        // Build a config for host-network service container.
433        let config = ContainerdConfig {
434            image: image.to_string(),
435            command: None,
436            run: None,
437            env: env.clone(),
438            volumes: vec![],
439            working_dir: None,
440            user: "0:0".to_string(),
441            network: "host".to_string(),
442            memory: None,
443            cpu: None,
444            pull: "if-not-present".to_string(),
445            containerd_addr: addr.to_string(),
446            cli: "nerdctl".to_string(),
447            tls: Default::default(),
448            registry_auth: Default::default(),
449            timeout_ms: None,
450        };
451
452        let step = Self::new(config);
453        let oci_spec = step.build_oci_spec(env);
454
455        // Create container.
456        let mut containers_client = ContainersClient::new(channel.clone());
457        let create_req = Self::with_namespace(
458            CreateContainerRequest {
459                container: Some(Container {
460                    id: container_id.to_string(),
461                    image: image.to_string(),
462                    runtime: Some(Runtime {
463                        name: "io.containerd.runc.v2".to_string(),
464                        options: None,
465                    }),
466                    spec: Some(oci_spec),
467                    snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
468                    snapshot_key: container_id.to_string(),
469                    labels: HashMap::new(),
470                    created_at: None,
471                    updated_at: None,
472                    extensions: HashMap::new(),
473                    sandbox: String::new(),
474                }),
475            },
476            namespace,
477        );
478        containers_client.create(create_req).await.map_err(|e| {
479            WfeError::StepExecution(format!("failed to create service container: {e}"))
480        })?;
481
482        // Prepare snapshot.
483        let mut snapshots_client = SnapshotsClient::new(channel.clone());
484        let mounts = {
485            let mounts_req = Self::with_namespace(
486                MountsRequest {
487                    snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
488                    key: container_id.to_string(),
489                },
490                namespace,
491            );
492            match snapshots_client.mounts(mounts_req).await {
493                Ok(resp) => resp.into_inner().mounts,
494                Err(_) => {
495                    let parent = Self::resolve_image_chain_id(&channel, image, namespace).await?;
496                    let prepare_req = Self::with_namespace(
497                        PrepareSnapshotRequest {
498                            snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
499                            key: container_id.to_string(),
500                            parent,
501                            labels: HashMap::new(),
502                        },
503                        namespace,
504                    );
505                    snapshots_client
506                        .prepare(prepare_req)
507                        .await
508                        .map_err(|e| {
509                            WfeError::StepExecution(format!("failed to prepare snapshot: {e}"))
510                        })?
511                        .into_inner()
512                        .mounts
513                }
514            }
515        };
516
517        // Create and start task (no stdout/stderr capture for services).
518        let mut tasks_client = TasksClient::new(channel.clone());
519        let create_task_req = Self::with_namespace(
520            CreateTaskRequest {
521                container_id: container_id.to_string(),
522                rootfs: mounts,
523                stdin: String::new(),
524                stdout: String::new(),
525                stderr: String::new(),
526                terminal: false,
527                checkpoint: None,
528                options: None,
529                runtime_path: String::new(),
530            },
531            namespace,
532        );
533        tasks_client
534            .create(create_task_req)
535            .await
536            .map_err(|e| WfeError::StepExecution(format!("failed to create service task: {e}")))?;
537
538        let start_req = Self::with_namespace(
539            StartRequest {
540                container_id: container_id.to_string(),
541                exec_id: String::new(),
542            },
543            namespace,
544        );
545        tasks_client
546            .start(start_req)
547            .await
548            .map_err(|e| WfeError::StepExecution(format!("failed to start service task: {e}")))?;
549
550        tracing::info!(container_id = %container_id, image = %image, "service container started");
551        Ok(())
552    }
553
554    /// Stop and clean up a service container via the containerd gRPC API.
555    pub async fn cleanup_service(addr: &str, container_id: &str) -> Result<(), WfeError> {
556        let channel = Self::connect(addr).await?;
557        Self::cleanup(&channel, container_id, DEFAULT_NAMESPACE).await
558    }
559
560    /// Parse `##wfe[output key=value]` lines from stdout.
561    pub fn parse_outputs(stdout: &str) -> HashMap<String, String> {
562        let mut outputs = HashMap::new();
563        for line in stdout.lines() {
564            if let Some(rest) = line.strip_prefix("##wfe[output ")
565                && let Some(rest) = rest.strip_suffix(']')
566                && let Some(eq_pos) = rest.find('=')
567            {
568                let name = rest[..eq_pos].trim().to_string();
569                let value = rest[eq_pos + 1..].to_string();
570                outputs.insert(name, value);
571            }
572        }
573        outputs
574    }
575
576    /// Build the output data JSON value from step execution results.
577    pub fn build_output_data(
578        step_name: &str,
579        stdout: &str,
580        stderr: &str,
581        exit_code: i32,
582        parsed_outputs: &HashMap<String, String>,
583    ) -> serde_json::Value {
584        let mut outputs = serde_json::Map::new();
585        for (key, value) in parsed_outputs {
586            outputs.insert(key.clone(), serde_json::Value::String(value.clone()));
587        }
588        outputs.insert(
589            format!("{step_name}.stdout"),
590            serde_json::Value::String(stdout.to_string()),
591        );
592        outputs.insert(
593            format!("{step_name}.stderr"),
594            serde_json::Value::String(stderr.to_string()),
595        );
596        outputs.insert(
597            format!("{step_name}.exit_code"),
598            serde_json::Value::Number(serde_json::Number::from(exit_code)),
599        );
600        serde_json::Value::Object(outputs)
601    }
602}
603
604/// Parse a "uid:gid" string into (u32, u32). Falls back to (65534, 65534).
605fn parse_user_spec(user: &str) -> (u32, u32) {
606    let parts: Vec<&str> = user.split(':').collect();
607    if parts.len() == 2 {
608        let uid = parts[0].parse().unwrap_or(65534);
609        let gid = parts[1].parse().unwrap_or(65534);
610        (uid, gid)
611    } else {
612        (65534, 65534)
613    }
614}
615
616#[async_trait]
617impl StepBody for ContainerdStep {
618    async fn run(
619        &mut self,
620        context: &StepExecutionContext<'_>,
621    ) -> wfe_core::Result<ExecutionResult> {
622        let step_name = context.step.name.as_deref().unwrap_or("unknown");
623        let namespace = DEFAULT_NAMESPACE;
624
625        // 1. Connect to containerd.
626        let addr = &self.config.containerd_addr;
627        tracing::info!(addr = %addr, "connecting to containerd daemon");
628        let channel = Self::connect(addr).await?;
629
630        // Verify connectivity.
631        {
632            let mut version_client = VersionClient::new(channel.clone());
633            let req = Self::with_namespace((), namespace);
634            match version_client.version(req).await {
635                Ok(resp) => {
636                    let v = resp.into_inner();
637                    tracing::info!(
638                        version = %v.version,
639                        revision = %v.revision,
640                        "connected to containerd"
641                    );
642                }
643                Err(e) => {
644                    return Err(WfeError::StepExecution(format!(
645                        "containerd version check failed: {e}"
646                    )));
647                }
648            }
649        }
650
651        // 2. Ensure image exists (based on pull policy).
652        let should_check = !matches!(self.config.pull.as_str(), "never");
653        if should_check {
654            Self::ensure_image(&channel, &self.config.image, namespace).await?;
655        }
656
657        // Generate a unique container ID.
658        let container_id = format!("wfe-{}", uuid::Uuid::new_v4());
659
660        // 3. Merge environment variables.
661        let mut merged_env: HashMap<String, String> = HashMap::new();
662        if let Some(data_obj) = context.workflow.data.as_object() {
663            for (key, value) in data_obj {
664                let env_key = key.to_uppercase();
665                let env_val = match value {
666                    serde_json::Value::String(s) => s.clone(),
667                    other => other.to_string(),
668                };
669                merged_env.insert(env_key, env_val);
670            }
671        }
672        // Config env overrides workflow data.
673        for (key, value) in &self.config.env {
674            merged_env.insert(key.clone(), value.clone());
675        }
676
677        // 4. Build OCI spec.
678        let oci_spec = self.build_oci_spec(&merged_env);
679
680        // 5. Create container.
681        tracing::info!(container_id = %container_id, image = %self.config.image, "creating container");
682        let mut containers_client = ContainersClient::new(channel.clone());
683        let create_req = Self::with_namespace(
684            CreateContainerRequest {
685                container: Some(Container {
686                    id: container_id.clone(),
687                    image: self.config.image.clone(),
688                    runtime: Some(Runtime {
689                        name: "io.containerd.runc.v2".to_string(),
690                        options: None,
691                    }),
692                    spec: Some(oci_spec),
693                    snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
694                    snapshot_key: container_id.clone(),
695                    labels: HashMap::new(),
696                    created_at: None,
697                    updated_at: None,
698                    extensions: HashMap::new(),
699                    sandbox: String::new(),
700                }),
701            },
702            namespace,
703        );
704
705        containers_client
706            .create(create_req)
707            .await
708            .map_err(|e| WfeError::StepExecution(format!("failed to create container: {e}")))?;
709
710        // 6. Prepare snapshot with the image's layers as parent.
711        let mut snapshots_client = SnapshotsClient::new(channel.clone());
712
713        let mounts = {
714            // First try: see if a snapshot was already prepared for this container.
715            let mounts_req = Self::with_namespace(
716                MountsRequest {
717                    snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
718                    key: container_id.clone(),
719                },
720                namespace,
721            );
722
723            match snapshots_client.mounts(mounts_req).await {
724                Ok(resp) => resp.into_inner().mounts,
725                Err(_) => {
726                    // Resolve the image's chain ID to use as snapshot parent.
727                    let parent = if should_check {
728                        Self::resolve_image_chain_id(&channel, &self.config.image, namespace)
729                            .await?
730                    } else {
731                        String::new()
732                    };
733
734                    let prepare_req = Self::with_namespace(
735                        PrepareSnapshotRequest {
736                            snapshotter: DEFAULT_SNAPSHOTTER.to_string(),
737                            key: container_id.clone(),
738                            parent,
739                            labels: HashMap::new(),
740                        },
741                        namespace,
742                    );
743                    snapshots_client
744                        .prepare(prepare_req)
745                        .await
746                        .map_err(|e| {
747                            WfeError::StepExecution(format!("failed to prepare snapshot: {e}"))
748                        })?
749                        .into_inner()
750                        .mounts
751                }
752            }
753        };
754
755        // 7. Create FIFO paths for stdout/stderr capture.
756        // Use WFE_IO_DIR if set (e.g., a shared mount with a remote containerd daemon),
757        // otherwise fall back to the system temp directory.
758        let io_base = std::env::var("WFE_IO_DIR")
759            .map(std::path::PathBuf::from)
760            .unwrap_or_else(|_| std::env::temp_dir());
761        let tmp_dir = io_base.join(format!("wfe-io-{container_id}"));
762        std::fs::create_dir_all(&tmp_dir)
763            .map_err(|e| WfeError::StepExecution(format!("failed to create IO temp dir: {e}")))?;
764
765        let stdout_path = tmp_dir.join("stdout");
766        let stderr_path = tmp_dir.join("stderr");
767
768        // Create empty files for the shim to write stdout/stderr to.
769        // We use regular files instead of FIFOs because FIFOs don't work
770        // across filesystem boundaries (e.g., virtiofs mounts with Lima VMs).
771        for path in [&stdout_path, &stderr_path] {
772            let _ = std::fs::remove_file(path);
773            std::fs::File::create(path).map_err(|e| {
774                WfeError::StepExecution(format!("failed to create IO file {}: {e}", path.display()))
775            })?;
776            // Ensure the remote shim can write to it.
777            #[cfg(unix)]
778            {
779                use std::os::unix::fs::PermissionsExt;
780                std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o666)).ok();
781            }
782        }
783
784        let stdout_str = stdout_path.to_string_lossy().to_string();
785        let stderr_str = stderr_path.to_string_lossy().to_string();
786
787        // 8. Create task.
788        let mut tasks_client = TasksClient::new(channel.clone());
789
790        let create_task_req = Self::with_namespace(
791            CreateTaskRequest {
792                container_id: container_id.clone(),
793                rootfs: mounts,
794                stdin: String::new(),
795                stdout: stdout_str.clone(),
796                stderr: stderr_str.clone(),
797                terminal: false,
798                checkpoint: None,
799                options: None,
800                runtime_path: String::new(),
801            },
802            namespace,
803        );
804
805        tasks_client
806            .create(create_task_req)
807            .await
808            .map_err(|e| WfeError::StepExecution(format!("failed to create task: {e}")))?;
809
810        // Start the task.
811        let start_req = Self::with_namespace(
812            StartRequest {
813                container_id: container_id.clone(),
814                exec_id: String::new(),
815            },
816            namespace,
817        );
818
819        tasks_client
820            .start(start_req)
821            .await
822            .map_err(|e| WfeError::StepExecution(format!("failed to start task: {e}")))?;
823
824        tracing::info!(container_id = %container_id, "task started");
825
826        // 9. Wait for task completion (with optional timeout).
827        let wait_req = Self::with_namespace(
828            WaitRequest {
829                container_id: container_id.clone(),
830                exec_id: String::new(),
831            },
832            namespace,
833        );
834
835        let wait_result = if let Some(timeout_ms) = self.config.timeout_ms {
836            let duration = std::time::Duration::from_millis(timeout_ms);
837            match tokio::time::timeout(duration, tasks_client.wait(wait_req)).await {
838                Ok(result) => result,
839                Err(_) => {
840                    // Attempt cleanup before returning timeout error.
841                    let _ = Self::cleanup(&channel, &container_id, namespace).await;
842                    let _ = std::fs::remove_dir_all(&tmp_dir);
843                    return Err(WfeError::StepExecution(format!(
844                        "container execution timed out after {timeout_ms}ms"
845                    )));
846                }
847            }
848        } else {
849            tasks_client.wait(wait_req).await
850        };
851
852        let exit_status = match wait_result {
853            Ok(resp) => resp.into_inner().exit_status,
854            Err(e) => {
855                let _ = Self::cleanup(&channel, &container_id, namespace).await;
856                let _ = std::fs::remove_dir_all(&tmp_dir);
857                return Err(WfeError::StepExecution(format!(
858                    "failed waiting for task: {e}"
859                )));
860            }
861        };
862
863        // 10. Read captured output from files.
864        let stdout_content = tokio::fs::read_to_string(&stdout_path)
865            .await
866            .unwrap_or_default();
867        let stderr_content = tokio::fs::read_to_string(&stderr_path)
868            .await
869            .unwrap_or_default();
870
871        // 11. Cleanup: delete task, then container.
872        if let Err(e) = Self::cleanup(&channel, &container_id, namespace).await {
873            tracing::warn!(container_id = %container_id, error = %e, "cleanup failed");
874        }
875        let _ = std::fs::remove_dir_all(&tmp_dir);
876
877        // 12. Check exit status.
878        let exit_code = exit_status as i32;
879        if exit_code != 0 {
880            return Err(WfeError::StepExecution(format!(
881                "container exited with code {exit_code}\nstdout: {stdout_content}\nstderr: {stderr_content}"
882            )));
883        }
884
885        // 13. Parse outputs and build result.
886        let parsed = Self::parse_outputs(&stdout_content);
887        let output_data = Self::build_output_data(
888            step_name,
889            &stdout_content,
890            &stderr_content,
891            exit_code,
892            &parsed,
893        );
894
895        Ok(ExecutionResult {
896            proceed: true,
897            output_data: Some(output_data),
898            ..Default::default()
899        })
900    }
901}
902
903impl ContainerdStep {
904    /// Delete the task and container, best-effort.
905    pub(crate) async fn cleanup(
906        channel: &Channel,
907        container_id: &str,
908        namespace: &str,
909    ) -> Result<(), WfeError> {
910        let mut tasks_client = TasksClient::new(channel.clone());
911        let mut containers_client = ContainersClient::new(channel.clone());
912
913        // Delete task (ignore errors — it may already be gone).
914        let del_task_req = Self::with_namespace(
915            DeleteTaskRequest {
916                container_id: container_id.to_string(),
917            },
918            namespace,
919        );
920        let _ = tasks_client.delete(del_task_req).await;
921
922        // Delete container.
923        let del_container_req = Self::with_namespace(
924            DeleteContainerRequest {
925                id: container_id.to_string(),
926            },
927            namespace,
928        );
929        containers_client
930            .delete(del_container_req)
931            .await
932            .map_err(|e| WfeError::StepExecution(format!("failed to delete container: {e}")))?;
933
934        Ok(())
935    }
936}
937
938#[cfg(test)]
939mod tests {
940    use super::*;
941    use crate::config::{TlsConfig, VolumeMountConfig};
942    use pretty_assertions::assert_eq;
943
944    fn minimal_config() -> ContainerdConfig {
945        ContainerdConfig {
946            image: "alpine:3.18".to_string(),
947            command: None,
948            run: Some("echo hello".to_string()),
949            env: HashMap::new(),
950            volumes: vec![],
951            working_dir: None,
952            user: "65534:65534".to_string(),
953            network: "none".to_string(),
954            memory: None,
955            cpu: None,
956            pull: "if-not-present".to_string(),
957            containerd_addr: "/run/containerd/containerd.sock".to_string(),
958            cli: "nerdctl".to_string(),
959            tls: TlsConfig::default(),
960            registry_auth: HashMap::new(),
961            timeout_ms: None,
962        }
963    }
964
965    // ── parse_outputs ──────────────────────────────────────────────────
966
967    #[test]
968    fn parse_outputs_empty() {
969        let outputs = ContainerdStep::parse_outputs("");
970        assert!(outputs.is_empty());
971    }
972
973    #[test]
974    fn parse_outputs_single() {
975        let stdout = "some log line\n##wfe[output version=1.2.3]\nmore logs\n";
976        let outputs = ContainerdStep::parse_outputs(stdout);
977        assert_eq!(outputs.len(), 1);
978        assert_eq!(outputs.get("version").unwrap(), "1.2.3");
979    }
980
981    #[test]
982    fn parse_outputs_multiple() {
983        let stdout = "##wfe[output foo=bar]\n##wfe[output baz=qux]\n";
984        let outputs = ContainerdStep::parse_outputs(stdout);
985        assert_eq!(outputs.len(), 2);
986        assert_eq!(outputs.get("foo").unwrap(), "bar");
987        assert_eq!(outputs.get("baz").unwrap(), "qux");
988    }
989
990    #[test]
991    fn parse_outputs_mixed_with_regular_stdout() {
992        let stdout = "Starting container...\n\
993                      Pulling image...\n\
994                      ##wfe[output digest=sha256:abc123]\n\
995                      Running tests...\n\
996                      ##wfe[output result=pass]\n\
997                      Done.\n";
998        let outputs = ContainerdStep::parse_outputs(stdout);
999        assert_eq!(outputs.len(), 2);
1000        assert_eq!(outputs.get("digest").unwrap(), "sha256:abc123");
1001        assert_eq!(outputs.get("result").unwrap(), "pass");
1002    }
1003
1004    #[test]
1005    fn parse_outputs_no_wfe_lines() {
1006        let stdout = "line 1\nline 2\nline 3\n";
1007        let outputs = ContainerdStep::parse_outputs(stdout);
1008        assert!(outputs.is_empty());
1009    }
1010
1011    #[test]
1012    fn parse_outputs_value_with_equals_sign() {
1013        let stdout = "##wfe[output url=https://example.com?a=1&b=2]\n";
1014        let outputs = ContainerdStep::parse_outputs(stdout);
1015        assert_eq!(outputs.len(), 1);
1016        assert_eq!(outputs.get("url").unwrap(), "https://example.com?a=1&b=2");
1017    }
1018
1019    #[test]
1020    fn parse_outputs_ignores_malformed_lines() {
1021        let stdout = "##wfe[output ]\n\
1022                      ##wfe[output no_equals]\n\
1023                      ##wfe[output valid=yes]\n\
1024                      ##wfe[output_extra bad=val]\n";
1025        let outputs = ContainerdStep::parse_outputs(stdout);
1026        assert_eq!(outputs.len(), 1);
1027        assert_eq!(outputs.get("valid").unwrap(), "yes");
1028    }
1029
1030    #[test]
1031    fn parse_outputs_overwrites_duplicate_keys() {
1032        let stdout = "##wfe[output key=first]\n##wfe[output key=second]\n";
1033        let outputs = ContainerdStep::parse_outputs(stdout);
1034        assert_eq!(outputs.len(), 1);
1035        assert_eq!(outputs.get("key").unwrap(), "second");
1036    }
1037
1038    // ── build_output_data ──────────────────────────────────────────────
1039
1040    #[test]
1041    fn build_output_data_basic() {
1042        let parsed = HashMap::from([("result".to_string(), "success".to_string())]);
1043        let data = ContainerdStep::build_output_data("my_step", "hello world\n", "", 0, &parsed);
1044
1045        let obj = data.as_object().unwrap();
1046        assert_eq!(obj.get("result").unwrap(), "success");
1047        assert_eq!(obj.get("my_step.stdout").unwrap(), "hello world\n");
1048        assert_eq!(obj.get("my_step.stderr").unwrap(), "");
1049        assert_eq!(obj.get("my_step.exit_code").unwrap(), 0);
1050    }
1051
1052    #[test]
1053    fn build_output_data_no_parsed_outputs() {
1054        let data = ContainerdStep::build_output_data("step1", "out", "err", 1, &HashMap::new());
1055
1056        let obj = data.as_object().unwrap();
1057        assert_eq!(obj.len(), 3); // stdout, stderr, exit_code
1058        assert_eq!(obj.get("step1.stdout").unwrap(), "out");
1059        assert_eq!(obj.get("step1.stderr").unwrap(), "err");
1060        assert_eq!(obj.get("step1.exit_code").unwrap(), 1);
1061    }
1062
1063    #[test]
1064    fn build_output_data_with_multiple_parsed_outputs() {
1065        let parsed = HashMap::from([
1066            ("a".to_string(), "1".to_string()),
1067            ("b".to_string(), "2".to_string()),
1068            ("c".to_string(), "3".to_string()),
1069        ]);
1070        let data = ContainerdStep::build_output_data("s", "", "", 0, &parsed);
1071
1072        let obj = data.as_object().unwrap();
1073        assert_eq!(obj.get("a").unwrap(), "1");
1074        assert_eq!(obj.get("b").unwrap(), "2");
1075        assert_eq!(obj.get("c").unwrap(), "3");
1076        // Plus the 3 standard keys
1077        assert_eq!(obj.len(), 6);
1078    }
1079
1080    #[test]
1081    fn build_output_data_negative_exit_code() {
1082        let data = ContainerdStep::build_output_data("s", "", "", -1, &HashMap::new());
1083        let obj = data.as_object().unwrap();
1084        assert_eq!(obj.get("s.exit_code").unwrap(), -1);
1085    }
1086
1087    // ── parse_user_spec ────────────────────────────────────────────────
1088
1089    #[test]
1090    fn parse_user_spec_normal() {
1091        assert_eq!(parse_user_spec("1000:1000"), (1000, 1000));
1092    }
1093
1094    #[test]
1095    fn parse_user_spec_root() {
1096        assert_eq!(parse_user_spec("0:0"), (0, 0));
1097    }
1098
1099    #[test]
1100    fn parse_user_spec_default() {
1101        assert_eq!(parse_user_spec("65534:65534"), (65534, 65534));
1102    }
1103
1104    #[test]
1105    fn parse_user_spec_invalid_falls_back() {
1106        assert_eq!(parse_user_spec("abc"), (65534, 65534));
1107    }
1108
1109    // ── build_oci_spec ─────────────────────────────────────────────────
1110
1111    #[test]
1112    fn build_oci_spec_minimal() {
1113        let step = ContainerdStep::new(minimal_config());
1114        let env = HashMap::new();
1115        let spec = step.build_oci_spec(&env);
1116
1117        assert_eq!(
1118            spec.type_url,
1119            "types.containerd.io/opencontainers/runtime-spec/1/Spec"
1120        );
1121        assert!(!spec.value.is_empty());
1122
1123        // Deserialize and verify.
1124        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1125        assert_eq!(parsed["ociVersion"], "1.0.2");
1126        assert_eq!(parsed["process"]["args"][0], "/bin/sh");
1127        assert_eq!(parsed["process"]["args"][1], "-c");
1128        assert_eq!(parsed["process"]["args"][2], "echo hello");
1129        assert_eq!(parsed["process"]["user"]["uid"], 65534);
1130        assert_eq!(parsed["process"]["user"]["gid"], 65534);
1131        assert_eq!(parsed["process"]["cwd"], "/");
1132    }
1133
1134    #[test]
1135    fn build_oci_spec_with_command() {
1136        let mut config = minimal_config();
1137        config.run = None;
1138        config.command = Some(vec![
1139            "echo".to_string(),
1140            "hello".to_string(),
1141            "world".to_string(),
1142        ]);
1143        let step = ContainerdStep::new(config);
1144        let spec = step.build_oci_spec(&HashMap::new());
1145
1146        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1147        assert_eq!(parsed["process"]["args"][0], "echo");
1148        assert_eq!(parsed["process"]["args"][1], "hello");
1149        assert_eq!(parsed["process"]["args"][2], "world");
1150    }
1151
1152    #[test]
1153    fn build_oci_spec_with_env() {
1154        let step = ContainerdStep::new(minimal_config());
1155        let env = HashMap::from([
1156            ("FOO".to_string(), "bar".to_string()),
1157            ("BAZ".to_string(), "qux".to_string()),
1158        ]);
1159        let spec = step.build_oci_spec(&env);
1160
1161        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1162        let env_arr: Vec<String> = parsed["process"]["env"]
1163            .as_array()
1164            .unwrap()
1165            .iter()
1166            .map(|v| v.as_str().unwrap().to_string())
1167            .collect();
1168
1169        assert!(env_arr.contains(&"FOO=bar".to_string()));
1170        assert!(env_arr.contains(&"BAZ=qux".to_string()));
1171    }
1172
1173    #[test]
1174    fn build_oci_spec_with_working_dir() {
1175        let mut config = minimal_config();
1176        config.working_dir = Some("/app".to_string());
1177        let step = ContainerdStep::new(config);
1178        let spec = step.build_oci_spec(&HashMap::new());
1179
1180        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1181        assert_eq!(parsed["process"]["cwd"], "/app");
1182    }
1183
1184    #[test]
1185    fn build_oci_spec_with_user() {
1186        let mut config = minimal_config();
1187        config.user = "1000:2000".to_string();
1188        let step = ContainerdStep::new(config);
1189        let spec = step.build_oci_spec(&HashMap::new());
1190
1191        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1192        assert_eq!(parsed["process"]["user"]["uid"], 1000);
1193        assert_eq!(parsed["process"]["user"]["gid"], 2000);
1194    }
1195
1196    #[test]
1197    fn build_oci_spec_with_volumes() {
1198        let mut config = minimal_config();
1199        config.volumes = vec![
1200            VolumeMountConfig {
1201                source: "/host/data".to_string(),
1202                target: "/container/data".to_string(),
1203                readonly: false,
1204            },
1205            VolumeMountConfig {
1206                source: "/host/config".to_string(),
1207                target: "/etc/config".to_string(),
1208                readonly: true,
1209            },
1210        ];
1211        let step = ContainerdStep::new(config);
1212        let spec = step.build_oci_spec(&HashMap::new());
1213
1214        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1215        let mounts = parsed["mounts"].as_array().unwrap();
1216        // 3 default + 2 user = 5
1217        assert_eq!(mounts.len(), 5);
1218
1219        let bind_mounts: Vec<&serde_json::Value> =
1220            mounts.iter().filter(|m| m["type"] == "bind").collect();
1221        assert_eq!(bind_mounts.len(), 2);
1222
1223        let ro_mount = bind_mounts
1224            .iter()
1225            .find(|m| m["destination"] == "/etc/config")
1226            .unwrap();
1227        let opts: Vec<String> = ro_mount["options"]
1228            .as_array()
1229            .unwrap()
1230            .iter()
1231            .map(|v| v.as_str().unwrap().to_string())
1232            .collect();
1233        assert!(opts.contains(&"ro".to_string()));
1234    }
1235
1236    #[test]
1237    fn build_oci_spec_no_command_no_run() {
1238        let mut config = minimal_config();
1239        config.run = None;
1240        config.command = None;
1241        let step = ContainerdStep::new(config);
1242        let spec = step.build_oci_spec(&HashMap::new());
1243
1244        let parsed: serde_json::Value = serde_json::from_slice(&spec.value).unwrap();
1245        assert!(parsed["process"]["args"].as_array().unwrap().is_empty());
1246    }
1247
1248    // ── connect ────────────────────────────────────────────────────────
1249
1250    #[tokio::test]
1251    async fn connect_to_missing_unix_socket_returns_error() {
1252        let err = ContainerdStep::connect("/tmp/nonexistent-wfe-containerd-test.sock")
1253            .await
1254            .unwrap_err();
1255        let msg = format!("{err}");
1256        assert!(
1257            msg.contains("socket not found"),
1258            "expected 'socket not found' error, got: {msg}"
1259        );
1260    }
1261
1262    #[tokio::test]
1263    async fn connect_to_missing_unix_socket_with_scheme_returns_error() {
1264        let err = ContainerdStep::connect("unix:///tmp/nonexistent-wfe-containerd-test.sock")
1265            .await
1266            .unwrap_err();
1267        let msg = format!("{err}");
1268        assert!(
1269            msg.contains("socket not found"),
1270            "expected 'socket not found' error, got: {msg}"
1271        );
1272    }
1273
1274    #[tokio::test]
1275    async fn connect_to_invalid_tcp_returns_error() {
1276        let err = ContainerdStep::connect("tcp://127.0.0.1:1")
1277            .await
1278            .unwrap_err();
1279        let msg = format!("{err}");
1280        assert!(
1281            msg.contains("failed to connect"),
1282            "expected connection error, got: {msg}"
1283        );
1284    }
1285
1286    // ── ContainerdStep::new ────────────────────────────────────────────
1287
1288    #[test]
1289    fn new_creates_step_with_config() {
1290        let config = minimal_config();
1291        let step = ContainerdStep::new(config);
1292        assert_eq!(step.config.image, "alpine:3.18");
1293        assert_eq!(
1294            step.config.containerd_addr,
1295            "/run/containerd/containerd.sock"
1296        );
1297    }
1298}
1299
1300/// Integration tests that require a live containerd daemon.
1301#[cfg(test)]
1302mod e2e_tests {
1303    use super::*;
1304
1305    /// Returns the containerd socket address if available, or None.
1306    fn containerd_addr() -> Option<String> {
1307        let addr = std::env::var("WFE_CONTAINERD_ADDR").unwrap_or_else(|_| {
1308            format!(
1309                "unix://{}/.lima/wfe-test/sock/containerd.sock",
1310                std::env::var("HOME").unwrap_or_else(|_| "/root".to_string())
1311            )
1312        });
1313
1314        let socket_path = addr.strip_prefix("unix://").unwrap_or(addr.as_str());
1315
1316        if Path::new(socket_path).exists() {
1317            Some(addr)
1318        } else {
1319            None
1320        }
1321    }
1322
1323    #[tokio::test]
1324    async fn e2e_version_check() {
1325        let Some(addr) = containerd_addr() else {
1326            eprintln!("SKIP: containerd socket not available");
1327            return;
1328        };
1329
1330        let channel = ContainerdStep::connect(&addr).await.unwrap();
1331        let mut client = VersionClient::new(channel);
1332
1333        let req = ContainerdStep::with_namespace((), DEFAULT_NAMESPACE);
1334        let resp = client.version(req).await.unwrap();
1335        let version = resp.into_inner();
1336
1337        assert!(!version.version.is_empty(), "version should not be empty");
1338        assert!(!version.revision.is_empty(), "revision should not be empty");
1339        eprintln!(
1340            "containerd version={} revision={}",
1341            version.version, version.revision
1342        );
1343    }
1344}