use std::io::{self, Read, Write};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
#[cfg(feature = "async")]
use crate::command::CommandKind;
#[cfg(feature = "async")]
use tokio_util::sync::CancellationToken;
#[cfg(feature = "async")]
use tokio_util::task::TaskTracker;
use tracing::debug;
#[cfg(feature = "thread-pool")]
fn spawn_batch(f: impl FnOnce() + Send + 'static) {
rayon::spawn(f);
}
#[cfg(not(feature = "thread-pool"))]
fn spawn_batch(f: impl FnOnce() + Send + 'static) {
let _ = thread::spawn(f);
}
use crossterm::{
cursor::{Hide, MoveTo, Show},
event::{self, DisableMouseCapture, EnableMouseCapture, Event, KeyEventKind},
execute,
terminal::{
self, Clear, ClearType, EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode,
enable_raw_mode,
},
};
use crate::command::Cmd;
use crate::key::{from_crossterm_key, is_sequence_prefix};
use crate::message::{
BatchMsg, BlurMsg, FocusMsg, InterruptMsg, Message, PrintLineMsg, QuitMsg,
RequestWindowSizeMsg, SequenceMsg, SetWindowTitleMsg, WindowSizeMsg,
};
use crate::mouse::from_crossterm_mouse;
use crate::screen::{ReleaseTerminalMsg, RestoreTerminalMsg};
use crate::{KeyMsg, KeyType};
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("terminal io error: {0}")]
Io(#[from] io::Error),
#[error("failed to {action} raw mode: {source}")]
RawModeFailure {
action: &'static str,
#[source]
source: io::Error,
},
#[error("failed to {action} alternate screen: {source}")]
AltScreenFailure {
action: &'static str,
#[source]
source: io::Error,
},
#[error("failed to poll terminal events: {0}")]
EventPoll(io::Error),
#[error("failed to render view: {0}")]
Render(io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;
pub trait Model: Send + 'static {
fn init(&self) -> Option<Cmd>;
fn update(&mut self, msg: Message) -> Option<Cmd>;
fn view(&self) -> String;
}
#[derive(Debug, Clone)]
pub struct ProgramOptions {
pub alt_screen: bool,
pub mouse_cell_motion: bool,
pub mouse_all_motion: bool,
pub bracketed_paste: bool,
pub report_focus: bool,
pub custom_io: bool,
pub fps: u32,
pub without_signals: bool,
pub without_catch_panics: bool,
}
impl Default for ProgramOptions {
fn default() -> Self {
Self {
alt_screen: false,
mouse_cell_motion: false,
mouse_all_motion: false,
bracketed_paste: true,
report_focus: false,
custom_io: false,
fps: 60,
without_signals: false,
without_catch_panics: false,
}
}
}
pub struct ProgramHandle<M: Model> {
tx: Sender<Message>,
handle: Option<thread::JoinHandle<Result<M>>>,
}
impl<M: Model> ProgramHandle<M> {
pub fn send<T: Into<Message>>(&self, msg: T) -> bool {
self.tx.send(msg.into()).is_ok()
}
pub fn quit(&self) {
let _ = self.tx.send(Message::new(QuitMsg));
}
pub fn wait(mut self) -> Result<M> {
if let Some(handle) = self.handle.take() {
handle
.join()
.map_err(|_| Error::Io(io::Error::other("program thread panicked")))?
} else {
Err(Error::Io(io::Error::other("program already joined")))
}
}
pub fn is_running(&self) -> bool {
self.handle.as_ref().is_some_and(|h| !h.is_finished())
}
}
pub struct Program<M: Model> {
model: M,
options: ProgramOptions,
external_rx: Option<Receiver<Message>>,
input: Option<Box<dyn Read + Send>>,
output: Option<Box<dyn Write + Send>>,
}
impl<M: Model> Program<M> {
pub fn new(model: M) -> Self {
Self {
model,
options: ProgramOptions::default(),
external_rx: None,
input: None,
output: None,
}
}
pub fn with_input_receiver(mut self, rx: Receiver<Message>) -> Self {
self.external_rx = Some(rx);
self
}
pub fn with_input<R: Read + Send + 'static>(mut self, input: R) -> Self {
self.input = Some(Box::new(input));
self.options.custom_io = true;
self
}
pub fn with_output<W: Write + Send + 'static>(mut self, output: W) -> Self {
self.output = Some(Box::new(output));
self.options.custom_io = true;
self
}
pub fn with_alt_screen(mut self) -> Self {
self.options.alt_screen = true;
self
}
pub fn with_mouse_cell_motion(mut self) -> Self {
self.options.mouse_cell_motion = true;
self
}
pub fn with_mouse_all_motion(mut self) -> Self {
self.options.mouse_all_motion = true;
self
}
pub fn with_fps(mut self, fps: u32) -> Self {
self.options.fps = fps.clamp(1, 120);
self
}
pub fn with_report_focus(mut self) -> Self {
self.options.report_focus = true;
self
}
pub fn without_bracketed_paste(mut self) -> Self {
self.options.bracketed_paste = false;
self
}
pub fn without_signal_handler(mut self) -> Self {
self.options.without_signals = true;
self
}
pub fn without_catch_panics(mut self) -> Self {
self.options.without_catch_panics = true;
self
}
pub fn with_custom_io(mut self) -> Self {
self.options.custom_io = true;
self
}
pub fn run_with_writer<W: Write + Send + 'static>(self, mut writer: W) -> Result<M> {
let options = self.options.clone();
if !options.custom_io {
enable_raw_mode()?;
}
if options.alt_screen {
execute!(writer, EnterAlternateScreen)?;
}
execute!(writer, Hide)?;
if options.mouse_all_motion {
execute!(writer, EnableMouseCapture)?;
} else if options.mouse_cell_motion {
execute!(writer, EnableMouseCapture)?;
}
if options.report_focus {
execute!(writer, event::EnableFocusChange)?;
}
if options.bracketed_paste {
execute!(writer, event::EnableBracketedPaste)?;
}
let prev_hook = if !options.without_catch_panics {
let cleanup_opts = options.clone();
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
let mut stderr = io::stderr();
if cleanup_opts.bracketed_paste {
let _ = execute!(stderr, event::DisableBracketedPaste);
}
if cleanup_opts.report_focus {
let _ = execute!(stderr, event::DisableFocusChange);
}
if cleanup_opts.mouse_all_motion || cleanup_opts.mouse_cell_motion {
let _ = execute!(stderr, DisableMouseCapture);
}
let _ = execute!(stderr, Show);
if cleanup_opts.alt_screen {
let _ = execute!(stderr, LeaveAlternateScreen);
}
if !cleanup_opts.custom_io {
let _ = disable_raw_mode();
}
prev(info);
}));
true
} else {
false
};
let result = self.event_loop(&mut writer);
if prev_hook {
let _ = std::panic::take_hook();
}
if options.bracketed_paste {
let _ = execute!(writer, event::DisableBracketedPaste);
}
if options.report_focus {
let _ = execute!(writer, event::DisableFocusChange);
}
if options.mouse_all_motion || options.mouse_cell_motion {
let _ = execute!(writer, DisableMouseCapture);
}
let _ = execute!(writer, Show);
if options.alt_screen {
let _ = execute!(writer, LeaveAlternateScreen);
}
if !options.custom_io {
let _ = disable_raw_mode();
}
result
}
pub fn run(mut self) -> Result<M> {
if let Some(output) = self.output.take() {
return self.run_with_writer(output);
}
let stdout = io::stdout();
self.run_with_writer(stdout)
}
pub fn start(mut self) -> ProgramHandle<M> {
let (tx, rx) = mpsc::channel();
self.external_rx = Some(rx);
let output = self.output.take();
let handle = thread::spawn(move || {
if let Some(output) = output {
self.run_with_writer(output)
} else {
let stdout = io::stdout();
self.run_with_writer(stdout)
}
});
ProgramHandle {
tx,
handle: Some(handle),
}
}
fn event_loop<W: Write>(mut self, writer: &mut W) -> Result<M> {
let (tx, rx): (Sender<Message>, Receiver<Message>) = mpsc::channel();
let mut external_forwarder_handle: Option<thread::JoinHandle<()>> = None;
let mut input_parser_handle: Option<thread::JoinHandle<()>> = None;
let external_shutdown = Arc::new(AtomicBool::new(false));
let command_threads: Arc<Mutex<Vec<thread::JoinHandle<()>>>> =
Arc::new(Mutex::new(Vec::new()));
if let Some(ext_rx) = self.external_rx.take() {
let tx_clone = tx.clone();
let shutdown_clone = Arc::clone(&external_shutdown);
debug!(target: "bubbletea::thread", "Spawning external forwarder thread");
external_forwarder_handle = Some(thread::spawn(move || {
const POLL_INTERVAL: Duration = Duration::from_millis(50);
loop {
if shutdown_clone.load(Ordering::Relaxed) {
break;
}
match ext_rx.recv_timeout(POLL_INTERVAL) {
Ok(msg) => {
if tx_clone.send(msg).is_err() {
debug!(target: "bubbletea::event", "external message dropped — receiver disconnected");
break;
}
}
Err(RecvTimeoutError::Timeout) => {}
Err(RecvTimeoutError::Disconnected) => break,
}
}
}));
}
if let Some(mut input) = self.input.take() {
let tx_clone = tx.clone();
debug!(target: "bubbletea::thread", "Spawning input parser thread");
input_parser_handle = Some(thread::spawn(move || {
let mut parser = InputParser::new();
let mut buf = [0u8; 256];
loop {
match input.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let can_have_more_data = true;
for msg in parser.push_bytes(&buf[..n], can_have_more_data) {
if tx_clone.send(msg).is_err() {
debug!(target: "bubbletea::input", "input message dropped — receiver disconnected");
return;
}
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
thread::yield_now();
}
Err(_) => break,
}
}
for msg in parser.flush() {
if tx_clone.send(msg).is_err() {
debug!(target: "bubbletea::input", "flush message dropped — receiver disconnected");
break;
}
}
}));
}
if !self.options.custom_io
&& let Ok((width, height)) = terminal::size()
&& tx
.send(Message::new(WindowSizeMsg { width, height }))
.is_err()
{
debug!(target: "bubbletea::event", "initial window size dropped — receiver disconnected");
}
if let Some(cmd) = self.model.init() {
self.handle_command(cmd, tx.clone(), Arc::clone(&command_threads));
}
let mut last_view = String::new();
self.render(writer, &mut last_view)?;
let frame_duration = Duration::from_secs_f64(1.0 / self.options.fps as f64);
loop {
if !self.options.custom_io && event::poll(frame_duration)? {
match event::read()? {
Event::Key(key_event) => {
if key_event.kind != KeyEventKind::Press {
continue;
}
let key_msg = from_crossterm_key(key_event.code, key_event.modifiers);
if key_msg.key_type == crate::KeyType::CtrlC {
if tx.send(Message::new(InterruptMsg)).is_err() {
debug!(target: "bubbletea::event", "interrupt message dropped — receiver disconnected");
}
} else if tx.send(Message::new(key_msg)).is_err() {
debug!(target: "bubbletea::event", "key message dropped — receiver disconnected");
}
}
Event::Mouse(mouse_event) => {
let mouse_msg = from_crossterm_mouse(mouse_event);
if tx.send(Message::new(mouse_msg)).is_err() {
debug!(target: "bubbletea::event", "mouse message dropped — receiver disconnected");
}
}
Event::Resize(width, height) => {
if tx
.send(Message::new(WindowSizeMsg { width, height }))
.is_err()
{
debug!(target: "bubbletea::event", "resize message dropped — receiver disconnected");
}
}
Event::FocusGained => {
if tx.send(Message::new(FocusMsg)).is_err() {
debug!(target: "bubbletea::event", "focus message dropped — receiver disconnected");
}
}
Event::FocusLost => {
if tx.send(Message::new(BlurMsg)).is_err() {
debug!(target: "bubbletea::event", "blur message dropped — receiver disconnected");
}
}
Event::Paste(text) => {
let key_msg = KeyMsg {
key_type: crate::KeyType::Runes,
runes: text.chars().collect(),
alt: false,
paste: true,
};
if tx.send(Message::new(key_msg)).is_err() {
debug!(target: "bubbletea::event", "paste message dropped — receiver disconnected");
}
}
}
}
let mut needs_render = false;
let mut should_quit = false;
while let Ok(msg) = rx.try_recv() {
if msg.is::<QuitMsg>() {
should_quit = true;
break;
}
if msg.is::<InterruptMsg>() {
should_quit = true;
break;
}
if msg.is::<BatchMsg>() {
continue;
}
if let Some(title_msg) = msg.downcast_ref::<SetWindowTitleMsg>() {
execute!(writer, terminal::SetTitle(&title_msg.0))?;
continue;
}
if msg.is::<RequestWindowSizeMsg>() {
if !self.options.custom_io
&& let Ok((width, height)) = terminal::size()
&& tx
.send(Message::new(WindowSizeMsg { width, height }))
.is_err()
{
debug!(target: "bubbletea::event", "window size response dropped — receiver disconnected");
}
continue;
}
if let Some(print_msg) = msg.downcast_ref::<PrintLineMsg>() {
if !self.options.alt_screen {
for line in print_msg.0.lines() {
let _ = writeln!(writer, "{}", line);
}
let _ = writer.flush();
last_view.clear();
needs_render = true;
}
continue;
}
if msg.is::<ReleaseTerminalMsg>() {
if !self.options.custom_io {
if self.options.bracketed_paste {
let _ = execute!(writer, event::DisableBracketedPaste);
}
if self.options.report_focus {
let _ = execute!(writer, event::DisableFocusChange);
}
if self.options.mouse_all_motion || self.options.mouse_cell_motion {
let _ = execute!(writer, DisableMouseCapture);
}
let _ = execute!(writer, Show);
if self.options.alt_screen {
let _ = execute!(writer, LeaveAlternateScreen);
}
let _ = disable_raw_mode();
}
continue;
}
if msg.is::<RestoreTerminalMsg>() {
if !self.options.custom_io {
let _ = enable_raw_mode();
if self.options.alt_screen {
let _ = execute!(writer, EnterAlternateScreen);
}
let _ = execute!(writer, Hide);
if self.options.mouse_all_motion {
let _ = execute!(writer, EnableMouseCapture);
} else if self.options.mouse_cell_motion {
let _ = execute!(writer, EnableMouseCapture);
}
if self.options.report_focus {
let _ = execute!(writer, event::EnableFocusChange);
}
if self.options.bracketed_paste {
let _ = execute!(writer, event::EnableBracketedPaste);
}
last_view.clear();
}
needs_render = true;
continue;
}
if let Some(cmd) = self.model.update(msg) {
self.handle_command(cmd, tx.clone(), Arc::clone(&command_threads));
}
needs_render = true;
}
if should_quit {
break;
}
if needs_render {
self.render(writer, &mut last_view)?;
}
if self.options.custom_io {
thread::sleep(frame_duration);
}
}
external_shutdown.store(true, Ordering::Relaxed);
drop(tx);
debug!(target: "bubbletea::thread", "Sender dropped, waiting for threads to exit");
if let Some(handle) = external_forwarder_handle {
match handle.join() {
Ok(()) => {
debug!(target: "bubbletea::thread", "External forwarder thread joined successfully")
}
Err(e) => {
tracing::warn!(target: "bubbletea::thread", "External forwarder thread panicked: {:?}", e)
}
}
}
if let Some(handle) = input_parser_handle {
match handle.join() {
Ok(()) => {
debug!(target: "bubbletea::thread", "Input parser thread joined successfully")
}
Err(e) => {
tracing::warn!(target: "bubbletea::thread", "Input parser thread panicked: {:?}", e)
}
}
}
const COMMAND_THREAD_TIMEOUT: Duration = Duration::from_secs(5);
let join_deadline = std::time::Instant::now() + COMMAND_THREAD_TIMEOUT;
if let Ok(mut threads) = command_threads.lock() {
let thread_count = threads.len();
if thread_count > 0 {
debug!(target: "bubbletea::thread", "Waiting for {} command thread(s) to complete", thread_count);
}
for handle in threads.drain(..) {
if handle.is_finished() {
let _ = handle.join();
} else {
let remaining =
join_deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
debug!(target: "bubbletea::thread", "Timeout waiting for command threads, abandoning remaining");
break;
}
let poll_interval = Duration::from_millis(10);
let start = std::time::Instant::now();
while !handle.is_finished() && start.elapsed() < remaining {
thread::sleep(poll_interval);
}
if handle.is_finished() {
match handle.join() {
Ok(()) => {
debug!(target: "bubbletea::thread", "Command thread joined successfully")
}
Err(e) => {
tracing::warn!(target: "bubbletea::thread", "Command thread panicked: {:?}", e)
}
}
} else {
debug!(target: "bubbletea::thread", "Command thread did not finish in time, abandoning");
}
}
}
} else {
tracing::warn!(target: "bubbletea::thread", "Failed to join command threads: mutex poisoned");
}
Ok(self.model)
}
fn handle_command(
&self,
cmd: Cmd,
tx: Sender<Message>,
command_threads: Arc<Mutex<Vec<thread::JoinHandle<()>>>>,
) {
let handle = thread::spawn(move || {
if let Some(msg) = cmd.execute() {
if msg.is::<BatchMsg>() {
if let Some(batch) = msg.downcast::<BatchMsg>() {
for cmd in batch.0 {
let tx_clone = tx.clone();
spawn_batch(move || {
if let Some(msg) = cmd.execute()
&& tx_clone.send(msg).is_err()
{
debug!(target: "bubbletea::command", "batch command result dropped — receiver disconnected");
}
});
}
}
} else if msg.is::<SequenceMsg>() {
if let Some(seq) = msg.downcast::<SequenceMsg>() {
for cmd in seq.0 {
if let Some(msg) = cmd.execute()
&& tx.send(msg).is_err()
{
debug!(target: "bubbletea::command", "sequence command result dropped — receiver disconnected");
break;
}
}
}
} else if tx.send(msg).is_err() {
debug!(target: "bubbletea::command", "command result dropped — receiver disconnected");
}
}
});
if let Ok(mut threads) = command_threads.lock() {
threads.retain(|h| !h.is_finished());
threads.push(handle);
} else {
debug!(target: "bubbletea::thread", "Failed to track command thread: mutex poisoned");
}
}
fn render<W: Write>(&self, writer: &mut W, last_view: &mut String) -> Result<()> {
let view = self.model.view();
if view == *last_view {
return Ok(());
}
execute!(writer, MoveTo(0, 0), Clear(ClearType::All))?;
write!(writer, "{}", view)?;
writer.flush()?;
*last_view = view;
Ok(())
}
}
#[cfg(feature = "async")]
impl<M: Model> Program<M> {
pub async fn run_async(mut self) -> Result<M> {
if let Some(output) = self.output.take() {
return self.run_async_with_writer(output).await;
}
let stdout = io::stdout();
self.run_async_with_writer(stdout).await
}
pub async fn run_async_with_writer<W: Write + Send + 'static>(
self,
mut writer: W,
) -> Result<M> {
let options = self.options.clone();
if !options.custom_io {
enable_raw_mode()?;
}
if options.alt_screen {
execute!(writer, EnterAlternateScreen)?;
}
execute!(writer, Hide)?;
if options.mouse_all_motion {
execute!(writer, EnableMouseCapture)?;
} else if options.mouse_cell_motion {
execute!(writer, EnableMouseCapture)?;
}
if options.report_focus {
execute!(writer, event::EnableFocusChange)?;
}
if options.bracketed_paste {
execute!(writer, event::EnableBracketedPaste)?;
}
let prev_hook = if !options.without_catch_panics {
let cleanup_opts = options.clone();
let prev = std::panic::take_hook();
std::panic::set_hook(Box::new(move |info| {
let mut stderr = io::stderr();
if cleanup_opts.bracketed_paste {
let _ = execute!(stderr, event::DisableBracketedPaste);
}
if cleanup_opts.report_focus {
let _ = execute!(stderr, event::DisableFocusChange);
}
if cleanup_opts.mouse_all_motion || cleanup_opts.mouse_cell_motion {
let _ = execute!(stderr, DisableMouseCapture);
}
let _ = execute!(stderr, Show);
if cleanup_opts.alt_screen {
let _ = execute!(stderr, LeaveAlternateScreen);
}
if !cleanup_opts.custom_io {
let _ = disable_raw_mode();
}
prev(info);
}));
true
} else {
false
};
let result = self.event_loop_async(&mut writer).await;
if prev_hook {
let _ = std::panic::take_hook();
}
if options.bracketed_paste {
let _ = execute!(writer, event::DisableBracketedPaste);
}
if options.report_focus {
let _ = execute!(writer, event::DisableFocusChange);
}
if options.mouse_all_motion || options.mouse_cell_motion {
let _ = execute!(writer, DisableMouseCapture);
}
let _ = execute!(writer, Show);
if options.alt_screen {
let _ = execute!(writer, LeaveAlternateScreen);
}
if !options.custom_io {
let _ = disable_raw_mode();
}
result
}
async fn event_loop_async<W: Write>(mut self, stdout: &mut W) -> Result<M> {
let (tx, mut rx) = tokio::sync::mpsc::channel::<Message>(256);
let cancel_token = CancellationToken::new();
let task_tracker = TaskTracker::new();
if let Some(ext_rx) = self.external_rx.take() {
let tx_clone = tx.clone();
let cancel_clone = cancel_token.clone();
task_tracker.spawn_blocking(move || {
let timeout = Duration::from_millis(100);
loop {
if cancel_clone.is_cancelled() {
break;
}
match ext_rx.recv_timeout(timeout) {
Ok(msg) => {
if tx_clone.blocking_send(msg).is_err() {
debug!(target: "bubbletea::event", "async external message dropped — receiver disconnected");
break;
}
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
});
}
if let Some(mut input) = self.input.take() {
let tx_clone = tx.clone();
let cancel_clone = cancel_token.clone();
task_tracker.spawn_blocking(move || {
let mut parser = InputParser::new();
let mut buf = [0u8; 256];
loop {
if cancel_clone.is_cancelled() {
break;
}
match input.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let can_have_more_data = true;
for msg in parser.push_bytes(&buf[..n], can_have_more_data) {
if tx_clone.blocking_send(msg).is_err() {
return;
}
}
}
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
std::thread::yield_now();
}
Err(_) => break,
}
}
for msg in parser.flush() {
if tx_clone.blocking_send(msg).is_err() {
break;
}
}
});
}
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel::<Event>(100);
let event_cancel = cancel_token.clone();
if !self.options.custom_io {
task_tracker.spawn_blocking(move || {
loop {
if event_cancel.is_cancelled() {
break;
}
match event::poll(Duration::from_millis(100)) {
Ok(true) => {
if let Ok(evt) = event::read()
&& event_tx.blocking_send(evt).is_err()
{
break;
}
}
Ok(false) => {} Err(_) => {
break;
} }
}
});
}
if !self.options.custom_io {
let (width, height) = terminal::size()?;
if tx
.send(Message::new(WindowSizeMsg { width, height }))
.await
.is_err()
{
debug!(target: "bubbletea::event", "async initial window size dropped — receiver disconnected");
}
}
if let Some(cmd) = self.model.init() {
Self::handle_command_tracked(
cmd.into(),
tx.clone(),
&task_tracker,
cancel_token.clone(),
);
}
let mut last_view = String::new();
self.render(stdout, &mut last_view)?;
let frame_duration = Duration::from_secs_f64(1.0 / self.options.fps as f64);
let mut frame_interval = tokio::time::interval(frame_duration);
loop {
tokio::select! {
Some(event) = event_rx.recv(), if !self.options.custom_io => {
match event {
Event::Key(key_event) => {
if key_event.kind != KeyEventKind::Press {
continue;
}
let key_msg = from_crossterm_key(key_event.code, key_event.modifiers);
if key_msg.key_type == crate::KeyType::CtrlC {
if tx.send(Message::new(InterruptMsg)).await.is_err() {
debug!(target: "bubbletea::event", "async interrupt message dropped — receiver disconnected");
}
} else if tx.send(Message::new(key_msg)).await.is_err() {
debug!(target: "bubbletea::event", "async key message dropped — receiver disconnected");
}
}
Event::Mouse(mouse_event) => {
let mouse_msg = from_crossterm_mouse(mouse_event);
if tx.send(Message::new(mouse_msg)).await.is_err() {
debug!(target: "bubbletea::event", "async mouse message dropped — receiver disconnected");
}
}
Event::Resize(width, height) => {
if tx.send(Message::new(WindowSizeMsg { width, height })).await.is_err() {
debug!(target: "bubbletea::event", "async resize message dropped — receiver disconnected");
}
}
Event::FocusGained => {
if tx.send(Message::new(FocusMsg)).await.is_err() {
debug!(target: "bubbletea::event", "async focus message dropped — receiver disconnected");
}
}
Event::FocusLost => {
if tx.send(Message::new(BlurMsg)).await.is_err() {
debug!(target: "bubbletea::event", "async blur message dropped — receiver disconnected");
}
}
Event::Paste(text) => {
let key_msg = KeyMsg {
key_type: crate::KeyType::Runes,
runes: text.chars().collect(),
alt: false,
paste: true,
};
if tx.send(Message::new(key_msg)).await.is_err() {
debug!(target: "bubbletea::event", "async paste message dropped — receiver disconnected");
}
}
}
}
Some(msg) = rx.recv() => {
if msg.is::<QuitMsg>() {
Self::graceful_shutdown(&cancel_token, &task_tracker).await;
return Ok(self.model);
}
if msg.is::<InterruptMsg>() {
Self::graceful_shutdown(&cancel_token, &task_tracker).await;
return Ok(self.model);
}
if msg.is::<BatchMsg>() {
continue;
}
if let Some(title_msg) = msg.downcast_ref::<SetWindowTitleMsg>() {
execute!(stdout, terminal::SetTitle(&title_msg.0))?;
continue;
}
if msg.is::<RequestWindowSizeMsg>() {
if !self.options.custom_io {
let (width, height) = terminal::size()?;
if tx.send(Message::new(WindowSizeMsg { width, height })).await.is_err() {
debug!(target: "bubbletea::event", "async window size response dropped — receiver disconnected");
}
}
continue;
}
if let Some(print_msg) = msg.downcast_ref::<PrintLineMsg>() {
if !self.options.alt_screen {
for line in print_msg.0.lines() {
let _ = writeln!(stdout, "{}", line);
}
let _ = stdout.flush();
last_view.clear();
}
self.render(stdout, &mut last_view)?;
continue;
}
if msg.is::<ReleaseTerminalMsg>() {
if !self.options.custom_io {
if self.options.bracketed_paste {
let _ = execute!(stdout, event::DisableBracketedPaste);
}
if self.options.report_focus {
let _ = execute!(stdout, event::DisableFocusChange);
}
if self.options.mouse_all_motion || self.options.mouse_cell_motion {
let _ = execute!(stdout, DisableMouseCapture);
}
let _ = execute!(stdout, Show);
if self.options.alt_screen {
let _ = execute!(stdout, LeaveAlternateScreen);
}
let _ = disable_raw_mode();
}
continue;
}
if msg.is::<RestoreTerminalMsg>() {
if !self.options.custom_io {
let _ = enable_raw_mode();
if self.options.alt_screen {
let _ = execute!(stdout, EnterAlternateScreen);
}
let _ = execute!(stdout, Hide);
if self.options.mouse_all_motion {
let _ = execute!(stdout, EnableMouseCapture);
} else if self.options.mouse_cell_motion {
let _ = execute!(stdout, EnableMouseCapture);
}
if self.options.report_focus {
let _ = execute!(stdout, event::EnableFocusChange);
}
if self.options.bracketed_paste {
let _ = execute!(stdout, event::EnableBracketedPaste);
}
last_view.clear();
}
self.render(stdout, &mut last_view)?;
continue;
}
if let Some(cmd) = self.model.update(msg) {
Self::handle_command_tracked(
cmd.into(),
tx.clone(),
&task_tracker,
cancel_token.clone(),
);
}
self.render(stdout, &mut last_view)?;
}
_ = frame_interval.tick() => {
}
}
}
}
async fn graceful_shutdown(cancel_token: &CancellationToken, task_tracker: &TaskTracker) {
cancel_token.cancel();
task_tracker.close();
let shutdown_timeout = Duration::from_secs(5);
let _ = tokio::time::timeout(shutdown_timeout, task_tracker.wait()).await;
}
fn handle_command_tracked(
cmd: CommandKind,
tx: tokio::sync::mpsc::Sender<Message>,
tracker: &TaskTracker,
cancel_token: CancellationToken,
) {
let batch_tracker = tracker.clone();
let batch_cancel = cancel_token.clone();
tracker.spawn(async move {
tokio::select! {
result = cmd.execute() => {
if let Some(msg) = result {
if msg.is::<BatchMsg>() {
if let Some(batch) = msg.downcast::<BatchMsg>() {
for cmd in batch.0 {
let tx_clone = tx.clone();
let cancel = batch_cancel.clone();
batch_tracker.spawn(async move {
tokio::select! {
result = async {
let cmd_kind: CommandKind = cmd.into();
cmd_kind.execute().await
} => {
if let Some(msg) = result {
if tx_clone.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "async batch command result dropped — receiver disconnected");
}
}
}
_ = cancel.cancelled() => {
debug!(target: "bubbletea::command", "async batch command cancelled during shutdown");
}
}
});
}
}
} else if msg.is::<SequenceMsg>() {
if let Some(seq) = msg.downcast::<SequenceMsg>() {
for cmd in seq.0 {
let cmd_kind: CommandKind = cmd.into();
if let Some(msg) = cmd_kind.execute().await {
if tx.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "async sequence command result dropped — receiver disconnected");
break;
}
}
}
}
} else if tx.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "async command result dropped — receiver disconnected");
}
}
}
_ = cancel_token.cancelled() => {
}
}
});
}
#[allow(dead_code)]
fn handle_command_async(&self, cmd: CommandKind, tx: tokio::sync::mpsc::Sender<Message>) {
tokio::spawn(async move {
if let Some(msg) = cmd.execute().await {
if msg.is::<BatchMsg>() {
if let Some(batch) = msg.downcast::<BatchMsg>() {
for cmd in batch.0 {
let tx_clone = tx.clone();
tokio::spawn(async move {
let cmd_kind: CommandKind = cmd.into();
if let Some(msg) = cmd_kind.execute().await {
if tx_clone.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "legacy async batch command result dropped — receiver disconnected");
}
}
});
}
}
} else if msg.is::<SequenceMsg>() {
if let Some(seq) = msg.downcast::<SequenceMsg>() {
for cmd in seq.0 {
let cmd_kind: CommandKind = cmd.into();
if let Some(msg) = cmd_kind.execute().await {
if tx.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "legacy async sequence command result dropped — receiver disconnected");
break;
}
}
}
}
} else if tx.send(msg).await.is_err() {
debug!(target: "bubbletea::command", "legacy async command result dropped — receiver disconnected");
}
}
});
}
}
struct InputParser {
buffer: Vec<u8>,
}
impl InputParser {
fn new() -> Self {
Self { buffer: Vec::new() }
}
const MAX_BUFFER: usize = 1024 * 1024;
fn push_bytes(&mut self, bytes: &[u8], can_have_more_data: bool) -> Vec<Message> {
if !bytes.is_empty() {
if self.buffer.len() + bytes.len() > Self::MAX_BUFFER {
debug!(
target: "bubbletea::input",
"Input buffer exceeded 1MB limit, draining"
);
self.buffer.clear();
}
self.buffer.extend_from_slice(bytes);
}
let mut messages = Vec::new();
loop {
if self.buffer.is_empty() {
break;
}
match parse_one_message(&self.buffer, can_have_more_data) {
ParseOutcome::NeedMore => break,
ParseOutcome::Parsed(consumed, msg) => {
self.buffer.drain(0..consumed);
if let Some(msg) = msg {
messages.push(msg);
}
}
}
}
messages
}
fn flush(&mut self) -> Vec<Message> {
let mut messages = Vec::new();
loop {
if self.buffer.is_empty() {
break;
}
match parse_one_message(&self.buffer, false) {
ParseOutcome::NeedMore => break,
ParseOutcome::Parsed(consumed, msg) => {
self.buffer.drain(0..consumed);
if let Some(msg) = msg {
messages.push(msg);
}
}
}
}
messages
}
}
enum ParseOutcome {
NeedMore,
Parsed(usize, Option<Message>),
}
fn parse_one_message(buf: &[u8], can_have_more_data: bool) -> ParseOutcome {
if buf.is_empty() {
return ParseOutcome::NeedMore;
}
if let Some(outcome) = parse_mouse_event(buf, can_have_more_data) {
return outcome;
}
if let Some(outcome) = parse_focus_event(buf, can_have_more_data) {
return outcome;
}
if let Some(outcome) = parse_bracketed_paste(buf, can_have_more_data) {
return outcome;
}
if let Some(outcome) = parse_key_sequence(buf, can_have_more_data) {
return outcome;
}
parse_runes_or_control(buf, can_have_more_data)
}
fn parse_mouse_event(buf: &[u8], can_have_more_data: bool) -> Option<ParseOutcome> {
if buf.starts_with(b"\x1b[M") {
if buf.len() < 6 {
return Some(if can_have_more_data {
ParseOutcome::NeedMore
} else {
ParseOutcome::Parsed(1, Some(replacement_message()))
});
}
let seq = &buf[..6];
return Some(match crate::mouse::parse_mouse_event_sequence(seq) {
Ok(msg) => ParseOutcome::Parsed(6, Some(Message::new(msg))),
Err(_) => ParseOutcome::Parsed(1, Some(replacement_message())),
});
}
if buf.starts_with(b"\x1b[<") {
if let Some(end_idx) = buf.iter().position(|b| *b == b'M' || *b == b'm') {
let seq = &buf[..=end_idx];
return Some(match crate::mouse::parse_mouse_event_sequence(seq) {
Ok(msg) => ParseOutcome::Parsed(seq.len(), Some(Message::new(msg))),
Err(_) => ParseOutcome::Parsed(1, Some(replacement_message())),
});
}
return Some(if can_have_more_data {
ParseOutcome::NeedMore
} else {
ParseOutcome::Parsed(1, Some(replacement_message()))
});
}
None
}
fn parse_focus_event(buf: &[u8], can_have_more_data: bool) -> Option<ParseOutcome> {
if buf.len() < 3 && buf.starts_with(b"\x1b[") && can_have_more_data {
return Some(ParseOutcome::NeedMore);
}
if buf.starts_with(b"\x1b[I") {
return Some(ParseOutcome::Parsed(3, Some(Message::new(FocusMsg))));
}
if buf.starts_with(b"\x1b[O") {
return Some(ParseOutcome::Parsed(3, Some(Message::new(BlurMsg))));
}
None
}
fn parse_bracketed_paste(buf: &[u8], can_have_more_data: bool) -> Option<ParseOutcome> {
const BP_START: &[u8] = b"\x1b[200~";
const BP_END: &[u8] = b"\x1b[201~";
if !buf.starts_with(BP_START) {
return None;
}
if let Some(idx) = buf.windows(BP_END.len()).position(|w| w == BP_END) {
let content = &buf[BP_START.len()..idx];
let text = String::from_utf8_lossy(content);
let runes = text.chars().collect::<Vec<char>>();
let key = KeyMsg::from_runes(runes).with_paste();
let total_len = idx + BP_END.len();
return Some(ParseOutcome::Parsed(total_len, Some(message_from_key(key))));
}
Some(if can_have_more_data {
ParseOutcome::NeedMore
} else {
let content = &buf[BP_START.len()..];
let text = String::from_utf8_lossy(content);
let runes = text.chars().collect::<Vec<char>>();
let key = KeyMsg::from_runes(runes).with_paste();
ParseOutcome::Parsed(buf.len(), Some(message_from_key(key)))
})
}
fn parse_key_sequence(buf: &[u8], can_have_more_data: bool) -> Option<ParseOutcome> {
if let Some((key, len)) = crate::key::parse_sequence_prefix(buf) {
return Some(ParseOutcome::Parsed(len, Some(message_from_key(key))));
}
if can_have_more_data && is_sequence_prefix(buf) {
return Some(ParseOutcome::NeedMore);
}
if buf.starts_with(b"\x1b")
&& let Some((mut key, len)) = crate::key::parse_sequence_prefix(&buf[1..])
{
if !key.alt {
key = key.with_alt();
}
return Some(ParseOutcome::Parsed(len + 1, Some(message_from_key(key))));
}
None
}
fn parse_runes_or_control(buf: &[u8], can_have_more_data: bool) -> ParseOutcome {
let mut alt = false;
let mut idx = 0;
if buf[0] == 0x1b {
if buf.len() == 1 {
return if can_have_more_data {
ParseOutcome::NeedMore
} else {
ParseOutcome::Parsed(1, Some(message_from_key(KeyMsg::from_type(KeyType::Esc))))
};
}
alt = true;
idx = 1;
}
if idx >= buf.len() {
return ParseOutcome::NeedMore;
}
if let Some(key_type) = control_key_type(buf[idx]) {
let mut key = KeyMsg::from_type(key_type);
if alt {
key = key.with_alt();
}
return ParseOutcome::Parsed(idx + 1, Some(message_from_key(key)));
}
let mut runes = Vec::new();
let mut i = idx;
while i < buf.len() {
let b = buf[i];
if is_control_or_space(b) {
break;
}
let (ch, width, valid) = match decode_char(&buf[i..], can_have_more_data) {
DecodeOutcome::NeedMore => return ParseOutcome::NeedMore,
DecodeOutcome::Decoded(ch, width, valid) => (ch, width, valid),
};
if !valid {
runes.push(std::char::REPLACEMENT_CHARACTER);
i += 1;
} else {
runes.push(ch);
i += width;
}
if alt {
break;
}
}
if !runes.is_empty() {
let mut key = KeyMsg::from_runes(runes);
if alt {
key = key.with_alt();
}
return ParseOutcome::Parsed(i, Some(message_from_key(key)));
}
ParseOutcome::Parsed(1, Some(replacement_message()))
}
fn control_key_type(byte: u8) -> Option<KeyType> {
match byte {
0x00 => Some(KeyType::Null),
0x01 => Some(KeyType::CtrlA),
0x02 => Some(KeyType::CtrlB),
0x03 => Some(KeyType::CtrlC),
0x04 => Some(KeyType::CtrlD),
0x05 => Some(KeyType::CtrlE),
0x06 => Some(KeyType::CtrlF),
0x07 => Some(KeyType::CtrlG),
0x08 => Some(KeyType::CtrlH),
0x09 => Some(KeyType::Tab),
0x0A => Some(KeyType::CtrlJ),
0x0B => Some(KeyType::CtrlK),
0x0C => Some(KeyType::CtrlL),
0x0D => Some(KeyType::Enter),
0x0E => Some(KeyType::CtrlN),
0x0F => Some(KeyType::CtrlO),
0x10 => Some(KeyType::CtrlP),
0x11 => Some(KeyType::CtrlQ),
0x12 => Some(KeyType::CtrlR),
0x13 => Some(KeyType::CtrlS),
0x14 => Some(KeyType::CtrlT),
0x15 => Some(KeyType::CtrlU),
0x16 => Some(KeyType::CtrlV),
0x17 => Some(KeyType::CtrlW),
0x18 => Some(KeyType::CtrlX),
0x19 => Some(KeyType::CtrlY),
0x1A => Some(KeyType::CtrlZ),
0x1B => Some(KeyType::Esc),
0x1C => Some(KeyType::CtrlBackslash),
0x1D => Some(KeyType::CtrlCloseBracket),
0x1E => Some(KeyType::CtrlCaret),
0x1F => Some(KeyType::CtrlUnderscore),
0x20 => Some(KeyType::Space),
0x7F => Some(KeyType::Backspace),
_ => None,
}
}
fn is_control_or_space(byte: u8) -> bool {
byte <= 0x1F || byte == 0x7F || byte == b' '
}
enum DecodeOutcome {
NeedMore,
Decoded(char, usize, bool),
}
fn decode_char(input: &[u8], can_have_more_data: bool) -> DecodeOutcome {
let first = input[0];
let width = if first < 0x80 {
1
} else if (first & 0xE0) == 0xC0 {
2
} else if (first & 0xF0) == 0xE0 {
3
} else if (first & 0xF8) == 0xF0 {
4
} else {
return DecodeOutcome::Decoded(std::char::REPLACEMENT_CHARACTER, 1, false);
};
if input.len() < width {
return if can_have_more_data {
DecodeOutcome::NeedMore
} else {
DecodeOutcome::Decoded(std::char::REPLACEMENT_CHARACTER, 1, false)
};
}
match std::str::from_utf8(&input[..width]) {
Ok(s) => {
let ch = s.chars().next().unwrap_or(std::char::REPLACEMENT_CHARACTER);
DecodeOutcome::Decoded(ch, width, true)
}
Err(_) => DecodeOutcome::Decoded(std::char::REPLACEMENT_CHARACTER, 1, false),
}
}
fn message_from_key(key: KeyMsg) -> Message {
if key.key_type == KeyType::CtrlC {
Message::new(InterruptMsg)
} else {
Message::new(key)
}
}
fn replacement_message() -> Message {
Message::new(KeyMsg::from_char(std::char::REPLACEMENT_CHARACTER))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
struct TestModel {
count: i32,
}
impl Model for TestModel {
fn init(&self) -> Option<Cmd> {
None
}
fn update(&mut self, msg: Message) -> Option<Cmd> {
if let Some(n) = msg.downcast::<i32>() {
self.count += n;
}
None
}
fn view(&self) -> String {
format!("Count: {}", self.count)
}
}
#[test]
fn test_program_options_default() {
let opts = ProgramOptions::default();
assert!(!opts.alt_screen);
assert!(!opts.mouse_cell_motion);
assert!(opts.bracketed_paste);
assert_eq!(opts.fps, 60);
}
#[test]
fn test_program_builder() {
let model = TestModel { count: 0 };
let program = Program::new(model)
.with_alt_screen()
.with_mouse_cell_motion()
.with_fps(30);
assert!(program.options.alt_screen);
assert!(program.options.mouse_cell_motion);
assert_eq!(program.options.fps, 30);
}
#[test]
fn test_program_fps_max() {
let model = TestModel { count: 0 };
let program = Program::new(model).with_fps(200);
assert_eq!(program.options.fps, 120); }
#[test]
fn test_program_fps_min() {
let model = TestModel { count: 0 };
let program = Program::new(model).with_fps(0);
assert_eq!(program.options.fps, 1); }
#[test]
fn test_parse_bracketed_paste_basic() {
let input = b"\x1b[200~hello world\x1b[201~";
let result = parse_bracketed_paste(input, false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(len, Some(msg))) = result {
assert_eq!(len, input.len());
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste, "Key should have paste flag set");
assert_eq!(
key.runes,
vec!['h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd']
);
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn test_parse_bracketed_paste_empty() {
let input = b"\x1b[200~\x1b[201~";
let result = parse_bracketed_paste(input, false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(len, Some(msg))) = result {
assert_eq!(len, input.len());
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste);
assert!(key.runes.is_empty());
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn test_parse_bracketed_paste_multiline() {
let input = b"\x1b[200~line1\nline2\nline3\x1b[201~";
let result = parse_bracketed_paste(input, false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(len, Some(msg))) = result {
assert_eq!(len, input.len());
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste);
let text: String = key.runes.iter().collect();
assert_eq!(text, "line1\nline2\nline3");
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn test_parse_bracketed_paste_unicode() {
let input = "\x1b[200~hello 世界 🌍\x1b[201~".as_bytes();
let result = parse_bracketed_paste(input, false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(_, Some(msg))) = result {
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste);
let text: String = key.runes.iter().collect();
assert_eq!(text, "hello 世界 🌍");
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn test_parse_bracketed_paste_incomplete() {
let input = b"\x1b[200~hello";
let result = parse_bracketed_paste(input, true);
assert!(matches!(result, Some(ParseOutcome::NeedMore)));
}
#[test]
fn test_parse_bracketed_paste_incomplete_no_more_data() {
let input = b"\x1b[200~hello";
let result = parse_bracketed_paste(input, false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(len, Some(msg))) = result {
assert_eq!(len, input.len());
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste);
let text: String = key.runes.iter().collect();
assert_eq!(text, "hello");
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn test_parse_bracketed_paste_not_bracketed() {
let input = b"hello";
let result = parse_bracketed_paste(input, false);
assert!(result.is_none(), "Non-paste input should return None");
}
#[test]
fn test_parse_bracketed_paste_large() {
let content = "a".repeat(10000);
let input = format!("\x1b[200~{}\x1b[201~", content);
let result = parse_bracketed_paste(input.as_bytes(), false);
assert!(result.is_some());
if let Some(ParseOutcome::Parsed(len, Some(msg))) = result {
assert_eq!(len, input.len());
let key = msg.downcast_ref::<KeyMsg>().unwrap();
assert!(key.paste);
assert_eq!(key.runes.len(), 10000);
} else {
panic!("Expected Parsed outcome");
}
}
#[test]
fn spawn_batch_executes_closure() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let executed = Arc::new(AtomicBool::new(false));
let clone = executed.clone();
spawn_batch(move || {
clone.store(true, Ordering::SeqCst);
});
thread::sleep(Duration::from_millis(200));
assert!(
executed.load(Ordering::SeqCst),
"spawn_batch should execute the closure"
);
}
#[test]
fn spawn_batch_handles_many_concurrent_tasks() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));
let task_count = 200;
for _ in 0..task_count {
let c = counter.clone();
spawn_batch(move || {
c.fetch_add(1, Ordering::SeqCst);
});
}
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while counter.load(Ordering::SeqCst) < task_count && std::time::Instant::now() < deadline {
thread::sleep(Duration::from_millis(50));
}
assert_eq!(
counter.load(Ordering::SeqCst),
task_count,
"All {task_count} tasks should complete"
);
}
#[test]
fn handle_command_batch_executes_all_subcommands() {
let model = TestModel { count: 0 };
let program = Program::new(model);
let (tx, rx) = mpsc::channel();
let command_threads = Arc::new(Mutex::new(Vec::new()));
let cmds: Vec<Option<Cmd>> = (0..50)
.map(|i| Some(Cmd::new(move || Message::new(i))))
.collect();
let batch_cmd = crate::batch(cmds).unwrap();
program.handle_command(batch_cmd, tx, Arc::clone(&command_threads));
let mut results = Vec::new();
let deadline = std::time::Instant::now() + Duration::from_secs(5);
while results.len() < 50 && std::time::Instant::now() < deadline {
if let Ok(msg) = rx.recv_timeout(Duration::from_millis(100)) {
results.push(msg.downcast::<i32>().unwrap());
}
}
assert_eq!(
results.len(),
50,
"All 50 batch sub-commands should produce results"
);
results.sort();
let expected: Vec<i32> = (0..50).collect();
assert_eq!(
results, expected,
"Each sub-command value should be received exactly once"
);
}
#[cfg(feature = "thread-pool")]
#[test]
fn handle_command_batch_bounded_parallelism() {
use std::sync::atomic::{AtomicUsize, Ordering};
let model = TestModel { count: 0 };
let program = Program::new(model);
let (tx, rx) = mpsc::channel();
let command_threads = Arc::new(Mutex::new(Vec::new()));
let active = Arc::new(AtomicUsize::new(0));
let max_active = Arc::new(AtomicUsize::new(0));
let task_count: usize = 100;
let cmds: Vec<Option<Cmd>> = (0..task_count)
.map(|_| {
let a = active.clone();
let m = max_active.clone();
Some(Cmd::new(move || {
let current = a.fetch_add(1, Ordering::SeqCst) + 1;
m.fetch_max(current, Ordering::SeqCst);
thread::sleep(Duration::from_millis(5));
a.fetch_sub(1, Ordering::SeqCst);
Message::new(1i32)
}))
})
.collect();
let batch_cmd = crate::batch(cmds).unwrap();
program.handle_command(batch_cmd, tx, Arc::clone(&command_threads));
let mut count = 0usize;
let deadline = std::time::Instant::now() + Duration::from_secs(15);
while count < task_count && std::time::Instant::now() < deadline {
if let Ok(_msg) = rx.recv_timeout(Duration::from_millis(100)) {
count += 1;
}
}
assert_eq!(count, task_count, "All batch commands should complete");
let observed_max = max_active.load(Ordering::SeqCst);
let num_cpus = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
assert!(
observed_max <= num_cpus + 2,
"Expected bounded parallelism near {num_cpus}, but observed {observed_max}. \
Without thread-pool feature, this would be near {task_count}."
);
}
#[test]
fn handle_command_large_batch_no_panic() {
let model = TestModel { count: 0 };
let program = Program::new(model);
let (tx, rx) = mpsc::channel();
let command_threads = Arc::new(Mutex::new(Vec::new()));
let cmds: Vec<Option<Cmd>> = (0..500)
.map(|i| Some(Cmd::new(move || Message::new(i))))
.collect();
let batch_cmd = crate::batch(cmds).unwrap();
program.handle_command(batch_cmd, tx, Arc::clone(&command_threads));
let mut count = 0usize;
let deadline = std::time::Instant::now() + Duration::from_secs(10);
while count < 500 && std::time::Instant::now() < deadline {
if let Ok(_msg) = rx.recv_timeout(Duration::from_millis(50)) {
count += 1;
}
}
assert_eq!(count, 500, "Large batch should complete without panic");
}
#[test]
fn test_thread_handles_captured() {
let handle: Option<thread::JoinHandle<()>> = Some(thread::spawn(|| {
}));
assert!(handle.is_some(), "Handle should be captured");
if let Some(h) = handle {
h.join().expect("Thread should join successfully");
}
}
#[test]
fn test_threads_exit_on_channel_drop() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let thread_exited = Arc::new(AtomicBool::new(false));
let thread_exited_clone = Arc::clone(&thread_exited);
let (tx, rx) = mpsc::channel::<i32>();
let handle = thread::spawn(move || {
while rx.recv().is_ok() {}
thread_exited_clone.store(true, Ordering::SeqCst);
});
assert!(!thread_exited.load(Ordering::SeqCst));
drop(tx);
handle
.join()
.expect("Thread should join after channel drop");
assert!(
thread_exited.load(Ordering::SeqCst),
"Thread should have exited after channel drop"
);
}
#[test]
fn test_shutdown_joins_all_threads() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let join_count = Arc::new(AtomicUsize::new(0));
let mut handles: Vec<thread::JoinHandle<()>> = Vec::new();
for i in 0..3 {
let join_count_clone = Arc::clone(&join_count);
handles.push(thread::spawn(move || {
thread::sleep(Duration::from_millis(10 * (i as u64 + 1)));
join_count_clone.fetch_add(1, Ordering::SeqCst);
}));
}
for handle in handles {
match handle.join() {
Ok(()) => {} Err(e) => panic!("Thread panicked during join: {:?}", e),
}
}
assert_eq!(
join_count.load(Ordering::SeqCst),
3,
"All threads should have completed and been joined"
);
}
#[test]
fn test_thread_panic_handled_gracefully() {
let handle = thread::spawn(|| {
panic!("Intentional panic for testing");
});
let result = handle.join();
let e = result.expect_err("Join should return Err when thread panics");
let panic_info = format!("{e:?}");
tracing::warn!(panic = %panic_info, "Thread panicked during join");
}
#[test]
fn test_external_forwarder_pattern() {
let (external_tx, external_rx) = mpsc::channel::<Message>();
let (event_tx, event_rx) = mpsc::channel::<Message>();
let tx_clone = event_tx.clone();
let handle = thread::spawn(move || {
while let Ok(msg) = external_rx.recv() {
if tx_clone.send(msg).is_err() {
break;
}
}
});
external_tx.send(Message::new(1i32)).unwrap();
external_tx.send(Message::new(2i32)).unwrap();
external_tx.send(Message::new(3i32)).unwrap();
drop(external_tx);
let join_result = handle.join();
assert!(
join_result.is_ok(),
"Forwarder thread should exit cleanly when sender is dropped"
);
let mut received = Vec::new();
while let Ok(msg) = event_rx.try_recv() {
if let Some(&n) = msg.downcast_ref::<i32>() {
received.push(n);
}
}
assert_eq!(received, vec![1, 2, 3], "All messages should be forwarded");
}
#[test]
fn test_task_tracker_spawn_blocking_tracks_thread() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
let thread_completed = Arc::new(AtomicBool::new(false));
let thread_completed_clone = Arc::clone(&thread_completed);
rt.block_on(async {
let task_tracker = TaskTracker::new();
task_tracker.spawn_blocking(move || {
thread::sleep(Duration::from_millis(50));
thread_completed_clone.store(true, Ordering::SeqCst);
});
task_tracker.close();
task_tracker.wait().await;
assert!(
thread_completed.load(Ordering::SeqCst),
"spawn_blocking task should complete before wait() returns"
);
});
}
#[test]
fn test_cancellation_token_stops_blocking_task() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
let task_exited = Arc::new(AtomicBool::new(false));
let task_exited_clone = Arc::clone(&task_exited);
rt.block_on(async {
let cancel_token = CancellationToken::new();
let task_tracker = TaskTracker::new();
let cancel_clone = cancel_token.clone();
task_tracker.spawn_blocking(move || {
loop {
if cancel_clone.is_cancelled() {
task_exited_clone.store(true, Ordering::SeqCst);
break;
}
thread::sleep(Duration::from_millis(10));
}
});
thread::sleep(Duration::from_millis(30));
assert!(
!task_exited.load(Ordering::SeqCst),
"Task should still be running before cancellation"
);
cancel_token.cancel();
task_tracker.close();
task_tracker.wait().await;
assert!(
task_exited.load(Ordering::SeqCst),
"Task should exit after cancellation"
);
});
}
#[test]
fn test_graceful_shutdown_waits_for_all_blocking_tasks() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
let completed_count = Arc::new(AtomicUsize::new(0));
rt.block_on(async {
let cancel_token = CancellationToken::new();
let task_tracker = TaskTracker::new();
for i in 0..3 {
let count_clone = Arc::clone(&completed_count);
let cancel_clone = cancel_token.clone();
task_tracker.spawn_blocking(move || {
loop {
if cancel_clone.is_cancelled() {
break;
}
thread::sleep(Duration::from_millis(10));
}
thread::sleep(Duration::from_millis(10 * (i as u64 + 1)));
count_clone.fetch_add(1, Ordering::SeqCst);
});
}
thread::sleep(Duration::from_millis(30));
assert_eq!(
completed_count.load(Ordering::SeqCst),
0,
"No tasks should complete before shutdown"
);
cancel_token.cancel();
task_tracker.close();
let timeout_result: std::result::Result<(), tokio::time::error::Elapsed> =
tokio::time::timeout(Duration::from_secs(2), task_tracker.wait()).await;
assert!(
timeout_result.is_ok(),
"All tasks should complete within timeout"
);
assert_eq!(
completed_count.load(Ordering::SeqCst),
3,
"All 3 tasks should complete during graceful shutdown"
);
});
}
#[test]
fn test_spawn_blocking_vs_spawn_difference() {
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
let untracked_done = Arc::new(AtomicBool::new(false));
let untracked_done_clone = Arc::clone(&untracked_done);
rt.block_on(async {
let task_tracker = TaskTracker::new();
let _handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
untracked_done_clone.store(true, Ordering::SeqCst);
});
task_tracker.close();
task_tracker.wait().await;
});
let tracked_done = Arc::new(AtomicBool::new(false));
let tracked_done_clone = Arc::clone(&tracked_done);
rt.block_on(async {
let task_tracker = TaskTracker::new();
task_tracker.spawn_blocking(move || {
thread::sleep(Duration::from_millis(50));
tracked_done_clone.store(true, Ordering::SeqCst);
});
task_tracker.close();
task_tracker.wait().await;
assert!(
tracked_done.load(Ordering::SeqCst),
"spawn_blocking task should complete before wait() returns"
);
});
}
#[test]
fn test_event_thread_pattern_with_poll_timeout() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create runtime");
let poll_count = Arc::new(AtomicUsize::new(0));
let poll_count_clone = Arc::clone(&poll_count);
rt.block_on(async {
let cancel_token = CancellationToken::new();
let task_tracker = TaskTracker::new();
let cancel_clone = cancel_token.clone();
task_tracker.spawn_blocking(move || {
loop {
if cancel_clone.is_cancelled() {
break;
}
thread::sleep(Duration::from_millis(25));
poll_count_clone.fetch_add(1, Ordering::SeqCst);
}
});
thread::sleep(Duration::from_millis(100));
let count_before_cancel = poll_count.load(Ordering::SeqCst);
assert!(
count_before_cancel >= 2,
"Thread should have polled multiple times: {}",
count_before_cancel
);
cancel_token.cancel();
task_tracker.close();
task_tracker.wait().await;
let final_count = poll_count.load(Ordering::SeqCst);
assert!(
final_count >= count_before_cancel,
"Poll count should not decrease"
);
});
}
}