containerd_shimkit/sandbox/shim/
local.rs

1use std::collections::HashMap;
2use std::fs::create_dir_all;
3use std::ops::Not;
4use std::path::Path;
5use std::sync::Arc;
6
7use anyhow::ensure;
8use containerd_shim::api::{
9    ConnectRequest, ConnectResponse, CreateTaskRequest, CreateTaskResponse, DeleteRequest, Empty,
10    KillRequest, ShutdownRequest, StartRequest, StartResponse, StateRequest, StateResponse,
11    StatsRequest, StatsResponse, WaitRequest, WaitResponse,
12};
13use containerd_shim::error::Error as ShimError;
14use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, TaskIO, TaskStart};
15use containerd_shim::protos::shim::shim_ttrpc::Task;
16use containerd_shim::protos::types::task::Status;
17use containerd_shim::util::IntoOption;
18use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult};
19use futures::FutureExt as _;
20use log::debug;
21use oci_spec::runtime::Spec;
22use prost::Message;
23use protobuf::well_known_types::any::Any;
24use serde::{Deserialize, Serialize};
25use tokio::sync::RwLock;
26#[cfg(feature = "opentelemetry")]
27use tracing_opentelemetry::OpenTelemetrySpanExt as _;
28
29#[cfg(feature = "opentelemetry")]
30use super::otel::extract_context;
31use crate::sandbox::async_utils::AmbientRuntime as _;
32use crate::sandbox::instance::{Instance, InstanceConfig};
33use crate::sandbox::shim::events::{EventSender, RemoteEventSender, ToTimestamp};
34use crate::sandbox::shim::instance_data::InstanceData;
35use crate::sandbox::sync::WaitableCell;
36use crate::sandbox::{Error, Result, oci};
37use crate::sys::metrics::get_metrics;
38
39#[cfg(test)]
40mod tests;
41
42/// containerd runtime options
43#[derive(Message, Clone, PartialEq)]
44struct Options {
45    #[prost(string)]
46    type_url: String,
47    #[prost(string)]
48    config_path: String,
49    #[prost(string)]
50    config_body: String,
51}
52
53/// This is generated by decoding the `options` field of a `CreateTaskRequest` to get an `Options` struct,
54/// interpreting the `config_body` field as TOML,
55/// and deserializing it.
56#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Debug)]
57pub struct Config {
58    /// Enables systemd cgroup.
59    #[serde(alias = "SystemdCgroup")]
60    pub systemd_cgroup: bool,
61}
62
63impl Config {
64    fn get_from_options(options: Option<&Any>) -> anyhow::Result<Self> {
65        let Some(opts) = options else {
66            return Ok(Default::default());
67        };
68
69        ensure!(
70            opts.type_url == "runtimeoptions.v1.Options",
71            "Invalid options type {}",
72            opts.type_url
73        );
74
75        let opts = Options::decode(opts.value.as_slice())?;
76
77        let config = toml::from_str(opts.config_body.as_str())
78            .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;
79
80        Ok(config)
81    }
82}
83
84type LocalInstances<T> = RwLock<HashMap<String, Arc<InstanceData<T>>>>;
85
86/// Local implements the Task service for a containerd shim.
87/// It defers all task operations to the `Instance` implementation.
88pub struct Local<T: Instance + Send + Sync, E: EventSender = RemoteEventSender> {
89    pub(super) instances: LocalInstances<T>,
90    events: E,
91    exit: WaitableCell<()>,
92    namespace: String,
93    containerd_address: String,
94}
95
96impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
97    /// Creates a new local task service.
98    #[cfg_attr(
99        feature = "tracing",
100        tracing::instrument(skip(events, exit), level = "Debug")
101    )]
102    pub fn new(
103        events: E,
104        exit: WaitableCell<()>,
105        namespace: impl AsRef<str> + std::fmt::Debug,
106        containerd_address: impl AsRef<str> + std::fmt::Debug,
107    ) -> Self {
108        let instances = RwLock::default();
109        let namespace = namespace.as_ref().to_string();
110        let containerd_address = containerd_address.as_ref().to_string();
111        Self {
112            instances,
113            events,
114            exit,
115            namespace,
116            containerd_address,
117        }
118    }
119
120    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
121    pub(super) async fn get_instance(&self, id: &str) -> Result<Arc<InstanceData<T>>> {
122        let instance = self.instances.read().await.get(id).cloned();
123        instance.ok_or_else(|| Error::NotFound(id.to_string()))
124    }
125
126    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
127    async fn has_instance(&self, id: &str) -> bool {
128        self.instances.read().await.contains_key(id)
129    }
130
131    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
132    async fn is_empty(&self) -> bool {
133        self.instances.read().await.is_empty()
134    }
135}
136
137// These are the same functions as in Task, but without the TtrcpContext, which is useful for testing
138impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
139    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
140    async fn task_create(&self, req: CreateTaskRequest) -> Result<CreateTaskResponse> {
141        let config = Config::get_from_options(req.options.as_ref())
142            .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;
143
144        if !req.checkpoint().is_empty() || !req.parent_checkpoint().is_empty() {
145            return Err(ShimError::Unimplemented("checkpoint is not supported".to_string()).into());
146        }
147
148        if req.terminal {
149            return Err(Error::InvalidArgument(
150                "terminal is not supported".to_string(),
151            ));
152        }
153
154        if self.has_instance(&req.id).await {
155            return Err(Error::AlreadyExists(req.id));
156        }
157
158        let mut spec = Spec::load(Path::new(&req.bundle).join("config.json"))
159            .map_err(|err| Error::InvalidArgument(format!("could not load runtime spec: {err}")))?;
160
161        spec.canonicalize_rootfs(req.bundle()).map_err(|err| {
162            ShimError::InvalidArgument(format!("could not canonicalize rootfs: {}", err))
163        })?;
164
165        let rootfs = spec
166            .root()
167            .as_ref()
168            .ok_or_else(|| Error::InvalidArgument("rootfs is not set in runtime spec".to_string()))?
169            .path();
170
171        let _ = create_dir_all(rootfs);
172        let rootfs_mounts = req.rootfs().to_vec();
173        if !rootfs_mounts.is_empty() {
174            for m in rootfs_mounts {
175                let _mount_type = m.type_().none_if(|&x| x.is_empty());
176                let _source = m.source.as_str().none_if(|&x| x.is_empty());
177
178                #[cfg(unix)]
179                containerd_shim::mount::mount_rootfs(
180                    _mount_type,
181                    _source,
182                    &m.options.to_vec(),
183                    rootfs,
184                )?;
185            }
186        }
187
188        let cfg = InstanceConfig {
189            namespace: self.namespace.clone(),
190            containerd_address: self.containerd_address.clone(),
191            bundle: req.bundle.as_str().into(),
192            stdout: req.stdout.as_str().into(),
193            stderr: req.stderr.as_str().into(),
194            stdin: req.stdin.as_str().into(),
195            config,
196        };
197
198        // Check if this is a cri container
199        let instance = InstanceData::new(req.id(), cfg).await?;
200
201        self.instances
202            .write()
203            .await
204            .insert(req.id().to_string(), Arc::new(instance));
205
206        self.events.send(TaskCreate {
207            container_id: req.id,
208            bundle: req.bundle,
209            rootfs: req.rootfs,
210            io: Some(TaskIO {
211                stdin: req.stdin,
212                stdout: req.stdout,
213                stderr: req.stderr,
214                ..Default::default()
215            })
216            .into(),
217            ..Default::default()
218        });
219
220        debug!("create done");
221
222        // Per the spec, the prestart hook must be called as part of the create operation
223        debug!("call prehook before the start");
224        oci::setup_prestart_hooks(spec.hooks())?;
225
226        Ok(CreateTaskResponse {
227            pid: std::process::id(),
228            ..Default::default()
229        })
230    }
231
232    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
233    async fn task_start(&self, req: StartRequest) -> Result<StartResponse> {
234        if req.exec_id().is_empty().not() {
235            return Err(ShimError::Unimplemented("exec is not supported".to_string()).into());
236        }
237
238        let i = self.get_instance(req.id()).await?;
239        let pid = i.start().await?;
240
241        self.events.send(TaskStart {
242            container_id: req.id().into(),
243            pid,
244            ..Default::default()
245        });
246
247        let events = self.events.clone();
248
249        let id = req.id().to_string();
250
251        async move {
252            let (exit_code, timestamp) = i.wait().await;
253            events.send(TaskExit {
254                container_id: id.clone(),
255                exit_status: exit_code,
256                exited_at: Some(timestamp.to_timestamp()).into(),
257                pid,
258                id,
259                ..Default::default()
260            });
261        }
262        .spawn();
263
264        debug!("started: {:?}", req);
265
266        Ok(StartResponse {
267            pid,
268            ..Default::default()
269        })
270    }
271
272    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
273    async fn task_kill(&self, req: KillRequest) -> Result<Empty> {
274        if !req.exec_id().is_empty() {
275            return Err(Error::InvalidArgument("exec is not supported".to_string()));
276        }
277        self.get_instance(req.id())
278            .await?
279            .kill(req.signal())
280            .await?;
281        Ok(Empty::new())
282    }
283
284    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
285    async fn task_delete(&self, req: DeleteRequest) -> Result<DeleteResponse> {
286        if !req.exec_id().is_empty() {
287            return Err(Error::InvalidArgument("exec is not supported".to_string()));
288        }
289
290        let i = self.get_instance(req.id()).await?;
291
292        i.delete().await?;
293
294        let pid = i.pid().unwrap_or_default();
295        let (exit_code, timestamp) = i.wait().now_or_never().unzip();
296        let timestamp = timestamp.map(ToTimestamp::to_timestamp);
297
298        self.instances.write().await.remove(req.id());
299
300        self.events.send(TaskDelete {
301            container_id: req.id().into(),
302            pid,
303            exit_status: exit_code.unwrap_or_default(),
304            exited_at: timestamp.clone().into(),
305            ..Default::default()
306        });
307
308        Ok(DeleteResponse {
309            pid,
310            exit_status: exit_code.unwrap_or_default(),
311            exited_at: timestamp.into(),
312            ..Default::default()
313        })
314    }
315
316    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
317    async fn task_wait(&self, req: WaitRequest) -> Result<WaitResponse> {
318        if !req.exec_id().is_empty() {
319            return Err(Error::InvalidArgument("exec is not supported".to_string()));
320        }
321
322        let i = self.get_instance(req.id()).await?;
323        let (exit_code, timestamp) = i.wait().await;
324
325        debug!("wait finishes");
326        Ok(WaitResponse {
327            exit_status: exit_code,
328            exited_at: Some(timestamp.to_timestamp()).into(),
329            ..Default::default()
330        })
331    }
332
333    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
334    async fn task_state(&self, req: StateRequest) -> Result<StateResponse> {
335        if !req.exec_id().is_empty() {
336            return Err(Error::InvalidArgument("exec is not supported".to_string()));
337        }
338
339        let i = self.get_instance(req.id()).await?;
340        let pid = i.pid();
341        let (exit_code, timestamp) = i.wait().now_or_never().unzip();
342        let timestamp = timestamp.map(ToTimestamp::to_timestamp);
343
344        let status = if pid.is_none() {
345            Status::CREATED
346        } else if exit_code.is_none() {
347            Status::RUNNING
348        } else {
349            Status::STOPPED
350        };
351
352        Ok(StateResponse {
353            bundle: i.config.bundle.to_string_lossy().to_string(),
354            stdin: i.config.stdin.to_string_lossy().to_string(),
355            stdout: i.config.stdout.to_string_lossy().to_string(),
356            stderr: i.config.stderr.to_string_lossy().to_string(),
357            pid: pid.unwrap_or_default(),
358            exit_status: exit_code.unwrap_or_default(),
359            exited_at: timestamp.into(),
360            status: status.into(),
361            ..Default::default()
362        })
363    }
364
365    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
366    async fn task_stats(&self, req: StatsRequest) -> Result<StatsResponse> {
367        let i = self.get_instance(req.id()).await?;
368        let pid = i
369            .pid()
370            .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?;
371
372        let metrics = get_metrics(pid)?;
373
374        Ok(StatsResponse {
375            stats: Some(metrics).into(),
376            ..Default::default()
377        })
378    }
379}
380
381impl<T: Instance + Sync + Send, E: EventSender> Task for Local<T, E> {
382    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
383    fn create(
384        &self,
385        _ctx: &TtrpcContext,
386        req: CreateTaskRequest,
387    ) -> TtrpcResult<CreateTaskResponse> {
388        debug!("create: {:?}", req);
389
390        #[cfg(feature = "opentelemetry")]
391        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
392
393        Ok(self.task_create(req).block_on()?)
394    }
395
396    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
397    fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult<StartResponse> {
398        debug!("start: {:?}", req);
399
400        #[cfg(feature = "opentelemetry")]
401        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
402
403        Ok(self.task_start(req).block_on()?)
404    }
405
406    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
407    fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult<Empty> {
408        debug!("kill: {:?}", req);
409
410        #[cfg(feature = "opentelemetry")]
411        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
412
413        Ok(self.task_kill(req).block_on()?)
414    }
415
416    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
417    fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult<DeleteResponse> {
418        debug!("delete: {:?}", req);
419
420        #[cfg(feature = "opentelemetry")]
421        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
422
423        Ok(self.task_delete(req).block_on()?)
424    }
425
426    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
427    fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult<WaitResponse> {
428        debug!("wait: {:?}", req);
429
430        #[cfg(feature = "opentelemetry")]
431        let span_exporter = {
432            use tracing::{Level, Span, span};
433            let parent_span = Span::current();
434            parent_span.set_parent(extract_context(&_ctx.metadata));
435
436            // This future never completes as it runs an infinite loop.
437            // It will stop executing when dropped.
438            // We need to keep this future's lifetime tied to this
439            // method's lifetime.
440            // This means we shouldn't tokio::spawn it, but rather
441            // tokio::select! it inside of this async method.
442            async move {
443                loop {
444                    let current_span =
445                        span!(parent: &parent_span, Level::INFO, "task wait 60s interval");
446                    let _enter = current_span.enter();
447                    tokio::time::sleep(std::time::Duration::from_secs(60)).await;
448                }
449            }
450        };
451
452        #[cfg(not(feature = "opentelemetry"))]
453        let span_exporter = std::future::pending::<()>();
454
455        let res = async {
456            tokio::select! {
457                _ = span_exporter => unreachable!(),
458                res = self.task_wait(req) => res,
459            }
460        }
461        .block_on()?;
462
463        Ok(res)
464    }
465
466    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
467    fn connect(&self, _ctx: &TtrpcContext, req: ConnectRequest) -> TtrpcResult<ConnectResponse> {
468        debug!("connect: {:?}", req);
469
470        #[cfg(feature = "opentelemetry")]
471        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
472
473        let i = self.get_instance(req.id()).block_on()?;
474        let shim_pid = std::process::id();
475        let task_pid = i.pid().unwrap_or_default();
476        Ok(ConnectResponse {
477            shim_pid,
478            task_pid,
479            ..Default::default()
480        })
481    }
482
483    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
484    fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult<StateResponse> {
485        debug!("state: {:?}", req);
486
487        #[cfg(feature = "opentelemetry")]
488        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
489
490        Ok(self.task_state(req).block_on()?)
491    }
492
493    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
494    fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult<Empty> {
495        debug!("shutdown");
496
497        #[cfg(feature = "opentelemetry")]
498        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
499
500        if self.is_empty().block_on() {
501            let _ = self.exit.set(());
502        }
503        Ok(Empty::new())
504    }
505
506    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
507    fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult<StatsResponse> {
508        debug!("stats: {:?}", req);
509
510        #[cfg(feature = "opentelemetry")]
511        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
512
513        Ok(self.task_stats(req).block_on()?)
514    }
515}