use anyhow::Error;
use async_trait::async_trait;
use clap::{Parser, Subcommand};
use shellflip::lifecycle::*;
use shellflip::{RestartConfig, ShutdownCoordinator, ShutdownHandle, ShutdownSignal};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::{pin, select};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[arg(short, long, default_value = "/tmp/restarter.sock")]
socket: String,
}
#[derive(Subcommand)]
enum Commands {
Restart,
}
struct AppData {
restart_generation: u32,
}
#[async_trait]
impl LifecycleHandler for AppData {
async fn send_to_new_process(&mut self, mut write_pipe: PipeWriter) -> std::io::Result<()> {
if self.restart_generation > 4 {
log::info!("Four restarts is more than anybody needs, surely?");
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"The operation completed successfully",
));
}
write_pipe.write_u32(self.restart_generation).await?;
Ok(())
}
}
#[tokio::main]
async fn main() -> Result<(), Error> {
env_logger::init();
let args = Args::parse();
let mut app_data = AppData {
restart_generation: 0,
};
if let Some(mut handover_pipe) = receive_from_old_process() {
app_data.restart_generation = handover_pipe.read_u32().await? + 1;
}
let restart_generation = app_data.restart_generation;
let restart_conf = RestartConfig {
enabled: true,
coordination_socket_path: args.socket.into(),
lifecycle_handler: Box::new(app_data),
..Default::default()
};
match args.command {
Some(Commands::Restart) => {
let res = restart_conf.request_restart().await;
match res {
Ok(id) => {
log::info!("Restart succeeded, child pid is {}", id);
return Ok(());
}
Err(e) => {
log::error!("Restart failed: {}", e);
return Err(e);
}
}
}
None => {}
}
let restart_task = restart_conf.try_into_restart_task()?;
pin!(restart_task);
let shutdown_coordinator = ShutdownCoordinator::new();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
println!(
"Instance no. {} listening on {}",
restart_generation,
listener.local_addr().unwrap()
);
loop {
select! {
res = listener.accept() => {
match res {
Ok((sock, addr)) => {
log::info!("Received connection from {}", addr);
tokio::spawn(echo(sock, shutdown_coordinator.handle()));
}
Err(e) => {
log::warn!("Accept error: {}", e);
}
}
}
res = &mut restart_task => {
match res {
Ok(_) => {
log::info!("Restart successful, waiting for tasks to complete");
}
Err(e) => {
log::error!("Restart task failed: {}", e);
}
}
shutdown_coordinator.shutdown().await;
log::info!("Exiting...");
return Ok(());
}
}
}
}
async fn echo(mut sock: TcpStream, shutdown_handle: Arc<ShutdownHandle>) {
let mut shutdown_signal = ShutdownSignal::from(&*shutdown_handle);
let mut buf = [0u8; 1024];
let out = format!("Hello, this is process {}\n", std::process::id());
let _ = sock.write_all(out.as_bytes()).await;
loop {
select! {
r = sock.read(&mut buf) => {
match r {
Ok(0) => return,
Ok(n) => {
if let Err(e) = sock.write_all(&buf[..n]).await {
log::error!("write failed: {}", e);
return;
}
}
Err(e) => {
log::error!("read failed: {}", e);
return;
}
}
}
_ = shutdown_signal.on_shutdown() => {
log::info!("shutdown requested but client {} is still active", sock.peer_addr().unwrap());
}
}
}
}