use std::collections::BTreeMap;
use std::future::Future;
use std::sync::{Arc, Mutex};
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::llm::{Message, MessageContent, Role};
use crate::session::History;
const DEFAULT_RECENT_BLOCKS: usize = 10;
const DEFAULT_FINISHED_TASKS_CAP: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackgroundProgressConfig {
pub default_recent_blocks: usize,
pub block_text_limit: usize,
pub finished_tasks_cap: usize,
}
impl Default for BackgroundProgressConfig {
fn default() -> Self {
Self {
default_recent_blocks: DEFAULT_RECENT_BLOCKS,
block_text_limit: 0,
finished_tasks_cap: DEFAULT_FINISHED_TASKS_CAP,
}
}
}
impl BackgroundProgressConfig {
fn resolve_recent(&self, requested: Option<usize>) -> usize {
requested.unwrap_or(self.default_recent_blocks).max(1)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BackgroundOutcome {
pub task_id: String,
pub label: String,
pub result: BackgroundResult,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BackgroundResult {
Completed(String),
Failed(String),
}
impl BackgroundResult {
fn is_error(&self) -> bool {
matches!(self, BackgroundResult::Failed(_))
}
fn text(&self) -> &str {
match self {
BackgroundResult::Completed(t) | BackgroundResult::Failed(t) => t,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskStatus {
Running,
Completed,
Failed,
Canceled,
}
impl TaskStatus {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
TaskStatus::Running => "running",
TaskStatus::Completed => "completed",
TaskStatus::Failed => "failed",
TaskStatus::Canceled => "canceled",
}
}
fn is_terminal(&self) -> bool {
!matches!(self, TaskStatus::Running)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockKind {
User,
AssistantText,
Thought,
ToolUse,
ToolResult,
Other,
}
impl BlockKind {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
BlockKind::User => "user",
BlockKind::AssistantText => "assistant",
BlockKind::Thought => "thought",
BlockKind::ToolUse => "tool_use",
BlockKind::ToolResult => "tool_result",
BlockKind::Other => "other",
}
}
fn is_free_form_body(&self) -> bool {
matches!(
self,
BlockKind::User | BlockKind::AssistantText | BlockKind::Thought | BlockKind::ToolResult
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProgressBlock {
pub kind: BlockKind,
pub text: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TaskSnapshot {
pub task_id: String,
pub label: String,
pub status: TaskStatus,
pub block_count: usize,
pub recent: Vec<ProgressBlock>,
}
fn truncate_body(text: &str, limit: usize) -> String {
if limit == 0 {
return String::new();
}
let total = text.chars().count();
if total <= limit {
return text.to_string();
}
let kept: String = text.chars().take(limit).collect();
format!("{kept} …(+{} more chars)", total - limit)
}
fn tool_result_text(body: &crate::llm::ToolResultBody) -> String {
use crate::llm::{ToolResultBody, ToolResultContent};
match body {
ToolResultBody::Text { text } => text.clone(),
ToolResultBody::Json { value } => value.to_string(),
ToolResultBody::Content { blocks } => blocks
.iter()
.map(|b| match b {
ToolResultContent::Text { text } => text.clone(),
ToolResultContent::Image { .. } => "<image>".to_string(),
})
.collect::<Vec<_>>()
.join("\n"),
}
}
fn block_of_content(content: &MessageContent, role: Role, limit: usize) -> ProgressBlock {
let (kind, raw): (BlockKind, String) = match content {
MessageContent::Text { text } => {
let kind = if role == Role::Assistant {
BlockKind::AssistantText
} else {
BlockKind::User
};
(kind, text.clone())
}
MessageContent::Thinking { text, .. } => (BlockKind::Thought, text.clone()),
MessageContent::ToolUse { name, .. } => (BlockKind::ToolUse, name.clone()),
MessageContent::ToolResult { output, .. } => {
(BlockKind::ToolResult, tool_result_text(output))
}
MessageContent::Image { .. } => (BlockKind::Other, "<image>".to_string()),
MessageContent::ProviderActivity { kind, .. } => {
(BlockKind::Other, format!("provider activity: {kind:?}"))
}
};
let text = if kind.is_free_form_body() {
truncate_body(&raw, limit)
} else {
raw
};
ProgressBlock { kind, text }
}
fn recent_blocks_of(messages: &[Message], n: usize, limit: usize) -> (usize, Vec<ProgressBlock>) {
let mut all: Vec<ProgressBlock> = Vec::new();
for m in messages {
for c in m.content.iter() {
all.push(block_of_content(c, m.role, limit));
}
}
let total = all.len();
let skip = total.saturating_sub(n);
(total, all.into_iter().skip(skip).collect())
}
struct TaskEntry {
label: String,
status: TaskStatus,
cancel: CancellationToken,
history: Option<Arc<dyn History>>,
handle: Option<JoinHandle<()>>,
finished_seq: Option<u64>,
}
struct BackgroundInner {
next_id: u64,
next_finished_seq: u64,
tasks: BTreeMap<String, TaskEntry>,
completed: Vec<BackgroundOutcome>,
finished_tasks_cap: usize,
}
impl BackgroundInner {
fn finish(&mut self, id: &str, status: TaskStatus) {
let seq = self.next_finished_seq;
self.next_finished_seq += 1;
if let Some(entry) = self.tasks.get_mut(id) {
entry.status = status;
entry.handle = None;
entry.finished_seq = Some(seq);
}
self.prune_finished();
}
fn prune_finished(&mut self) {
let mut finished: Vec<(u64, String)> = self
.tasks
.iter()
.filter_map(|(id, e)| e.finished_seq.map(|seq| (seq, id.clone())))
.collect();
if finished.len() <= self.finished_tasks_cap {
return;
}
finished.sort_by_key(|(seq, _)| *seq);
let drop_count = finished.len() - self.finished_tasks_cap;
for (_, id) in finished.into_iter().take(drop_count) {
self.tasks.remove(&id);
}
}
}
#[derive(Clone)]
pub struct BackgroundTasks {
cancel: CancellationToken,
completed_notify: Arc<Notify>,
progress_config: BackgroundProgressConfig,
inner: Arc<Mutex<BackgroundInner>>,
}
impl BackgroundTasks {
#[must_use]
pub fn new(
session_cancel: CancellationToken,
progress_config: BackgroundProgressConfig,
) -> Self {
Self {
cancel: session_cancel,
completed_notify: Arc::new(Notify::new()),
progress_config,
inner: Arc::new(Mutex::new(BackgroundInner {
next_id: 0,
next_finished_seq: 0,
tasks: BTreeMap::new(),
completed: Vec::new(),
finished_tasks_cap: progress_config.finished_tasks_cap,
})),
}
}
pub async fn wait_for_completion(&self) {
self.completed_notify.notified().await;
}
#[must_use]
pub fn has_completed(&self) -> bool {
!self
.inner
.lock()
.expect("BackgroundTasks mutex poisoned")
.completed
.is_empty()
}
pub fn spawn<F, Fut>(&self, label: String, make_fut: F) -> String
where
F: FnOnce(CancellationToken, TaskHandle) -> Fut,
Fut: Future<Output = BackgroundResult> + Send + 'static,
{
let mut inner = self.inner.lock().expect("BackgroundTasks mutex poisoned");
let id = format!("bg-{}", inner.next_id);
inner.next_id += 1;
let task_cancel = self.cancel.child_token();
let handle = TaskHandle {
inner: self.inner.clone(),
task_id: id.clone(),
};
let cancel_for_task = task_cancel.clone();
let fut = make_fut(task_cancel.clone(), handle);
let inner_arc = self.inner.clone();
let notify = self.completed_notify.clone();
let id_for_task = id.clone();
let label_for_task = label.clone();
let join = tokio::spawn(async move {
let result = fut.await;
let status = if cancel_for_task.is_cancelled() {
TaskStatus::Canceled
} else if result.is_error() {
TaskStatus::Failed
} else {
TaskStatus::Completed
};
if let Ok(mut inner) = inner_arc.lock() {
inner.finish(&id_for_task, status);
inner.completed.push(BackgroundOutcome {
task_id: id_for_task,
label: label_for_task,
result,
});
}
notify.notify_one();
});
inner.tasks.insert(
id.clone(),
TaskEntry {
label,
status: TaskStatus::Running,
cancel: task_cancel,
history: None,
handle: Some(join),
finished_seq: None,
},
);
id
}
pub fn drain_completed(&self) -> Vec<BackgroundOutcome> {
let mut inner = self.inner.lock().expect("BackgroundTasks mutex poisoned");
std::mem::take(&mut inner.completed)
}
#[must_use]
pub fn running_count(&self) -> usize {
self.inner
.lock()
.expect("BackgroundTasks mutex poisoned")
.tasks
.values()
.filter(|e| e.status == TaskStatus::Running)
.count()
}
#[must_use]
pub fn list(&self) -> Vec<TaskSnapshot> {
let inner = self.inner.lock().expect("BackgroundTasks mutex poisoned");
inner
.tasks
.iter()
.map(|(id, e)| TaskSnapshot {
task_id: id.clone(),
label: e.label.clone(),
status: e.status,
block_count: 0,
recent: Vec::new(),
})
.collect()
}
#[must_use]
pub fn peek(&self, id: &str, recent_blocks: Option<usize>) -> Option<TaskSnapshot> {
let n = self.progress_config.resolve_recent(recent_blocks);
let limit = self.progress_config.block_text_limit;
let (label, status, history) = {
let inner = self.inner.lock().expect("BackgroundTasks mutex poisoned");
let entry = inner.tasks.get(id)?;
(entry.label.clone(), entry.status, entry.history.clone())
};
let (block_count, recent) = match history {
Some(h) => recent_blocks_of(&h.snapshot(), n, limit),
None => (0, Vec::new()),
};
Some(TaskSnapshot {
task_id: id.to_string(),
label,
status,
block_count,
recent,
})
}
pub fn cancel_task(&self, id: &str) -> Option<bool> {
let inner = self.inner.lock().expect("BackgroundTasks mutex poisoned");
let entry = inner.tasks.get(id)?;
if entry.status.is_terminal() {
return Some(false);
}
entry.cancel.cancel();
Some(true)
}
pub fn cancel_all(&self) {
self.cancel.cancel();
}
}
#[derive(Clone)]
pub struct TaskHandle {
inner: Arc<Mutex<BackgroundInner>>,
task_id: String,
}
impl TaskHandle {
pub fn attach_history(&self, history: Arc<dyn History>) {
if let Ok(mut inner) = self.inner.lock()
&& let Some(entry) = inner.tasks.get_mut(&self.task_id)
{
entry.history = Some(history);
}
}
}
#[must_use]
pub fn format_background_outcome(outcome: &BackgroundOutcome) -> String {
let status = if outcome.result.is_error() {
"failed"
} else {
"completed"
};
format!(
"⟨background task {} ({}) {}⟩\n{}",
outcome.task_id,
outcome.label,
status,
outcome.result.text()
)
}
#[cfg(test)]
mod tests;