use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::signal;
use tokio::sync::watch;
pub struct SessionCancellation {
current_op_tx: Arc<Mutex<watch::Sender<bool>>>,
current_op_rx: watch::Receiver<bool>,
first_interrupt: Arc<AtomicBool>,
}
impl Default for SessionCancellation {
fn default() -> Self {
Self::new()
}
}
impl SessionCancellation {
pub fn new() -> Self {
let (tx, rx) = watch::channel(false);
Self {
current_op_tx: Arc::new(Mutex::new(tx)),
current_op_rx: rx,
first_interrupt: Arc::new(AtomicBool::new(false)),
}
}
pub fn operation_receiver(&self) -> watch::Receiver<bool> {
self.current_op_rx.clone()
}
pub fn start_signal_handler(&self) -> tokio::task::JoinHandle<()> {
let op_tx = self.current_op_tx.clone();
let first_interrupt = self.first_interrupt.clone();
tokio::spawn(async move {
#[cfg(unix)]
{
use signal::unix::{signal, SignalKind};
let sigint = match signal(SignalKind::interrupt()) {
Ok(sig) => sig,
Err(e) => {
crate::log_error!("Warning: Failed to register SIGINT handler: {}", e);
return;
}
};
let sigterm = match signal(SignalKind::terminate()) {
Ok(sig) => sig,
Err(e) => {
crate::log_error!("Warning: Failed to register SIGTERM handler: {}", e);
return;
}
};
let mut sigint = sigint;
let mut sigterm = sigterm;
tokio::select! {
_ = async {
loop {
sigint.recv().await;
if !handle_interrupt(&first_interrupt, &op_tx) {
break;
}
}
} => {},
_ = sigterm.recv() => {
println!("\n🛑 Termination signal received - exiting...");
std::io::Write::flush(&mut std::io::stdout()).unwrap_or(());
let _ = crate::mcp::server::cleanup_servers();
std::process::exit(130);
}
}
}
#[cfg(windows)]
{
loop {
match signal::ctrl_c().await {
Ok(()) => {
if !handle_interrupt(&first_interrupt, &op_tx) {
break;
}
}
Err(e) => {
crate::log_error!("Warning: Failed to listen for Ctrl+C: {}", e);
break;
}
}
}
}
})
}
pub fn new_operation(&mut self) -> watch::Receiver<bool> {
let (tx, rx) = watch::channel(false);
*self.current_op_tx.lock().unwrap() = tx;
self.current_op_rx = rx.clone();
rx
}
pub fn is_cancelled(&self) -> bool {
*self.current_op_rx.borrow()
}
pub async fn cancelled(&mut self) {
while !*self.current_op_rx.borrow() {
if self.current_op_rx.changed().await.is_err() {
break;
}
}
}
pub fn reset(&mut self) {
self.first_interrupt.store(false, Ordering::SeqCst);
let (tx, rx) = watch::channel(false);
*self.current_op_tx.lock().unwrap() = tx;
self.current_op_rx = rx;
}
pub fn shutdown(&self) {
let _ = self.current_op_tx.lock().unwrap().send(true);
}
}
fn handle_interrupt(
first_interrupt: &Arc<AtomicBool>,
op_tx: &Arc<Mutex<watch::Sender<bool>>>,
) -> bool {
if first_interrupt.load(Ordering::SeqCst) {
std::println!("\n\u{1f6d1} Forcing exit...");
std::io::Write::flush(&mut std::io::stdout()).unwrap_or(());
let _ = crate::mcp::server::cleanup_servers();
std::process::exit(130);
} else {
first_interrupt.store(true, Ordering::SeqCst);
let _ = op_tx.lock().unwrap().send(true);
crate::log_debug!("Ctrl+C: Interrupting current operation...");
crate::log_debug!("Press Ctrl+C again to force exit");
std::io::Write::flush(&mut std::io::stdout()).unwrap_or(());
true }
}