1use std::collections::HashMap;
2use std::fs::create_dir_all;
3use std::ops::Not;
4use std::path::Path;
5use std::sync::Arc;
6
7use anyhow::ensure;
8use containerd_shim::api::{
9 ConnectRequest, ConnectResponse, CreateTaskRequest, CreateTaskResponse, DeleteRequest, Empty,
10 KillRequest, ShutdownRequest, StartRequest, StartResponse, StateRequest, StateResponse,
11 StatsRequest, StatsResponse, WaitRequest, WaitResponse,
12};
13use containerd_shim::error::Error as ShimError;
14use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, TaskIO, TaskStart};
15use containerd_shim::protos::shim::shim_ttrpc::Task;
16use containerd_shim::protos::types::task::Status;
17use containerd_shim::util::IntoOption;
18use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult};
19use futures::FutureExt as _;
20use log::debug;
21use oci_spec::runtime::Spec;
22use prost::Message;
23use protobuf::well_known_types::any::Any;
24use serde::{Deserialize, Serialize};
25use tokio::sync::RwLock;
26#[cfg(feature = "opentelemetry")]
27use tracing_opentelemetry::OpenTelemetrySpanExt as _;
28
29#[cfg(feature = "opentelemetry")]
30use super::otel::extract_context;
31use crate::sandbox::async_utils::AmbientRuntime as _;
32use crate::sandbox::instance::{Instance, InstanceConfig};
33use crate::sandbox::shim::events::{EventSender, RemoteEventSender, ToTimestamp};
34use crate::sandbox::shim::instance_data::InstanceData;
35use crate::sandbox::sync::WaitableCell;
36use crate::sandbox::{Error, Result, oci};
37use crate::sys::metrics::get_metrics;
38
39#[cfg(test)]
40mod tests;
41
42#[derive(Message, Clone, PartialEq)]
44struct Options {
45 #[prost(string)]
46 type_url: String,
47 #[prost(string)]
48 config_path: String,
49 #[prost(string)]
50 config_body: String,
51}
52
53#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Debug)]
57pub struct Config {
58 #[serde(alias = "SystemdCgroup")]
60 pub systemd_cgroup: bool,
61}
62
63impl Config {
64 fn get_from_options(options: Option<&Any>) -> anyhow::Result<Self> {
65 let Some(opts) = options else {
66 return Ok(Default::default());
67 };
68
69 ensure!(
70 opts.type_url == "runtimeoptions.v1.Options",
71 "Invalid options type {}",
72 opts.type_url
73 );
74
75 let opts = Options::decode(opts.value.as_slice())?;
76
77 let config = toml::from_str(opts.config_body.as_str())
78 .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;
79
80 Ok(config)
81 }
82}
83
84type LocalInstances<T> = RwLock<HashMap<String, Arc<InstanceData<T>>>>;
85
86pub struct Local<T: Instance + Send + Sync, E: EventSender = RemoteEventSender> {
89 pub(super) instances: LocalInstances<T>,
90 events: E,
91 exit: WaitableCell<()>,
92 namespace: String,
93 containerd_address: String,
94}
95
96impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
97 #[cfg_attr(
99 feature = "tracing",
100 tracing::instrument(skip(events, exit), level = "Debug")
101 )]
102 pub fn new(
103 events: E,
104 exit: WaitableCell<()>,
105 namespace: impl AsRef<str> + std::fmt::Debug,
106 containerd_address: impl AsRef<str> + std::fmt::Debug,
107 ) -> Self {
108 let instances = RwLock::default();
109 let namespace = namespace.as_ref().to_string();
110 let containerd_address = containerd_address.as_ref().to_string();
111 Self {
112 instances,
113 events,
114 exit,
115 namespace,
116 containerd_address,
117 }
118 }
119
120 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
121 pub(super) async fn get_instance(&self, id: &str) -> Result<Arc<InstanceData<T>>> {
122 let instance = self.instances.read().await.get(id).cloned();
123 instance.ok_or_else(|| Error::NotFound(id.to_string()))
124 }
125
126 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
127 async fn has_instance(&self, id: &str) -> bool {
128 self.instances.read().await.contains_key(id)
129 }
130
131 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
132 async fn is_empty(&self) -> bool {
133 self.instances.read().await.is_empty()
134 }
135}
136
137impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
139 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
140 async fn task_create(&self, req: CreateTaskRequest) -> Result<CreateTaskResponse> {
141 let config = Config::get_from_options(req.options.as_ref())
142 .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;
143
144 if !req.checkpoint().is_empty() || !req.parent_checkpoint().is_empty() {
145 return Err(ShimError::Unimplemented("checkpoint is not supported".to_string()).into());
146 }
147
148 if req.terminal {
149 return Err(Error::InvalidArgument(
150 "terminal is not supported".to_string(),
151 ));
152 }
153
154 if self.has_instance(&req.id).await {
155 return Err(Error::AlreadyExists(req.id));
156 }
157
158 let mut spec = Spec::load(Path::new(&req.bundle).join("config.json"))
159 .map_err(|err| Error::InvalidArgument(format!("could not load runtime spec: {err}")))?;
160
161 spec.canonicalize_rootfs(req.bundle()).map_err(|err| {
162 ShimError::InvalidArgument(format!("could not canonicalize rootfs: {}", err))
163 })?;
164
165 let rootfs = spec
166 .root()
167 .as_ref()
168 .ok_or_else(|| Error::InvalidArgument("rootfs is not set in runtime spec".to_string()))?
169 .path();
170
171 let _ = create_dir_all(rootfs);
172 let rootfs_mounts = req.rootfs().to_vec();
173 if !rootfs_mounts.is_empty() {
174 for m in rootfs_mounts {
175 let _mount_type = m.type_().none_if(|&x| x.is_empty());
176 let _source = m.source.as_str().none_if(|&x| x.is_empty());
177
178 #[cfg(unix)]
179 containerd_shim::mount::mount_rootfs(
180 _mount_type,
181 _source,
182 &m.options.to_vec(),
183 rootfs,
184 )?;
185 }
186 }
187
188 let cfg = InstanceConfig {
189 namespace: self.namespace.clone(),
190 containerd_address: self.containerd_address.clone(),
191 bundle: req.bundle.as_str().into(),
192 stdout: req.stdout.as_str().into(),
193 stderr: req.stderr.as_str().into(),
194 stdin: req.stdin.as_str().into(),
195 config,
196 };
197
198 let instance = InstanceData::new(req.id(), cfg).await?;
200
201 self.instances
202 .write()
203 .await
204 .insert(req.id().to_string(), Arc::new(instance));
205
206 self.events.send(TaskCreate {
207 container_id: req.id,
208 bundle: req.bundle,
209 rootfs: req.rootfs,
210 io: Some(TaskIO {
211 stdin: req.stdin,
212 stdout: req.stdout,
213 stderr: req.stderr,
214 ..Default::default()
215 })
216 .into(),
217 ..Default::default()
218 });
219
220 debug!("create done");
221
222 debug!("call prehook before the start");
224 oci::setup_prestart_hooks(spec.hooks())?;
225
226 Ok(CreateTaskResponse {
227 pid: std::process::id(),
228 ..Default::default()
229 })
230 }
231
232 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
233 async fn task_start(&self, req: StartRequest) -> Result<StartResponse> {
234 if req.exec_id().is_empty().not() {
235 return Err(ShimError::Unimplemented("exec is not supported".to_string()).into());
236 }
237
238 let i = self.get_instance(req.id()).await?;
239 let pid = i.start().await?;
240
241 self.events.send(TaskStart {
242 container_id: req.id().into(),
243 pid,
244 ..Default::default()
245 });
246
247 let events = self.events.clone();
248
249 let id = req.id().to_string();
250
251 async move {
252 let (exit_code, timestamp) = i.wait().await;
253 events.send(TaskExit {
254 container_id: id.clone(),
255 exit_status: exit_code,
256 exited_at: Some(timestamp.to_timestamp()).into(),
257 pid,
258 id,
259 ..Default::default()
260 });
261 }
262 .spawn();
263
264 debug!("started: {:?}", req);
265
266 Ok(StartResponse {
267 pid,
268 ..Default::default()
269 })
270 }
271
272 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
273 async fn task_kill(&self, req: KillRequest) -> Result<Empty> {
274 if !req.exec_id().is_empty() {
275 return Err(Error::InvalidArgument("exec is not supported".to_string()));
276 }
277 self.get_instance(req.id())
278 .await?
279 .kill(req.signal())
280 .await?;
281 Ok(Empty::new())
282 }
283
284 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
285 async fn task_delete(&self, req: DeleteRequest) -> Result<DeleteResponse> {
286 if !req.exec_id().is_empty() {
287 return Err(Error::InvalidArgument("exec is not supported".to_string()));
288 }
289
290 let i = self.get_instance(req.id()).await?;
291
292 i.delete().await?;
293
294 let pid = i.pid().unwrap_or_default();
295 let (exit_code, timestamp) = i.wait().now_or_never().unzip();
296 let timestamp = timestamp.map(ToTimestamp::to_timestamp);
297
298 self.instances.write().await.remove(req.id());
299
300 self.events.send(TaskDelete {
301 container_id: req.id().into(),
302 pid,
303 exit_status: exit_code.unwrap_or_default(),
304 exited_at: timestamp.clone().into(),
305 ..Default::default()
306 });
307
308 Ok(DeleteResponse {
309 pid,
310 exit_status: exit_code.unwrap_or_default(),
311 exited_at: timestamp.into(),
312 ..Default::default()
313 })
314 }
315
316 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
317 async fn task_wait(&self, req: WaitRequest) -> Result<WaitResponse> {
318 if !req.exec_id().is_empty() {
319 return Err(Error::InvalidArgument("exec is not supported".to_string()));
320 }
321
322 let i = self.get_instance(req.id()).await?;
323 let (exit_code, timestamp) = i.wait().await;
324
325 debug!("wait finishes");
326 Ok(WaitResponse {
327 exit_status: exit_code,
328 exited_at: Some(timestamp.to_timestamp()).into(),
329 ..Default::default()
330 })
331 }
332
333 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
334 async fn task_state(&self, req: StateRequest) -> Result<StateResponse> {
335 if !req.exec_id().is_empty() {
336 return Err(Error::InvalidArgument("exec is not supported".to_string()));
337 }
338
339 let i = self.get_instance(req.id()).await?;
340 let pid = i.pid();
341 let (exit_code, timestamp) = i.wait().now_or_never().unzip();
342 let timestamp = timestamp.map(ToTimestamp::to_timestamp);
343
344 let status = if pid.is_none() {
345 Status::CREATED
346 } else if exit_code.is_none() {
347 Status::RUNNING
348 } else {
349 Status::STOPPED
350 };
351
352 Ok(StateResponse {
353 bundle: i.config.bundle.to_string_lossy().to_string(),
354 stdin: i.config.stdin.to_string_lossy().to_string(),
355 stdout: i.config.stdout.to_string_lossy().to_string(),
356 stderr: i.config.stderr.to_string_lossy().to_string(),
357 pid: pid.unwrap_or_default(),
358 exit_status: exit_code.unwrap_or_default(),
359 exited_at: timestamp.into(),
360 status: status.into(),
361 ..Default::default()
362 })
363 }
364
365 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
366 async fn task_stats(&self, req: StatsRequest) -> Result<StatsResponse> {
367 let i = self.get_instance(req.id()).await?;
368 let pid = i
369 .pid()
370 .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?;
371
372 let metrics = get_metrics(pid)?;
373
374 Ok(StatsResponse {
375 stats: Some(metrics).into(),
376 ..Default::default()
377 })
378 }
379}
380
381impl<T: Instance + Sync + Send, E: EventSender> Task for Local<T, E> {
382 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
383 fn create(
384 &self,
385 _ctx: &TtrpcContext,
386 req: CreateTaskRequest,
387 ) -> TtrpcResult<CreateTaskResponse> {
388 debug!("create: {:?}", req);
389
390 #[cfg(feature = "opentelemetry")]
391 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
392
393 Ok(self.task_create(req).block_on()?)
394 }
395
396 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
397 fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult<StartResponse> {
398 debug!("start: {:?}", req);
399
400 #[cfg(feature = "opentelemetry")]
401 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
402
403 Ok(self.task_start(req).block_on()?)
404 }
405
406 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
407 fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult<Empty> {
408 debug!("kill: {:?}", req);
409
410 #[cfg(feature = "opentelemetry")]
411 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
412
413 Ok(self.task_kill(req).block_on()?)
414 }
415
416 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
417 fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult<DeleteResponse> {
418 debug!("delete: {:?}", req);
419
420 #[cfg(feature = "opentelemetry")]
421 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
422
423 Ok(self.task_delete(req).block_on()?)
424 }
425
426 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
427 fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult<WaitResponse> {
428 debug!("wait: {:?}", req);
429
430 #[cfg(feature = "opentelemetry")]
431 let span_exporter = {
432 use tracing::{Level, Span, span};
433 let parent_span = Span::current();
434 parent_span.set_parent(extract_context(&_ctx.metadata));
435
436 async move {
443 loop {
444 let current_span =
445 span!(parent: &parent_span, Level::INFO, "task wait 60s interval");
446 let _enter = current_span.enter();
447 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
448 }
449 }
450 };
451
452 #[cfg(not(feature = "opentelemetry"))]
453 let span_exporter = std::future::pending::<()>();
454
455 let res = async {
456 tokio::select! {
457 _ = span_exporter => unreachable!(),
458 res = self.task_wait(req) => res,
459 }
460 }
461 .block_on()?;
462
463 Ok(res)
464 }
465
466 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
467 fn connect(&self, _ctx: &TtrpcContext, req: ConnectRequest) -> TtrpcResult<ConnectResponse> {
468 debug!("connect: {:?}", req);
469
470 #[cfg(feature = "opentelemetry")]
471 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
472
473 let i = self.get_instance(req.id()).block_on()?;
474 let shim_pid = std::process::id();
475 let task_pid = i.pid().unwrap_or_default();
476 Ok(ConnectResponse {
477 shim_pid,
478 task_pid,
479 ..Default::default()
480 })
481 }
482
483 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
484 fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult<StateResponse> {
485 debug!("state: {:?}", req);
486
487 #[cfg(feature = "opentelemetry")]
488 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
489
490 Ok(self.task_state(req).block_on()?)
491 }
492
493 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
494 fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult<Empty> {
495 debug!("shutdown");
496
497 #[cfg(feature = "opentelemetry")]
498 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
499
500 if self.is_empty().block_on() {
501 let _ = self.exit.set(());
502 }
503 Ok(Empty::new())
504 }
505
506 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
507 fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult<StatsResponse> {
508 debug!("stats: {:?}", req);
509
510 #[cfg(feature = "opentelemetry")]
511 tracing::Span::current().set_parent(extract_context(&_ctx.metadata));
512
513 Ok(self.task_stats(req).block_on()?)
514 }
515}