use std::{
env, io,
net::{IpAddr, SocketAddr},
str::FromStr,
time::{Duration, Instant},
};
use futures::{Stream, StreamExt};
use nix::unistd::getpid;
use tokio::{
io::{copy, split, AsyncRead, AsyncWrite, AsyncWriteExt},
net::{TcpStream, UnixStream},
};
use ecdysis::tokio_ecdysis::{SignalKind, StopOnShutdown, TokioEcdysisBuilder};
const EXIT_AFTER_UPGRADE_DURATION: Duration = std::time::Duration::from_secs(60);
pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
impl AsyncReadWrite for TcpStream {}
impl AsyncReadWrite for UnixStream {}
async fn echo_server<S, C>(mut sock_stream: S)
where
S: Stream<Item = io::Result<C>> + Unpin + Sized,
C: AsyncReadWrite + Unpin + Send + 'static,
{
let reload_count = env::var("ECDYSIS_RELOADS").unwrap().parse::<u32>().unwrap();
let wg = waitgroup::WaitGroup::new();
while let Some(Ok(mut client)) = sock_stream.next().await {
let w = wg.worker();
let client_fut = async move {
client.write_all(&reload_count.to_be_bytes()).await.unwrap();
let (mut client_r, mut client_w) = split(client);
let _n_bytes = copy(&mut client_r, &mut client_w).await;
drop(w); };
tokio::spawn(client_fut);
}
wg.wait().await;
}
#[tokio::main(flavor = "current_thread")]
async fn main() {
env_logger::init_from_env(
env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"),
);
set_listen_pid();
let reload_count = match env::var("ECDYSIS_RELOADS") {
Err(_) => {
env::set_var("ECDYSIS_RELOADS", "0");
0
}
Ok(s) => {
let reload_count = s.parse::<u32>().unwrap() + 1;
env::set_var("ECDYSIS_RELOADS", format!("{reload_count}").as_str());
reload_count
}
};
let pid = getpid();
if reload_count == 0 {
log::info!("echo server parent started (PID: {pid})");
} else {
log::info!("echo server child started (PID: {pid}; reload count: {reload_count})");
}
let mut ecdysis_builder = TokioEcdysisBuilder::new(SignalKind::hangup()).unwrap();
ecdysis_builder
.stop_on_signal(SignalKind::user_defined1())
.unwrap();
ecdysis_builder
.partial_stop_on_signal(SignalKind::user_defined2())
.unwrap();
ecdysis_builder
.upgrade_on_socket("/tmp/ecdysis_upgrade.sock")
.unwrap();
ecdysis_builder
.stop_on_socket("/tmp/ecdysis_exit.sock")
.unwrap();
ecdysis_builder
.partial_stop_on_socket("/tmp/ecdysis_partial_exit.sock")
.unwrap();
ecdysis_builder.set_pid_file("./pidfile");
#[cfg(feature = "systemd_notify")]
if let Err(err) = ecdysis_builder.enable_systemd_notifications() {
log::info!("Failed to enable systemd notifications: {err}");
}
#[cfg(feature = "systemd_sockets")]
if let Err(err) = ecdysis_builder.read_systemd_sockets() {
log::error!("Failed to read systemd sockets: {err:?}");
}
let ip_addr = match IpAddr::from_str("[::1]") {
Ok(ip_addr) => ip_addr,
Err(_) => IpAddr::from_str("0.0.0.0").unwrap(),
};
let addr = SocketAddr::new(ip_addr, 22222);
log::info!("Address is: {:?}", addr);
let stream = ecdysis_builder
.build_listen_tcp(StopOnShutdown::Yes, addr, |b, addr| {
if ip_addr.is_ipv6() {
b.set_only_v6(true).expect("cannot set v6 here");
}
b.set_reuse_address(true).expect("Cannot set REUSEADDR");
b.bind(&addr.into()).expect("Cannot bind to provided IP");
b.listen(128)?;
Ok(b.into())
})
.unwrap();
let server_handle = tokio::spawn(echo_server(stream));
#[cfg(feature = "systemd_sockets")]
let systemd_server_handle = {
let sd_unix_stream = ecdysis_builder
.systemd_listen_unix(
StopOnShutdown::Yes,
"ecdysis_test_unix".to_string(),
"/tmp/ecdysis_int_test.sock".to_string(),
)
.await
.unwrap();
tokio::spawn(echo_server(sd_unix_stream))
};
let (_tokio_ecdysis, ecdysis_fut) = ecdysis_builder.ready().unwrap();
let exit = ecdysis_fut.await;
log::info!("Shutdown because: {:?}", exit);
let exit_start_time = Instant::now();
tokio::spawn(async {
tokio::time::sleep(EXIT_AFTER_UPGRADE_DURATION).await;
log::info!(
"Force-exiting {:?} after upgrade",
EXIT_AFTER_UPGRADE_DURATION
);
std::process::exit(1)
});
tokio::time::sleep(Duration::from_secs(1)).await;
server_handle.await.unwrap();
#[cfg(feature = "systemd_sockets")]
systemd_server_handle.await.unwrap();
log::info!(
"Graceful exit {:?} after ecdysis stop (reason: {exit:?})",
exit_start_time.elapsed()
);
}
fn set_listen_pid() {
match env::var("LISTEN_FDNAMES") {
Ok(v) => {
if v.is_empty() {
return;
}
}
Err(_) => return,
}
let pid = format!("{}", std::process::id());
env::set_var("LISTEN_PID", pid);
log::info!("LISTEN_PID updated");
}