use std::env;
use std::fs;
use std::io::Write;
use std::os::unix::fs::FileTypeExt;
use std::os::unix::io::AsRawFd;
use std::path::Path;
use std::process::{self, Command, Stdio};
use std::sync::{Arc, Condvar, Mutex};
use command_fds::{CommandFdExt, FdMapping};
use libc::{c_int, pid_t, SIGCHLD, SIGINT, SIGPIPE, SIGTERM};
pub use log::{debug, error, info, warn};
use signal_hook::iterator::Signals;
use crate::protos::protobuf::Message;
use crate::protos::shim::shim_ttrpc::{create_task, Task};
use crate::protos::ttrpc::{Client, Server};
use util::{read_address, write_address};
use crate::api::DeleteResponse;
use crate::synchronous::publisher::RemotePublisher;
use crate::Error;
use crate::{args, logger, reap, Result, TTRPC_ADDRESS};
use crate::{parse_sockaddr, socket_address, start_listener, Config, StartOpts, SOCKET_FD};
pub mod monitor;
pub mod publisher;
pub mod util;
pub mod console;
#[allow(clippy::mutex_atomic)] #[derive(Default)]
pub struct ExitSignal(Mutex<bool>, Condvar);
#[allow(clippy::mutex_atomic)]
impl ExitSignal {
pub fn signal(&self) {
let (lock, cvar) = (&self.0, &self.1);
let mut exit = lock.lock().unwrap();
*exit = true;
cvar.notify_all();
}
pub fn wait(&self) {
let (lock, cvar) = (&self.0, &self.1);
let mut started = lock.lock().unwrap();
while !*started {
started = cvar.wait(started).unwrap();
}
}
}
pub trait Shim {
type T: Task + Send + Sync;
fn new(runtime_id: &str, id: &str, namespace: &str, config: &mut Config) -> Self;
fn start_shim(&mut self, opts: StartOpts) -> Result<String>;
fn delete_shim(&mut self) -> Result<DeleteResponse>;
fn wait(&mut self);
fn create_task_service(&self, publisher: RemotePublisher) -> Self::T;
}
pub fn run<T>(runtime_id: &str, opts: Option<Config>)
where
T: Shim + Send + Sync + 'static,
{
if let Some(err) = bootstrap::<T>(runtime_id, opts).err() {
eprintln!("{}: {:?}", runtime_id, err);
process::exit(1);
}
}
fn bootstrap<T>(runtime_id: &str, opts: Option<Config>) -> Result<()>
where
T: Shim + Send + Sync + 'static,
{
let os_args: Vec<_> = env::args_os().collect();
let flags = args::parse(&os_args[1..])?;
let ttrpc_address = env::var(TTRPC_ADDRESS)?;
let mut config = opts.unwrap_or_else(Config::default);
let signals = setup_signals(&config);
if !config.no_sub_reaper {
reap::set_subreaper()?;
}
let mut shim = T::new(runtime_id, &flags.id, &flags.namespace, &mut config);
match flags.action.as_str() {
"start" => {
let args = StartOpts {
id: flags.id,
publish_binary: flags.publish_binary,
address: flags.address,
ttrpc_address,
namespace: flags.namespace,
debug: flags.debug,
};
let address = shim.start_shim(args)?;
std::io::stdout()
.lock()
.write_fmt(format_args!("{}", address))
.map_err(io_error!(e, "write stdout"))?;
Ok(())
}
"delete" => {
std::thread::spawn(move || handle_signals(signals));
let response = shim.delete_shim()?;
let stdout = std::io::stdout();
let mut locked = stdout.lock();
response.write_to_writer(&mut locked)?;
Ok(())
}
_ => {
if !config.no_setup_logger {
logger::init(flags.debug)?;
}
let publisher = publisher::RemotePublisher::new(&ttrpc_address)?;
let task = shim.create_task_service(publisher);
let task_service = create_task(Arc::new(Box::new(task)));
let mut server = Server::new().register_service(task_service);
server = server.add_listener(SOCKET_FD)?;
server.start()?;
info!("Shim successfully started, waiting for exit signal...");
std::thread::spawn(move || handle_signals(signals));
shim.wait();
info!("Shutting down shim instance");
server.shutdown();
let address = read_address()?;
remove_socket_silently(&address);
Ok(())
}
}
}
fn setup_signals(config: &Config) -> Signals {
let signals = Signals::new(&[SIGTERM, SIGINT, SIGPIPE]).expect("new signal failed");
if !config.no_reaper {
signals.add_signal(SIGCHLD).expect("add signal failed");
}
signals
}
fn handle_signals(mut signals: Signals) {
loop {
for sig in signals.wait() {
match sig {
SIGTERM | SIGINT => {
debug!("received {}", sig);
return;
}
SIGCHLD => loop {
unsafe {
let pid: pid_t = -1;
let mut status: c_int = 0;
let options: c_int = libc::WNOHANG;
let res_pid = libc::waitpid(pid, &mut status, options);
let status = libc::WEXITSTATUS(status);
if res_pid <= 0 {
break;
} else {
monitor::monitor_notify_by_pid(res_pid, status).unwrap_or_else(|e| {
error!("failed to send exit event {}", e);
});
}
}
},
_ => {
debug!("received {}", sig);
}
}
}
}
}
fn wait_socket_working(address: &str, interval_in_ms: u64, count: u32) -> Result<()> {
for _i in 0..count {
match Client::connect(address) {
Ok(_) => {
return Ok(());
}
Err(_) => {
std::thread::sleep(std::time::Duration::from_millis(interval_in_ms));
}
}
}
Err(other!("time out waiting for socket {}", address))
}
fn remove_socket_silently(address: &str) {
remove_socket(address).unwrap_or_else(|e| warn!("failed to remove file {} {:?}", address, e))
}
fn remove_socket(address: &str) -> Result<()> {
let path = parse_sockaddr(address);
if let Ok(md) = Path::new(path).metadata() {
if md.file_type().is_socket() {
fs::remove_file(path).map_err(io_error!(e, "remove socket"))?;
}
}
Ok(())
}
pub fn spawn(opts: StartOpts, grouping: &str, vars: Vec<(&str, &str)>) -> Result<(u32, String)> {
let cmd = env::current_exe().map_err(io_error!(e, ""))?;
let cwd = env::current_dir().map_err(io_error!(e, ""))?;
let address = socket_address(&opts.address, &opts.namespace, grouping);
let listener = match start_listener(&address) {
Ok(l) => l,
Err(e) => {
if e.kind() != std::io::ErrorKind::AddrInUse {
return Err(Error::IoError {
context: "".to_string(),
err: e,
});
};
if let Ok(()) = wait_socket_working(&address, 5, 200) {
write_address(&address)?;
return Ok((0, address));
}
remove_socket(&address)?;
start_listener(&address).map_err(io_error!(e, ""))?
}
};
let mut command = Command::new(cmd);
command
.current_dir(cwd)
.stdout(Stdio::null())
.stdin(Stdio::null())
.stderr(Stdio::null())
.fd_mappings(vec![FdMapping {
parent_fd: listener.as_raw_fd(),
child_fd: SOCKET_FD,
}])?
.args(&[
"-namespace",
&opts.namespace,
"-id",
&opts.id,
"-address",
&opts.address,
]);
if opts.debug {
command.arg("-debug");
}
command.envs(vars);
command
.spawn()
.map_err(io_error!(e, "spawn shim"))
.map(|child| {
std::mem::forget(listener);
(child.id(), address)
})
}
#[cfg(test)]
mod tests {
use std::thread;
use super::*;
#[test]
fn exit_signal() {
let signal = Arc::new(ExitSignal::default());
let cloned = Arc::clone(&signal);
let handle = thread::spawn(move || {
cloned.signal();
});
signal.wait();
if let Err(err) = handle.join() {
panic!("{:?}", err);
}
}
}