use std::{
env,
io::Read,
os::unix::{fs::FileTypeExt, net::UnixListener},
path::Path,
process::{self, Command as StdCommand, Stdio},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{ready, Poll},
};
use async_trait::async_trait;
use containerd_shim_protos::{
api::DeleteResponse,
protobuf::{well_known_types::any::Any, Message, MessageField},
shim::oci::Options,
shim_async::{create_task, Client, Task},
ttrpc::r#async::Server,
types::introspection::{self, RuntimeInfo},
};
use futures::stream::{poll_fn, BoxStream, SelectAll, StreamExt};
use libc::{SIGCHLD, SIGINT, SIGPIPE, SIGTERM};
use log::{debug, error, info, warn};
use nix::{
errno::Errno,
sys::{
signal::Signal,
wait::{self, WaitPidFlag, WaitStatus},
},
unistd::Pid,
};
use oci_spec::runtime::Features;
use tokio::{io::AsyncWriteExt, process::Command, sync::Notify};
use which::which;
const DEFAULT_BINARY_NAME: &str = "runc";
use crate::{
args,
asynchronous::{monitor::monitor_notify_by_pid, publisher::RemotePublisher},
error::{Error, Result},
logger, parse_sockaddr, reap, socket_address,
util::{asyncify, read_file_to_str, write_str_to_file},
Config, Flags, StartOpts, TTRPC_ADDRESS,
};
pub mod monitor;
pub mod publisher;
pub mod util;
#[async_trait]
pub trait Shim {
type T: Task + Send + Sync;
async fn new(runtime_id: &str, args: &Flags, config: &mut Config) -> Self;
async fn start_shim(&mut self, opts: StartOpts) -> Result<String>;
async fn delete_shim(&mut self) -> Result<DeleteResponse>;
async fn wait(&mut self);
async fn create_task_service(&self, publisher: RemotePublisher) -> Self::T;
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
pub async fn run<T>(runtime_id: &str, opts: Option<Config>)
where
T: Shim + Send + Sync + 'static,
{
if let Some(err) = bootstrap::<T>(runtime_id, opts).await.err() {
eprintln!("{}: {:?}", runtime_id, err);
process::exit(1);
}
}
pub fn run_info() -> Result<RuntimeInfo> {
let mut info = introspection::RuntimeInfo {
name: "containerd-shim-runc-v2-rs".to_string(),
version: MessageField::some(introspection::RuntimeVersion {
version: env!("CARGO_PKG_VERSION").to_string(),
revision: String::default(),
..Default::default()
}),
..Default::default()
};
let mut binary_name = DEFAULT_BINARY_NAME.to_string();
let mut data: Vec<u8> = Vec::new();
std::io::stdin()
.read_to_end(&mut data)
.map_err(io_error!(e, "read stdin"))?;
if !data.is_empty() {
let opts =
Any::parse_from_bytes(&data).and_then(|any| Options::parse_from_bytes(&any.value))?;
if !opts.binary_name().is_empty() {
binary_name = opts.binary_name().to_string();
}
}
let binary_path = which(binary_name).unwrap();
let output = StdCommand::new(binary_path)
.arg("features")
.output()
.unwrap();
let features: Features = serde_json::from_str(&String::from_utf8_lossy(&output.stdout))?;
let features_any = Any {
type_url: "types.containerd.io/opencontainers/runtime-spec/1/features/Features".to_string(),
value: serde_json::to_vec(&features)?,
..Default::default()
};
info.features = MessageField::some(features_any);
Ok(info)
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
async fn bootstrap<T>(runtime_id: &str, opts: Option<Config>) -> Result<()>
where
T: Shim + Send + Sync + 'static,
{
let os_args: Vec<_> = env::args_os().collect();
let flags = args::parse(&os_args[1..])?;
let ttrpc_address = env::var(TTRPC_ADDRESS)?;
let mut config = opts.unwrap_or_default();
let signals = setup_signals_tokio(&config);
if !config.no_sub_reaper {
reap::set_subreaper()?;
}
let mut shim = T::new(runtime_id, &flags, &mut config).await;
match flags.action.as_str() {
"start" => {
let args = StartOpts {
id: flags.id,
publish_binary: flags.publish_binary,
address: flags.address,
ttrpc_address,
namespace: flags.namespace,
debug: flags.debug,
};
let address = shim.start_shim(args).await?;
let mut stdout = tokio::io::stdout();
stdout
.write_all(address.as_bytes())
.await
.map_err(io_error!(e, "write stdout"))?;
stdout.flush().await.map_err(io_error!(e, "flush stdout"))?;
Ok(())
}
"delete" => {
tokio::spawn(async move {
handle_signals(signals).await;
});
let response = shim.delete_shim().await?;
let resp_bytes = response.write_to_bytes()?;
tokio::io::stdout()
.write_all(resp_bytes.as_slice())
.await
.map_err(io_error!(e, "failed to write response"))?;
Ok(())
}
_ => {
if flags.socket.is_empty() {
return Err(Error::InvalidArgument(String::from(
"Shim socket cannot be empty",
)));
}
if !config.no_setup_logger {
logger::init(
flags.debug,
&config.default_log_level,
&flags.namespace,
&flags.id,
)?;
}
let publisher = RemotePublisher::new(&ttrpc_address).await?;
let task = Box::new(shim.create_task_service(publisher).await)
as Box<dyn containerd_shim_protos::shim_async::Task + Send + Sync>;
let task_service = create_task(Arc::from(task));
let Some(mut server) = create_server_with_retry(&flags).await? else {
signal_server_started();
return Ok(());
};
server = server.register_service(task_service);
server.start().await?;
signal_server_started();
info!("Shim successfully started, waiting for exit signal...");
tokio::spawn(async move {
handle_signals(signals).await;
});
shim.wait().await;
info!("Shutting down shim instance");
server.shutdown().await.unwrap_or_default();
if let Ok(address) = read_file_to_str("address").await {
remove_socket_silently(&address).await;
}
Ok(())
}
}
}
pub struct ExitSignal {
notifier: Notify,
exited: AtomicBool,
}
impl Default for ExitSignal {
fn default() -> Self {
ExitSignal {
notifier: Notify::new(),
exited: AtomicBool::new(false),
}
}
}
impl ExitSignal {
pub fn signal(&self) {
self.exited.store(true, Ordering::SeqCst);
self.notifier.notify_waiters();
}
pub async fn wait(&self) {
loop {
let notified = self.notifier.notified();
if self.exited.load(Ordering::SeqCst) {
return;
}
notified.await;
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
pub async fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> Result<String> {
let cmd = env::current_exe().map_err(io_error!(e, ""))?;
let cwd = env::current_dir().map_err(io_error!(e, ""))?;
let address = socket_address(&opts.address, &opts.namespace, grouping);
let mut command = Command::new(cmd);
command
.current_dir(cwd)
.stdout(Stdio::piped())
.stdin(Stdio::null())
.stderr(Stdio::null())
.envs(vars)
.args([
"-namespace",
&opts.namespace,
"-id",
&opts.id,
"-address",
&opts.address,
"-socket",
&address,
]);
if opts.debug {
command.arg("-debug");
}
let mut child = command.spawn().map_err(io_error!(e, "spawn shim"))?;
#[cfg(target_os = "linux")]
crate::cgroup::set_cgroup_and_oom_score(child.id().unwrap())?;
let mut reader = child.stdout.take().unwrap();
tokio::io::copy(&mut reader, &mut tokio::io::stderr())
.await
.unwrap();
Ok(address)
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))]
async fn create_server(flags: &args::Flags) -> Result<Server> {
use containerd_shim_protos::ttrpc::r#async::transport::Listener;
let listener = start_listener(&flags.socket).await?;
let listener = Listener::try_from(listener).map_err(io_error!(e, "creating ttrpc listener"))?;
let server = Server::new().add_listener(listener);
Ok(server)
}
async fn create_server_with_retry(flags: &args::Flags) -> Result<Option<Server>> {
let server = match create_server(flags).await {
Ok(server) => server,
Err(Error::IoError { err, .. }) if err.kind() == std::io::ErrorKind::AddrInUse => {
if let Ok(()) = wait_socket_working(&flags.socket, 5, 200).await {
write_str_to_file("address", &flags.socket).await?;
return Ok(None);
}
remove_socket(&flags.socket).await?;
create_server(flags).await?
}
Err(e) => return Err(e),
};
Ok(Some(server))
}
fn signal_server_started() {
use libc::{dup2, STDERR_FILENO, STDOUT_FILENO};
unsafe {
if dup2(STDERR_FILENO, STDOUT_FILENO) < 0 {
panic!("Error closing pipe: {}", std::io::Error::last_os_error())
}
}
}
#[cfg(unix)]
fn signal_stream(kind: i32) -> std::io::Result<BoxStream<'static, i32>> {
use tokio::signal::unix::{signal, SignalKind};
let kind = SignalKind::from_raw(kind);
signal(kind).map(|mut sig| {
poll_fn(move |cx| {
ready!(sig.poll_recv(cx));
Poll::Ready(Some(kind.as_raw_value()))
})
.boxed()
})
}
#[cfg(windows)]
fn signal_stream(kind: i32) -> std::io::Result<BoxStream<'static, i32>> {
use tokio::signal::windows::ctrl_c;
if kind != SIGINT {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Invalid signal {kind}"),
));
}
ctrl_c().map(|mut sig| {
poll_fn(move |cx| {
ready!(sig.poll_recv(cx));
Poll::Ready(Some(kind))
})
.boxed()
})
}
type Signals = SelectAll<BoxStream<'static, i32>>;
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))]
fn setup_signals_tokio(config: &Config) -> Signals {
#[cfg(unix)]
let signals: &[i32] = if config.no_reaper {
&[SIGTERM, SIGINT, SIGPIPE]
} else {
&[SIGTERM, SIGINT, SIGPIPE, SIGCHLD]
};
#[cfg(windows)]
let signals: &[i32] = &[SIGINT];
let signals: Vec<_> = signals
.iter()
.copied()
.map(signal_stream)
.collect::<std::io::Result<_>>()
.expect("signal setup failed");
SelectAll::from_iter(signals)
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))]
async fn handle_signals(signals: Signals) {
let mut signals = signals.fuse();
while let Some(sig) = signals.next().await {
match sig {
SIGPIPE => {}
SIGTERM | SIGINT => {
debug!("received {}", sig);
}
SIGCHLD => loop {
match wait::waitpid(Some(Pid::from_raw(-1)), Some(WaitPidFlag::WNOHANG)) {
Ok(WaitStatus::Exited(pid, status)) => {
monitor_notify_by_pid(pid.as_raw(), status)
.await
.unwrap_or_else(|e| error!("failed to send exit event {}", e))
}
Ok(WaitStatus::Signaled(pid, sig, _)) => {
debug!("child {} terminated({})", pid, sig);
let exit_code = 128 + sig as i32;
monitor_notify_by_pid(pid.as_raw(), exit_code)
.await
.unwrap_or_else(|e| error!("failed to send signal event {}", e))
}
Ok(WaitStatus::StillAlive) => {
break;
}
Err(Errno::ECHILD) => {
break;
}
Err(e) => {
warn!("error occurred in signal handler: {}", e);
}
_ => {}
}
},
_ => {
if let Ok(sig) = Signal::try_from(sig) {
debug!("received {}", sig);
} else {
warn!("received invalid signal {}", sig);
}
}
}
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
async fn remove_socket_silently(address: &str) {
remove_socket(address)
.await
.unwrap_or_else(|e| warn!("failed to remove socket: {}", e))
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
async fn remove_socket(address: &str) -> Result<()> {
let path = parse_sockaddr(address);
if let Ok(md) = Path::new(path).metadata() {
if md.file_type().is_socket() {
tokio::fs::remove_file(path).await.map_err(io_error!(
e,
"failed to remove socket {}",
address
))?;
}
}
Ok(())
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "info"))]
async fn start_listener(address: &str) -> Result<UnixListener> {
let addr = address.to_string();
asyncify(move || -> Result<UnixListener> {
crate::start_listener(&addr).map_err(|e| Error::IoError {
context: format!("failed to start listener {}", addr),
err: e,
})
})
.await
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "info"))]
async fn wait_socket_working(address: &str, interval_in_ms: u64, count: u32) -> Result<()> {
for _i in 0..count {
match Client::connect(address).await {
Ok(_) => {
return Ok(());
}
Err(_) => {
tokio::time::sleep(std::time::Duration::from_millis(interval_in_ms)).await;
}
}
}
Err(other!("time out waiting for socket {}", address))
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::asynchronous::{start_listener, ExitSignal};
#[tokio::test]
async fn test_exit_signal() {
let signal = Arc::new(ExitSignal::default());
let cloned = signal.clone();
let handle = tokio::spawn(async move {
cloned.wait().await;
});
signal.signal();
if let Err(err) = handle.await {
panic!("{:?}", err);
}
}
#[tokio::test]
async fn test_start_listener() {
let tmpdir = tempfile::tempdir().unwrap();
let path = tmpdir.path().to_str().unwrap().to_owned();
let socket = path + "/ns1/id1/socket";
let _listener = start_listener(&socket).await.unwrap();
let _listener2 = start_listener(&socket)
.await
.expect_err("socket should already in use");
let socket2 = socket + "/socket";
assert!(start_listener(&socket2).await.is_err());
let path = tmpdir.path().to_str().unwrap().to_owned();
let txt_file = path + "/demo.txt";
tokio::fs::write(&txt_file, "test").await.unwrap();
assert!(start_listener(&txt_file).await.is_err());
let context = tokio::fs::read_to_string(&txt_file).await.unwrap();
assert_eq!(context, "test");
}
}