use std::io::Write;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use color_eyre::eyre::Result;
use crossterm::event;
use tokio::net::UnixStream;
use tokio::sync::mpsc;
use tokio::time::interval;
use super::protocol::{self, MainMessage, ShimMessage};
use super::{list_sessions, socket_path};
enum RelayExit {
Quit,
Detached,
WriteError(String),
InputClosed,
ConnectionLost,
}
pub async fn run_shim(target_pid: Option<u32>, show_splash: bool) -> Result<()> {
let (pid, sock_path) = resolve_session(target_pid);
if show_splash {
run_splash(Some(&sock_path)).await?;
}
let stream = UnixStream::connect(&sock_path)
.await
.map_err(|e| color_eyre::eyre::eyre!("Failed to connect to session PID {pid}: {e}"))?;
let (read_half, mut write_half) = tokio::io::split(stream);
let read_half = tokio::io::BufReader::new(read_half);
let term_env = protocol::TerminalEnv::capture();
crossterm::terminal::enable_raw_mode()?;
protocol::write_message(&mut write_half, &term_env).await?;
let (input_tx, mut input_rx) = mpsc::channel::<ShimMessage>(1024);
let input_stop = Arc::new(AtomicBool::new(false));
spawn_input_reader(input_tx.clone(), Arc::clone(&input_stop));
spawn_sigwinch_handler(input_tx);
let (downstream_tx, mut downstream_rx) = mpsc::channel::<MainMessage>(1024);
tokio::spawn(async move {
let mut reader = read_half;
loop {
match protocol::read_message::<_, MainMessage>(&mut reader).await {
Ok(msg) => {
if downstream_tx.send(msg).await.is_err() {
break;
}
}
Err(e) => {
tracing::debug!("shim downstream read error: {e}");
break;
}
}
}
});
tracing::info!("shim relay loop starting");
let exit_reason = run_relay_loop(&mut input_rx, &mut write_half, &mut downstream_rx).await;
tracing::info!("shim relay loop exited");
input_stop.store(true, Ordering::Relaxed);
let _ = crossterm::terminal::disable_raw_mode();
let mut stdout = std::io::stdout();
let _ = crossterm::execute!(
stdout,
crossterm::terminal::LeaveAlternateScreen,
crossterm::event::DisableMouseCapture,
crossterm::event::DisableBracketedPaste,
crossterm::cursor::Show
);
match exit_reason {
RelayExit::Quit => eprintln!("Session ended."),
RelayExit::Detached => eprintln!("Detached from repartee (PID {pid})."),
RelayExit::WriteError(e) => {
eprintln!("Disconnected from repartee (PID {pid}): write error: {e}");
}
RelayExit::InputClosed => eprintln!("Disconnected from repartee (PID {pid}): input closed"),
RelayExit::ConnectionLost => {
eprintln!("Disconnected from repartee (PID {pid}): connection lost");
}
}
Ok(())
}
fn resolve_session(target_pid: Option<u32>) -> (u32, std::path::PathBuf) {
if let Some(pid) = target_pid {
let path = socket_path(pid);
if !path.exists() {
eprintln!("No session found for PID {pid}");
std::process::exit(1);
}
return (pid, path);
}
let sessions = list_sessions();
match sessions.len() {
0 => {
eprintln!("No detached sessions found.");
std::process::exit(1);
}
1 => sessions.into_iter().next().unwrap(),
_ => {
eprintln!("Multiple sessions found. Specify a PID:");
for (pid, _) in &sessions {
eprintln!(" repartee a {pid}");
}
std::process::exit(1);
}
}
}
pub async fn run_splash(sock_path: Option<&std::path::Path>) -> Result<()> {
use crossterm::event::{EnableBracketedPaste, EnableMouseCapture};
use crossterm::terminal::{EnterAlternateScreen, enable_raw_mode};
use ratatui::prelude::*;
const LINE_DELAY_MS: u64 = 50;
const HOLD_MS: u64 = 2500;
enable_raw_mode()?;
let mut stdout = std::io::stdout();
crossterm::execute!(
stdout,
EnterAlternateScreen,
EnableMouseCapture,
EnableBracketedPaste
)?;
let backend = CrosstermBackend::new(stdout);
let mut terminal = Terminal::new(backend)?;
let total_lines = include_str!("../../logo.txt").lines().count();
let mut visible = 0;
let mut line_tick = interval(Duration::from_millis(LINE_DELAY_MS));
let mut dismissed = false;
while visible < total_lines && !dismissed {
terminal.draw(|frame| crate::ui::splash::render(frame, visible))?;
tokio::select! {
_ = line_tick.tick() => {
visible += 1;
}
ev = tokio::task::spawn_blocking(|| {
if event::poll(std::time::Duration::from_millis(1)).unwrap_or(false) {
event::read().ok()
} else {
None
}
}) => {
if let Ok(Some(crossterm::event::Event::Key(_))) = ev {
dismissed = true;
}
}
}
}
if !dismissed {
terminal.draw(|frame| crate::ui::splash::render(frame, total_lines))?;
let hold_start = Instant::now();
while hold_start.elapsed() < Duration::from_millis(HOLD_MS) && !dismissed {
let remaining = Duration::from_millis(HOLD_MS).saturating_sub(hold_start.elapsed());
if remaining.is_zero() {
break;
}
if sock_path.is_some_and(std::path::Path::exists)
&& hold_start.elapsed() >= Duration::from_millis(500)
{
break;
}
if let Ok(Some(crossterm::event::Event::Key(_))) =
tokio::task::spawn_blocking(move || {
if event::poll(remaining.min(Duration::from_millis(100))).unwrap_or(false) {
event::read().ok()
} else {
None
}
})
.await
{
dismissed = true;
}
}
}
let _ = crossterm::terminal::disable_raw_mode();
let mut stdout = std::io::stdout();
let _ = crossterm::execute!(
stdout,
crossterm::terminal::LeaveAlternateScreen,
crossterm::event::DisableMouseCapture,
crossterm::event::DisableBracketedPaste,
crossterm::cursor::Show
);
Ok(())
}
fn spawn_input_reader(tx: mpsc::Sender<ShimMessage>, stop: Arc<AtomicBool>) {
std::thread::spawn(move || {
while !stop.load(Ordering::Relaxed) {
if event::poll(std::time::Duration::from_millis(50)).unwrap_or(false) {
match event::read() {
Ok(ev) => {
let msg = if is_detach_key(&ev) {
ShimMessage::Detach
} else {
ShimMessage::TermEvent(ev)
};
if tx.blocking_send(msg).is_err() {
break;
}
}
Err(_) => break,
}
}
}
});
}
const fn is_detach_key(ev: &crossterm::event::Event) -> bool {
use crossterm::event::{Event, KeyCode, KeyEvent, KeyModifiers};
matches!(
ev,
Event::Key(KeyEvent {
code: KeyCode::Char('\\' | 'z'),
modifiers,
..
}) if modifiers.contains(KeyModifiers::CONTROL)
)
}
fn spawn_sigwinch_handler(tx: mpsc::Sender<ShimMessage>) {
tokio::spawn(async move {
let Ok(mut sigwinch) =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())
else {
return;
};
while sigwinch.recv().await.is_some() {
let (cols, rows) = crossterm::terminal::size().unwrap_or((80, 24));
if tx.send(ShimMessage::Resize { cols, rows }).await.is_err() {
break;
}
}
});
}
async fn run_relay_loop<W>(
input_rx: &mut mpsc::Receiver<ShimMessage>,
write_half: &mut W,
downstream_rx: &mut mpsc::Receiver<MainMessage>,
) -> RelayExit
where
W: tokio::io::AsyncWriteExt + Unpin + Send,
{
loop {
tokio::select! {
msg = input_rx.recv() => {
if let Some(shim_msg) = msg {
if let Err(e) = protocol::write_message(write_half, &shim_msg).await {
return RelayExit::WriteError(e.to_string());
}
} else {
return RelayExit::InputClosed;
}
}
msg = downstream_rx.recv() => {
match msg {
Some(MainMessage::Output(bytes)) => {
let mut stdout = std::io::stdout().lock();
let _ = stdout.write_all(&bytes);
let _ = stdout.flush();
}
Some(MainMessage::Detached) => return RelayExit::Detached,
Some(MainMessage::Quit) => return RelayExit::Quit,
None => return RelayExit::ConnectionLost,
}
}
}
}
}