use crate::log_debug;
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::watch;
#[derive(Clone)]
pub struct AnimationState {
pub cost: Arc<AtomicU64>,
pub context_tokens: Arc<AtomicU64>,
pub max_threshold: Arc<AtomicU64>,
}
impl AnimationState {
pub fn new() -> Self {
Self {
cost: Arc::new(AtomicU64::new(0)),
context_tokens: Arc::new(AtomicU64::new(0)),
max_threshold: Arc::new(AtomicU64::new(0)),
}
}
pub fn update_cost(&self, cost: f64) {
self.cost.store((cost * 10000.0) as u64, Ordering::Relaxed);
}
pub fn get_cost(&self) -> f64 {
self.cost.load(Ordering::Relaxed) as f64 / 10000.0
}
pub fn update_context_tokens(&self, tokens: u64) {
self.context_tokens.store(tokens, Ordering::Relaxed);
}
pub fn get_context_tokens(&self) -> u64 {
self.context_tokens.load(Ordering::Relaxed)
}
pub fn update_max_threshold(&self, threshold: usize) {
self.max_threshold
.store(threshold as u64, Ordering::Relaxed);
}
pub fn get_max_threshold(&self) -> usize {
self.max_threshold.load(Ordering::Relaxed) as usize
}
}
impl Default for AnimationState {
fn default() -> Self {
Self::new()
}
}
pub struct AnimationManager {
spinner: Arc<Mutex<Option<ProgressBar>>>,
state: AnimationState,
cancel_rx: Arc<Mutex<Option<watch::Receiver<bool>>>>,
cancel_watcher: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
suspended: Arc<AtomicBool>,
}
impl AnimationManager {
pub fn new() -> Self {
Self {
spinner: Arc::new(Mutex::new(None)),
state: AnimationState::new(),
cancel_rx: Arc::new(Mutex::new(None)),
cancel_watcher: Arc::new(Mutex::new(None)),
suspended: Arc::new(AtomicBool::new(false)),
}
}
pub fn get_state(&self) -> AnimationState {
self.state.clone()
}
pub fn set_cancel_receiver(&self, rx: watch::Receiver<bool>) {
*self.cancel_rx.lock().unwrap() = Some(rx);
}
pub fn clear_cancel_receiver(&self) {
*self.cancel_rx.lock().unwrap() = None;
}
pub async fn suspend(&self) {
self.suspended.store(true, Ordering::SeqCst);
self.stop_current().await;
log_debug!("Animation suspended — user prompt imminent");
}
pub fn resume(&self) {
self.suspended.store(false, Ordering::SeqCst);
log_debug!("Animation resumed");
}
pub fn is_suspended(&self) -> bool {
self.suspended.load(Ordering::SeqCst)
}
pub fn with_suspended_spinner<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
let spinner_guard = self.spinner.lock().unwrap();
if let Some(ref pb) = *spinner_guard {
pb.suspend(f)
} else {
drop(spinner_guard);
f()
}
}
pub async fn start_animation(&self, mode: &crate::session::output::OutputMode) {
if self.is_suspended() {
log_debug!("start_animation: manager suspended — skipping");
return;
}
if !mode.should_show_animations() {
return;
}
self.ensure_started_internal();
}
pub async fn start_with_params(&self, cost: f64, context_tokens: u64, max_threshold: usize) {
if self.is_suspended() {
log_debug!("start_with_params: manager suspended — skipping");
return;
}
let output_mode = crate::config::with_thread_config(|config| config.output_mode())
.unwrap_or(crate::session::output::OutputMode::NonInteractive);
if !output_mode.should_show_animations() {
if output_mode.is_terminal_mode() {
if cost > 0.0 {
println!(
" ── cost: ${:.5} ────────────────────────────────────────",
cost
);
} else if max_threshold > 0 {
let percentage =
(context_tokens as f64 / max_threshold as f64 * 100.0).min(100.0);
println!(
" ── context: {:.1}% ────────────────────────────────────────",
percentage
);
}
}
return;
}
self.state.update_cost(cost);
self.state.update_context_tokens(context_tokens);
self.state.update_max_threshold(max_threshold);
self.ensure_started_internal();
}
pub fn update_state(&self, cost: f64, context_tokens: u64, max_threshold: usize) {
self.state.update_cost(cost);
self.state.update_context_tokens(context_tokens);
self.state.update_max_threshold(max_threshold);
let guard = self.spinner.lock().unwrap();
if let Some(ref pb) = *guard {
let cost_bits = self.state.cost.load(Ordering::Relaxed);
let ctx = self.state.context_tokens.load(Ordering::Relaxed);
let thresh = self.state.max_threshold.load(Ordering::Relaxed);
pb.set_message(build_base_message(cost_bits, ctx, thresh));
}
}
pub async fn set_phase(&self, phase: &str) {
if self.is_suspended() {
return;
}
let output_mode = crate::config::with_thread_config(|c| c.output_mode())
.unwrap_or(crate::session::output::OutputMode::NonInteractive);
if !output_mode.should_show_animations() {
return;
}
self.ensure_started_internal();
let guard = self.spinner.lock().unwrap();
if let Some(ref pb) = *guard {
let cost_bits = self.state.cost.load(Ordering::Relaxed);
let ctx = self.state.context_tokens.load(Ordering::Relaxed);
let thresh = self.state.max_threshold.load(Ordering::Relaxed);
pb.set_message(build_phase_message(cost_bits, ctx, thresh, phase));
}
}
pub fn clear_phase(&self) {
let guard = self.spinner.lock().unwrap();
if let Some(ref pb) = *guard {
let cost_bits = self.state.cost.load(Ordering::Relaxed);
let ctx = self.state.context_tokens.load(Ordering::Relaxed);
let thresh = self.state.max_threshold.load(Ordering::Relaxed);
pb.set_message(build_base_message(cost_bits, ctx, thresh));
}
}
fn ensure_started_internal(&self) {
let mut guard = self.spinner.lock().unwrap();
if let Some(ref pb) = *guard {
let cost_bits = self.state.cost.load(Ordering::Relaxed);
let ctx = self.state.context_tokens.load(Ordering::Relaxed);
let thresh = self.state.max_threshold.load(Ordering::Relaxed);
pb.set_message(build_base_message(cost_bits, ctx, thresh));
return;
}
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::default_spinner()
.with_key(
"elapsed_custom",
|ps: &ProgressState, w: &mut dyn std::fmt::Write| {
let elapsed = ps.elapsed();
if elapsed.as_secs() > 0 {
let _ = write!(
w,
"({} • Ctrl+C to interrupt)",
crate::session::chat::animation::format_elapsed_time(elapsed)
);
} else {
let _ = write!(w, "(Ctrl+C to interrupt)");
}
},
)
.template(" {spinner:.cyan} {msg:.cyan} {elapsed_custom:.cyan.dim}")
.unwrap()
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧"),
);
let cost_bits = self.state.cost.load(Ordering::Relaxed);
let ctx = self.state.context_tokens.load(Ordering::Relaxed);
let thresh = self.state.max_threshold.load(Ordering::Relaxed);
pb.set_message(build_base_message(cost_bits, ctx, thresh));
pb.enable_steady_tick(Duration::from_millis(100));
*guard = Some(pb.clone());
drop(guard);
self.spawn_cancel_watcher(pb);
}
fn spawn_cancel_watcher(&self, pb: ProgressBar) {
let cancel_rx = self.cancel_rx.lock().unwrap().clone();
let Some(mut rx) = cancel_rx else {
return;
};
let spinner_ref = self.spinner.clone();
let cancel_watcher = self.cancel_watcher.clone();
if let Some(prior) = cancel_watcher.lock().unwrap().take() {
prior.abort();
}
let handle = tokio::spawn(async move {
loop {
if *rx.borrow() {
break;
}
if rx.changed().await.is_err() {
return;
}
}
let pb_for_block = pb.clone();
let _ = tokio::task::spawn_blocking(move || {
pb_for_block.finish_and_clear();
pb_for_block.disable_steady_tick();
})
.await;
*spinner_ref.lock().unwrap() = None;
log_debug!("Animation cancelled via session cancellation channel");
});
*cancel_watcher.lock().unwrap() = Some(handle);
}
pub async fn stop_current(&self) {
let pb = self.spinner.lock().unwrap().take();
if let Some(handle) = self.cancel_watcher.lock().unwrap().take() {
handle.abort();
}
self.clear_cancel_receiver();
let Some(pb) = pb else {
return;
};
let join = tokio::task::spawn_blocking(move || {
pb.finish_and_clear();
pb.disable_steady_tick();
});
match tokio::time::timeout(Duration::from_millis(500), join).await {
Ok(_) => {}
Err(_) => {
log_debug!("stop_current: disable_steady_tick timed out — leaving detached");
}
}
}
pub fn is_running(&self) -> bool {
self.spinner.lock().unwrap().is_some()
}
}
impl Default for AnimationManager {
fn default() -> Self {
Self::new()
}
}
fn build_base_message(cost_bits: u64, ctx: u64, thresh: u64) -> String {
let cost = cost_bits as f64 / 10000.0;
if cost > 0.0 && thresh > 0 {
let pct = (ctx as f64 / thresh as f64 * 100.0).min(100.0);
format!("[${:.2}|{:.1}%] Working …", cost, pct)
} else if cost > 0.0 {
format!("[${:.2}|∞] Working …", cost)
} else if thresh > 0 {
let pct = (ctx as f64 / thresh as f64 * 100.0).min(100.0);
format!("[{:.1}%] Working …", pct)
} else {
"Working …".to_string()
}
}
fn build_phase_message(cost_bits: u64, ctx: u64, thresh: u64, phase: &str) -> String {
let cost = cost_bits as f64 / 10000.0;
if cost > 0.0 && thresh > 0 {
let pct = (ctx as f64 / thresh as f64 * 100.0).min(100.0);
format!("[${:.2}|{:.1}%] {}", cost, pct, phase)
} else if cost > 0.0 {
format!("[${:.2}|∞] {}", cost, phase)
} else if thresh > 0 {
let pct = (ctx as f64 / thresh as f64 * 100.0).min(100.0);
format!("[{:.1}%] {}", pct, phase)
} else {
phase.to_string()
}
}
pub static GLOBAL_ANIMATION_MANAGER: std::sync::OnceLock<AnimationManager> =
std::sync::OnceLock::new();
pub fn get_animation_manager() -> &'static AnimationManager {
GLOBAL_ANIMATION_MANAGER.get_or_init(AnimationManager::new)
}