Skip to main content

agentkit_task_manager/
lib.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use agentkit_core::{
7    Item, MetadataMap, TaskId, ToolCallId, ToolResultPart, TurnCancellation, TurnId,
8};
9use agentkit_tools_core::{
10    ApprovalRequest, AuthRequest, OwnedToolContext, ToolError, ToolExecutionOutcome, ToolExecutor,
11    ToolRequest,
12};
13use async_trait::async_trait;
14use thiserror::Error;
15use tokio::sync::{Mutex, Notify, mpsc};
16use tokio::task::JoinHandle;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum TaskKind {
20    Foreground,
21    Background,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub enum ContinuePolicy {
26    NotifyOnly,
27    RequestContinue,
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum DeliveryMode {
32    ToLoop,
33    Manual,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TaskSnapshot {
38    pub id: TaskId,
39    pub turn_id: TurnId,
40    pub call_id: ToolCallId,
41    pub tool_name: String,
42    pub kind: TaskKind,
43    pub metadata: MetadataMap,
44}
45
46#[derive(Clone, Debug, PartialEq)]
47pub enum TaskEvent {
48    Started(TaskSnapshot),
49    Detached(TaskSnapshot),
50    Completed(TaskSnapshot, ToolResultPart),
51    Cancelled(TaskSnapshot),
52    Failed(TaskSnapshot, ToolError),
53    ContinueRequested,
54}
55
56#[derive(Clone, Debug, PartialEq)]
57pub struct TaskApproval {
58    pub task_id: TaskId,
59    pub tool_request: ToolRequest,
60    pub approval: ApprovalRequest,
61}
62
63#[derive(Clone, Debug, PartialEq)]
64pub struct TaskAuth {
65    pub task_id: TaskId,
66    pub tool_request: ToolRequest,
67    pub auth: AuthRequest,
68}
69
70#[derive(Clone, Debug, PartialEq)]
71pub enum TaskResolution {
72    Item(Item),
73    Approval(TaskApproval),
74    Auth(TaskAuth),
75}
76
77#[derive(Clone, Debug, PartialEq)]
78pub enum TaskStartOutcome {
79    Ready(Box<TaskResolution>),
80    Pending { task_id: TaskId, kind: TaskKind },
81}
82
83#[derive(Clone, Debug, PartialEq)]
84pub enum TurnTaskUpdate {
85    Resolution(Box<TaskResolution>),
86    Detached(TaskSnapshot),
87}
88
89#[derive(Clone, Debug, Default, PartialEq)]
90pub struct PendingLoopUpdates {
91    pub resolutions: VecDeque<TaskResolution>,
92}
93
94#[derive(Clone, Debug)]
95pub struct TaskLaunchRequest {
96    pub task_id: Option<TaskId>,
97    pub request: ToolRequest,
98    pub approved_request: Option<ApprovalRequest>,
99}
100
101#[derive(Clone)]
102pub struct TaskStartContext {
103    pub executor: Arc<dyn ToolExecutor>,
104    pub tool_context: OwnedToolContext,
105}
106
107#[derive(Debug, Error, Clone, PartialEq, Eq)]
108pub enum TaskManagerError {
109    #[error("task not found: {0}")]
110    NotFound(TaskId),
111    #[error("task manager internal error: {0}")]
112    Internal(String),
113}
114
115pub trait TaskRoutingPolicy: Send + Sync {
116    fn route(&self, request: &ToolRequest) -> RoutingDecision;
117}
118
119impl<F> TaskRoutingPolicy for F
120where
121    F: Fn(&ToolRequest) -> RoutingDecision + Send + Sync,
122{
123    fn route(&self, request: &ToolRequest) -> RoutingDecision {
124        self(request)
125    }
126}
127
128#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum RoutingDecision {
130    Foreground,
131    Background,
132    ForegroundThenDetachAfter(Duration),
133}
134
135struct DefaultRoutingPolicy;
136
137impl TaskRoutingPolicy for DefaultRoutingPolicy {
138    fn route(&self, _request: &ToolRequest) -> RoutingDecision {
139        RoutingDecision::Foreground
140    }
141}
142
143#[async_trait]
144pub trait TaskManager: Send + Sync {
145    async fn start_task(
146        &self,
147        request: TaskLaunchRequest,
148        ctx: TaskStartContext,
149    ) -> Result<TaskStartOutcome, TaskManagerError>;
150
151    async fn wait_for_turn(
152        &self,
153        turn_id: &TurnId,
154        cancellation: Option<TurnCancellation>,
155    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError>;
156
157    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError>;
158
159    async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError>;
160
161    fn handle(&self) -> TaskManagerHandle;
162}
163
164#[async_trait]
165trait TaskManagerControl: Send + Sync {
166    async fn next_event(&self) -> Option<TaskEvent>;
167    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError>;
168    async fn list_running(&self) -> Vec<TaskSnapshot>;
169    async fn list_completed(&self) -> Vec<TaskSnapshot>;
170    async fn drain_ready_items(&self) -> Vec<Item>;
171    async fn set_continue_policy(
172        &self,
173        task_id: TaskId,
174        policy: ContinuePolicy,
175    ) -> Result<(), TaskManagerError>;
176    async fn set_delivery_mode(
177        &self,
178        task_id: TaskId,
179        mode: DeliveryMode,
180    ) -> Result<(), TaskManagerError>;
181    async fn wait_for_idle(&self);
182}
183
184#[derive(Clone)]
185pub struct TaskManagerHandle {
186    inner: Arc<dyn TaskManagerControl>,
187}
188
189impl TaskManagerHandle {
190    pub async fn next_event(&self) -> Option<TaskEvent> {
191        self.inner.next_event().await
192    }
193
194    pub async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
195        self.inner.cancel(task_id).await
196    }
197
198    pub async fn list_running(&self) -> Vec<TaskSnapshot> {
199        self.inner.list_running().await
200    }
201
202    pub async fn list_completed(&self) -> Vec<TaskSnapshot> {
203        self.inner.list_completed().await
204    }
205
206    pub async fn drain_ready_items(&self) -> Vec<Item> {
207        self.inner.drain_ready_items().await
208    }
209
210    pub async fn set_continue_policy(
211        &self,
212        task_id: TaskId,
213        policy: ContinuePolicy,
214    ) -> Result<(), TaskManagerError> {
215        self.inner.set_continue_policy(task_id, policy).await
216    }
217
218    pub async fn set_delivery_mode(
219        &self,
220        task_id: TaskId,
221        mode: DeliveryMode,
222    ) -> Result<(), TaskManagerError> {
223        self.inner.set_delivery_mode(task_id, mode).await
224    }
225
226    /// Wait until all running tasks have completed.
227    pub async fn wait_for_idle(&self) {
228        self.inner.wait_for_idle().await
229    }
230}
231
232pub struct SimpleTaskManager {
233    state: Arc<HandleState>,
234}
235
236impl SimpleTaskManager {
237    pub fn new() -> Self {
238        Self {
239            state: Arc::new(HandleState::default()),
240        }
241    }
242}
243
244impl Default for SimpleTaskManager {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250#[async_trait]
251impl TaskManager for SimpleTaskManager {
252    async fn start_task(
253        &self,
254        request: TaskLaunchRequest,
255        ctx: TaskStartContext,
256    ) -> Result<TaskStartOutcome, TaskManagerError> {
257        let task_id = request
258            .task_id
259            .clone()
260            .unwrap_or_else(|| self.state.next_task_id());
261        let outcome = match request.approved_request.as_ref() {
262            Some(approved) => {
263                ctx.executor
264                    .execute_approved_owned(request.request.clone(), approved, ctx.tool_context)
265                    .await
266            }
267            None => {
268                ctx.executor
269                    .execute_owned(request.request.clone(), ctx.tool_context)
270                    .await
271            }
272        };
273        Ok(TaskStartOutcome::Ready(Box::new(
274            map_outcome_to_resolution(Some(task_id), request.request, outcome),
275        )))
276    }
277
278    async fn wait_for_turn(
279        &self,
280        _turn_id: &TurnId,
281        _cancellation: Option<TurnCancellation>,
282    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
283        Ok(None)
284    }
285
286    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
287        Ok(PendingLoopUpdates::default())
288    }
289
290    async fn on_turn_interrupted(&self, _turn_id: &TurnId) -> Result<(), TaskManagerError> {
291        Ok(())
292    }
293
294    fn handle(&self) -> TaskManagerHandle {
295        TaskManagerHandle {
296            inner: self.state.clone(),
297        }
298    }
299}
300
301#[derive(Default)]
302struct HandleState {
303    next_task_index: AtomicU64,
304    events_rx: Mutex<Option<mpsc::UnboundedReceiver<TaskEvent>>>,
305}
306
307impl HandleState {
308    fn next_task_id(&self) -> TaskId {
309        let next = self.next_task_index.fetch_add(1, Ordering::SeqCst) + 1;
310        TaskId::new(format!("task-{}", next))
311    }
312}
313
314#[async_trait]
315impl TaskManagerControl for HandleState {
316    async fn next_event(&self) -> Option<TaskEvent> {
317        let mut rx = self.events_rx.lock().await;
318        match rx.as_mut() {
319            Some(inner) => inner.recv().await,
320            None => None,
321        }
322    }
323
324    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
325        Err(TaskManagerError::NotFound(task_id))
326    }
327
328    async fn list_running(&self) -> Vec<TaskSnapshot> {
329        Vec::new()
330    }
331
332    async fn list_completed(&self) -> Vec<TaskSnapshot> {
333        Vec::new()
334    }
335
336    async fn drain_ready_items(&self) -> Vec<Item> {
337        Vec::new()
338    }
339
340    async fn set_continue_policy(
341        &self,
342        task_id: TaskId,
343        _policy: ContinuePolicy,
344    ) -> Result<(), TaskManagerError> {
345        Err(TaskManagerError::NotFound(task_id))
346    }
347
348    async fn set_delivery_mode(
349        &self,
350        task_id: TaskId,
351        _mode: DeliveryMode,
352    ) -> Result<(), TaskManagerError> {
353        Err(TaskManagerError::NotFound(task_id))
354    }
355
356    async fn wait_for_idle(&self) {}
357}
358
359pub struct AsyncTaskManager {
360    inner: Arc<AsyncInner>,
361    routing: Arc<dyn TaskRoutingPolicy>,
362}
363
364impl AsyncTaskManager {
365    pub fn new() -> Self {
366        let (event_tx, event_rx) = mpsc::unbounded_channel();
367        Self {
368            inner: Arc::new(AsyncInner {
369                state: Mutex::new(AsyncState::default()),
370                host_event_tx: event_tx,
371                host_event_rx: Mutex::new(event_rx),
372                notify: Notify::new(),
373            }),
374            routing: Arc::new(DefaultRoutingPolicy),
375        }
376    }
377
378    pub fn routing(mut self, policy: impl TaskRoutingPolicy + 'static) -> Self {
379        self.routing = Arc::new(policy);
380        self
381    }
382}
383
384impl Default for AsyncTaskManager {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390#[derive(Default)]
391struct AsyncState {
392    next_task_index: u64,
393    tasks: BTreeMap<TaskId, TaskRecord>,
394    per_turn_running: BTreeMap<TurnId, usize>,
395    per_turn_updates: BTreeMap<TurnId, VecDeque<TurnTaskUpdate>>,
396    pending_loop_updates: VecDeque<TaskResolution>,
397    manual_ready_items: Vec<Item>,
398}
399
400struct TaskRecord {
401    snapshot: TaskSnapshot,
402    continue_policy: ContinuePolicy,
403    delivery_mode: DeliveryMode,
404    running: bool,
405    completed: bool,
406    join: Option<JoinHandle<()>>,
407}
408
409struct AsyncInner {
410    state: Mutex<AsyncState>,
411    host_event_tx: mpsc::UnboundedSender<TaskEvent>,
412    host_event_rx: Mutex<mpsc::UnboundedReceiver<TaskEvent>>,
413    notify: Notify,
414}
415
416impl AsyncInner {
417    async fn next_task_id(&self) -> TaskId {
418        let mut state = self.state.lock().await;
419        state.next_task_index += 1;
420        TaskId::new(format!("task-{}", state.next_task_index))
421    }
422}
423
424#[async_trait]
425impl TaskManager for AsyncTaskManager {
426    async fn start_task(
427        &self,
428        request: TaskLaunchRequest,
429        ctx: TaskStartContext,
430    ) -> Result<TaskStartOutcome, TaskManagerError> {
431        let route = self.routing.route(&request.request);
432        let task_id = match request.task_id.clone() {
433            Some(existing) => existing,
434            None => self.inner.next_task_id().await,
435        };
436        let initial_kind = match route {
437            RoutingDecision::Background => TaskKind::Background,
438            _ => TaskKind::Foreground,
439        };
440        let snapshot = TaskSnapshot {
441            id: task_id.clone(),
442            turn_id: request.request.turn_id.clone(),
443            call_id: request.request.call_id.clone(),
444            tool_name: request.request.tool_name.to_string(),
445            kind: initial_kind,
446            metadata: request.request.metadata.clone(),
447        };
448        let _ = self
449            .inner
450            .host_event_tx
451            .send(TaskEvent::Started(snapshot.clone()));
452
453        let mut state = self.inner.state.lock().await;
454        state.tasks.insert(
455            task_id.clone(),
456            TaskRecord {
457                snapshot: snapshot.clone(),
458                continue_policy: ContinuePolicy::NotifyOnly,
459                delivery_mode: DeliveryMode::ToLoop,
460                running: true,
461                completed: false,
462                join: None,
463            },
464        );
465        if initial_kind == TaskKind::Foreground {
466            *state
467                .per_turn_running
468                .entry(snapshot.turn_id.clone())
469                .or_default() += 1;
470        }
471        drop(state);
472
473        let event_tx = self.inner.host_event_tx.clone();
474        let inner = self.inner.clone();
475        let task_id_for_future = task_id.clone();
476        let turn_id = snapshot.turn_id.clone();
477        let approved = request.approved_request.clone();
478        let exec_request = request.request.clone();
479        let owned_ctx = ctx.tool_context.clone();
480        let executor = ctx.executor.clone();
481        let route_copy = route;
482        let join = tokio::spawn(async move {
483            if let RoutingDecision::ForegroundThenDetachAfter(duration) = route_copy {
484                let event_tx = event_tx.clone();
485                let inner = inner.clone();
486                let task_id = task_id_for_future.clone();
487                let turn_id = turn_id.clone();
488                tokio::spawn(async move {
489                    tokio::time::sleep(duration).await;
490                    let mut state = inner.state.lock().await;
491                    let snapshot = if let Some(record) = state.tasks.get_mut(&task_id)
492                        && record.running
493                        && record.snapshot.kind == TaskKind::Foreground
494                    {
495                        record.snapshot.kind = TaskKind::Background;
496                        Some(record.snapshot.clone())
497                    } else {
498                        None
499                    };
500                    if let Some(snapshot) = snapshot {
501                        if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
502                            *count = count.saturating_sub(1);
503                            if *count == 0 {
504                                state.per_turn_running.remove(&turn_id);
505                            }
506                        }
507                        state
508                            .per_turn_updates
509                            .entry(turn_id.clone())
510                            .or_default()
511                            .push_back(TurnTaskUpdate::Detached(snapshot.clone()));
512                        let _ = event_tx.send(TaskEvent::Detached(snapshot));
513                        inner.notify.notify_waiters();
514                    }
515                });
516            }
517
518            let outcome = match approved.as_ref() {
519                Some(approval) => {
520                    executor
521                        .execute_approved_owned(exec_request.clone(), approval, owned_ctx)
522                        .await
523                }
524                None => {
525                    executor
526                        .execute_owned(exec_request.clone(), owned_ctx)
527                        .await
528                }
529            };
530
531            let resolution =
532                map_outcome_to_resolution(Some(task_id_for_future.clone()), exec_request, outcome);
533            let completed_result = match &resolution {
534                TaskResolution::Item(item) => item.parts.iter().find_map(|part| match part {
535                    agentkit_core::Part::ToolResult(result) => Some(result.clone()),
536                    _ => None,
537                }),
538                TaskResolution::Approval(_) | TaskResolution::Auth(_) => None,
539            };
540
541            let (snapshot, should_request_continue) = {
542                let mut state = inner.state.lock().await;
543                let Some(record) = state.tasks.get_mut(&task_id_for_future) else {
544                    return;
545                };
546                record.running = false;
547                record.completed = true;
548                let snapshot = record.snapshot.clone();
549                let continue_policy = record.continue_policy;
550                let delivery_mode = record.delivery_mode;
551                let current_kind = snapshot.kind;
552
553                if current_kind == TaskKind::Foreground {
554                    if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
555                        *count = count.saturating_sub(1);
556                        if *count == 0 {
557                            state.per_turn_running.remove(&turn_id);
558                        }
559                    }
560                    state
561                        .per_turn_updates
562                        .entry(turn_id.clone())
563                        .or_default()
564                        .push_back(TurnTaskUpdate::Resolution(Box::new(resolution.clone())));
565                } else {
566                    match &resolution {
567                        TaskResolution::Item(_) if delivery_mode == DeliveryMode::ToLoop => {
568                            state.pending_loop_updates.push_back(resolution.clone());
569                        }
570                        TaskResolution::Approval(_) | TaskResolution::Auth(_)
571                            if delivery_mode == DeliveryMode::ToLoop =>
572                        {
573                            state.pending_loop_updates.push_back(resolution.clone());
574                        }
575                        TaskResolution::Item(item) => {
576                            state.manual_ready_items.push(item.clone());
577                        }
578                        TaskResolution::Approval(_) | TaskResolution::Auth(_) => {}
579                    }
580                }
581
582                (
583                    snapshot,
584                    current_kind == TaskKind::Background
585                        && delivery_mode == DeliveryMode::ToLoop
586                        && continue_policy == ContinuePolicy::RequestContinue,
587                )
588            };
589
590            if let Some(result) = completed_result {
591                let _ = event_tx.send(TaskEvent::Completed(snapshot.clone(), result));
592            }
593            if should_request_continue {
594                let _ = event_tx.send(TaskEvent::ContinueRequested);
595            }
596            inner.notify.notify_waiters();
597        });
598
599        let mut state = self.inner.state.lock().await;
600        if let Some(record) = state.tasks.get_mut(&task_id) {
601            record.join = Some(join);
602        }
603        Ok(TaskStartOutcome::Pending {
604            task_id,
605            kind: initial_kind,
606        })
607    }
608
609    async fn wait_for_turn(
610        &self,
611        turn_id: &TurnId,
612        cancellation: Option<TurnCancellation>,
613    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
614        loop {
615            {
616                let mut state = self.inner.state.lock().await;
617                if let Some(queue) = state.per_turn_updates.get_mut(turn_id)
618                    && let Some(update) = queue.pop_front()
619                {
620                    return Ok(Some(update));
621                }
622                if state
623                    .per_turn_running
624                    .get(turn_id)
625                    .copied()
626                    .unwrap_or_default()
627                    == 0
628                {
629                    return Ok(None);
630                }
631            }
632            if cancellation
633                .as_ref()
634                .is_some_and(TurnCancellation::is_cancelled)
635            {
636                return Ok(None);
637            }
638            if let Some(cancellation) = cancellation.as_ref() {
639                tokio::select! {
640                    _ = self.inner.notify.notified() => {}
641                    _ = cancellation.cancelled() => return Ok(None),
642                }
643            } else {
644                self.inner.notify.notified().await;
645            }
646        }
647    }
648
649    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
650        let mut state = self.inner.state.lock().await;
651        Ok(PendingLoopUpdates {
652            resolutions: std::mem::take(&mut state.pending_loop_updates),
653        })
654    }
655
656    async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError> {
657        let mut state = self.inner.state.lock().await;
658        let interrupted: Vec<TaskId> = state
659            .tasks
660            .iter()
661            .filter_map(|(id, record)| {
662                (record.snapshot.turn_id == *turn_id
663                    && record.snapshot.kind == TaskKind::Foreground
664                    && record.running)
665                    .then_some(id.clone())
666            })
667            .collect();
668        for task_id in interrupted {
669            if let Some(record) = state.tasks.get_mut(&task_id) {
670                record.running = false;
671                if let Some(join) = record.join.take() {
672                    join.abort();
673                }
674                let snapshot = record.snapshot.clone();
675                let _ = self
676                    .inner
677                    .host_event_tx
678                    .send(TaskEvent::Cancelled(snapshot));
679            }
680        }
681        state.per_turn_running.remove(turn_id);
682        self.inner.notify.notify_waiters();
683        Ok(())
684    }
685
686    fn handle(&self) -> TaskManagerHandle {
687        TaskManagerHandle {
688            inner: self.inner.clone(),
689        }
690    }
691}
692
693#[async_trait]
694impl TaskManagerControl for AsyncInner {
695    async fn next_event(&self) -> Option<TaskEvent> {
696        self.host_event_rx.lock().await.recv().await
697    }
698
699    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
700        let mut state = self.state.lock().await;
701        let record = state
702            .tasks
703            .get_mut(&task_id)
704            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
705        if let Some(join) = record.join.take() {
706            join.abort();
707        }
708        record.running = false;
709        let snapshot = record.snapshot.clone();
710        if record.snapshot.kind == TaskKind::Foreground
711            && let Some(count) = state.per_turn_running.get_mut(&snapshot.turn_id)
712        {
713            *count = count.saturating_sub(1);
714            if *count == 0 {
715                state.per_turn_running.remove(&snapshot.turn_id);
716            }
717        }
718        let _ = self.host_event_tx.send(TaskEvent::Cancelled(snapshot));
719        self.notify.notify_waiters();
720        Ok(())
721    }
722
723    async fn list_running(&self) -> Vec<TaskSnapshot> {
724        let state = self.state.lock().await;
725        state
726            .tasks
727            .values()
728            .filter(|record| record.running)
729            .map(|record| record.snapshot.clone())
730            .collect()
731    }
732
733    async fn list_completed(&self) -> Vec<TaskSnapshot> {
734        let state = self.state.lock().await;
735        state
736            .tasks
737            .values()
738            .filter(|record| record.completed)
739            .map(|record| record.snapshot.clone())
740            .collect()
741    }
742
743    async fn drain_ready_items(&self) -> Vec<Item> {
744        let mut state = self.state.lock().await;
745        std::mem::take(&mut state.manual_ready_items)
746    }
747
748    async fn set_continue_policy(
749        &self,
750        task_id: TaskId,
751        policy: ContinuePolicy,
752    ) -> Result<(), TaskManagerError> {
753        let mut state = self.state.lock().await;
754        let record = state
755            .tasks
756            .get_mut(&task_id)
757            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
758        record.continue_policy = policy;
759        Ok(())
760    }
761
762    async fn set_delivery_mode(
763        &self,
764        task_id: TaskId,
765        mode: DeliveryMode,
766    ) -> Result<(), TaskManagerError> {
767        let mut state = self.state.lock().await;
768        let record = state
769            .tasks
770            .get_mut(&task_id)
771            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
772        record.delivery_mode = mode;
773        Ok(())
774    }
775
776    async fn wait_for_idle(&self) {
777        loop {
778            {
779                let state = self.state.lock().await;
780                if !state.tasks.values().any(|r| r.running) {
781                    return;
782                }
783            }
784            self.notify.notified().await;
785        }
786    }
787}
788
789fn map_outcome_to_resolution(
790    task_id: Option<TaskId>,
791    request: ToolRequest,
792    outcome: ToolExecutionOutcome,
793) -> TaskResolution {
794    match outcome {
795        ToolExecutionOutcome::Completed(result) => TaskResolution::Item(Item {
796            id: None,
797            kind: agentkit_core::ItemKind::Tool,
798            parts: vec![agentkit_core::Part::ToolResult(result.result)],
799            metadata: result.metadata,
800        }),
801        ToolExecutionOutcome::Interrupted(
802            agentkit_tools_core::ToolInterruption::ApprovalRequired(mut approval),
803        ) => {
804            let task_id = task_id.unwrap_or_default();
805            approval.task_id = Some(task_id.clone());
806            TaskResolution::Approval(TaskApproval {
807                task_id,
808                tool_request: request,
809                approval,
810            })
811        }
812        ToolExecutionOutcome::Interrupted(agentkit_tools_core::ToolInterruption::AuthRequired(
813            mut auth,
814        )) => {
815            let task_id = task_id.unwrap_or_default();
816            auth.task_id = Some(task_id.clone());
817            TaskResolution::Auth(TaskAuth {
818                task_id,
819                tool_request: request,
820                auth,
821            })
822        }
823        ToolExecutionOutcome::Failed(error) => TaskResolution::Item(Item {
824            id: None,
825            kind: agentkit_core::ItemKind::Tool,
826            parts: vec![agentkit_core::Part::ToolResult(ToolResultPart {
827                call_id: request.call_id,
828                output: agentkit_core::ToolOutput::Text(error.to_string()),
829                is_error: true,
830                metadata: request.metadata,
831            })],
832            metadata: MetadataMap::new(),
833        }),
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use std::collections::BTreeMap;
840    use std::sync::Arc as StdArc;
841    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
842
843    use agentkit_core::{
844        CancellationController, ItemKind, Part, SessionId, ToolOutput, TurnCancellation,
845    };
846    use agentkit_tools_core::{
847        ApprovalReason, PermissionChecker, PermissionDecision, ToolAnnotations, ToolInterruption,
848        ToolName, ToolResult, ToolSpec,
849    };
850    use serde_json::json;
851    use tokio::sync::Notify;
852    use tokio::time::{Duration, timeout};
853
854    use super::*;
855
856    struct AllowAllPermissions;
857
858    impl PermissionChecker for AllowAllPermissions {
859        fn evaluate(
860            &self,
861            _request: &dyn agentkit_tools_core::PermissionRequest,
862        ) -> PermissionDecision {
863            PermissionDecision::Allow
864        }
865    }
866
867    #[derive(Clone)]
868    enum TestBehavior {
869        Block {
870            entered: StdArc<AtomicBool>,
871            release: StdArc<Notify>,
872            output: &'static str,
873        },
874        Approval,
875    }
876
877    #[derive(Clone)]
878    struct TestExecutor {
879        behaviors: BTreeMap<String, TestBehavior>,
880    }
881
882    impl TestExecutor {
883        fn new(behaviors: impl IntoIterator<Item = (impl Into<String>, TestBehavior)>) -> Self {
884            Self {
885                behaviors: behaviors
886                    .into_iter()
887                    .map(|(name, behavior)| (name.into(), behavior))
888                    .collect(),
889            }
890        }
891    }
892
893    #[async_trait]
894    impl ToolExecutor for TestExecutor {
895        fn specs(&self) -> Vec<ToolSpec> {
896            self.behaviors
897                .keys()
898                .map(|name| ToolSpec {
899                    name: ToolName::new(name),
900                    description: format!("test tool {name}"),
901                    input_schema: json!({
902                        "type": "object",
903                        "properties": {},
904                        "additionalProperties": false
905                    }),
906                    annotations: ToolAnnotations::default(),
907                    metadata: MetadataMap::new(),
908                })
909                .collect()
910        }
911
912        async fn execute(
913            &self,
914            request: ToolRequest,
915            _ctx: &mut agentkit_tools_core::ToolContext<'_>,
916        ) -> ToolExecutionOutcome {
917            match self.behaviors.get(request.tool_name.0.as_str()) {
918                Some(TestBehavior::Block {
919                    entered,
920                    release,
921                    output,
922                }) => {
923                    entered.store(true, AtomicOrdering::SeqCst);
924                    release.notified().await;
925                    ToolExecutionOutcome::Completed(ToolResult {
926                        result: ToolResultPart {
927                            call_id: request.call_id,
928                            output: ToolOutput::Text((*output).into()),
929                            is_error: false,
930                            metadata: request.metadata,
931                        },
932                        duration: None,
933                        metadata: MetadataMap::new(),
934                    })
935                }
936                Some(TestBehavior::Approval) => ToolExecutionOutcome::Interrupted(
937                    ToolInterruption::ApprovalRequired(ApprovalRequest {
938                        task_id: None,
939                        call_id: Some(request.call_id.clone()),
940                        id: "approval:test".into(),
941                        request_kind: "tool.test".into(),
942                        reason: ApprovalReason::SensitivePath,
943                        summary: "requires approval".into(),
944                        metadata: MetadataMap::new(),
945                    }),
946                ),
947                None => ToolExecutionOutcome::Failed(ToolError::Unavailable(
948                    request.tool_name.0.clone(),
949                )),
950            }
951        }
952    }
953
954    struct NameRoutingPolicy {
955        routes: BTreeMap<String, RoutingDecision>,
956    }
957
958    impl NameRoutingPolicy {
959        fn new(routes: impl IntoIterator<Item = (impl Into<String>, RoutingDecision)>) -> Self {
960            Self {
961                routes: routes
962                    .into_iter()
963                    .map(|(name, decision)| (name.into(), decision))
964                    .collect(),
965            }
966        }
967    }
968
969    impl TaskRoutingPolicy for NameRoutingPolicy {
970        fn route(&self, request: &ToolRequest) -> RoutingDecision {
971            self.routes
972                .get(request.tool_name.0.as_str())
973                .copied()
974                .unwrap_or(RoutingDecision::Foreground)
975        }
976    }
977
978    fn make_request(tool_name: &str, turn_id: &str, call_id: &str) -> ToolRequest {
979        ToolRequest {
980            call_id: ToolCallId::new(call_id),
981            tool_name: ToolName::new(tool_name),
982            input: json!({}),
983            session_id: SessionId::new("session-1"),
984            turn_id: TurnId::new(turn_id),
985            metadata: MetadataMap::new(),
986        }
987    }
988
989    fn make_context(
990        executor: Arc<dyn ToolExecutor>,
991        turn_id: &TurnId,
992        cancellation: Option<TurnCancellation>,
993    ) -> TaskStartContext {
994        TaskStartContext {
995            executor,
996            tool_context: OwnedToolContext {
997                session_id: SessionId::new("session-1"),
998                turn_id: turn_id.clone(),
999                metadata: MetadataMap::new(),
1000                permissions: Arc::new(AllowAllPermissions),
1001                resources: Arc::new(()),
1002                cancellation,
1003            },
1004        }
1005    }
1006
1007    async fn next_event(handle: &TaskManagerHandle) -> TaskEvent {
1008        timeout(Duration::from_secs(1), handle.next_event())
1009            .await
1010            .expect("timed out waiting for task event")
1011            .expect("task event stream ended unexpectedly")
1012    }
1013
1014    async fn wait_until_entered(entered: &AtomicBool) {
1015        timeout(Duration::from_secs(1), async {
1016            while !entered.load(AtomicOrdering::SeqCst) {
1017                tokio::task::yield_now().await;
1018            }
1019        })
1020        .await
1021        .expect("task never entered execution");
1022    }
1023
1024    #[tokio::test]
1025    async fn simple_task_manager_executes_inline_and_assigns_task_ids() {
1026        let manager = SimpleTaskManager::new();
1027        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1028            "needs-approval",
1029            TestBehavior::Approval,
1030        )]));
1031        let request = make_request("needs-approval", "turn-1", "call-1");
1032
1033        let outcome = manager
1034            .start_task(
1035                TaskLaunchRequest {
1036                    task_id: None,
1037                    request: request.clone(),
1038                    approved_request: None,
1039                },
1040                make_context(executor, &request.turn_id, None),
1041            )
1042            .await
1043            .unwrap();
1044
1045        match outcome {
1046            TaskStartOutcome::Ready(resolution) => match *resolution {
1047                TaskResolution::Approval(task) => {
1048                    assert!(!task.task_id.0.is_empty());
1049                    assert_eq!(task.approval.task_id.as_ref(), Some(&task.task_id));
1050                    assert_eq!(task.tool_request.call_id, request.call_id);
1051                }
1052                other => panic!("unexpected task resolution: {other:?}"),
1053            },
1054            other => panic!("unexpected start outcome: {other:?}"),
1055        }
1056
1057        assert!(manager.handle().list_running().await.is_empty());
1058    }
1059
1060    #[tokio::test]
1061    async fn async_manager_interrupt_cancels_foreground_only() {
1062        let fg_release = StdArc::new(Notify::new());
1063        let fg_entered = StdArc::new(AtomicBool::new(false));
1064        let bg_release = StdArc::new(Notify::new());
1065        let bg_entered = StdArc::new(AtomicBool::new(false));
1066        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([
1067            (
1068                "foreground",
1069                TestBehavior::Block {
1070                    entered: fg_entered.clone(),
1071                    release: fg_release.clone(),
1072                    output: "foreground-done",
1073                },
1074            ),
1075            (
1076                "background",
1077                TestBehavior::Block {
1078                    entered: bg_entered.clone(),
1079                    release: bg_release.clone(),
1080                    output: "background-done",
1081                },
1082            ),
1083        ]));
1084        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([
1085            ("foreground", RoutingDecision::Foreground),
1086            ("background", RoutingDecision::Background),
1087        ]));
1088        let handle = manager.handle();
1089        let turn_id = TurnId::new("turn-1");
1090
1091        let foreground = manager
1092            .start_task(
1093                TaskLaunchRequest {
1094                    task_id: None,
1095                    request: make_request("foreground", "turn-1", "call-fg"),
1096                    approved_request: None,
1097                },
1098                make_context(executor.clone(), &turn_id, None),
1099            )
1100            .await
1101            .unwrap();
1102        let background = manager
1103            .start_task(
1104                TaskLaunchRequest {
1105                    task_id: None,
1106                    request: make_request("background", "turn-1", "call-bg"),
1107                    approved_request: None,
1108                },
1109                make_context(executor.clone(), &turn_id, None),
1110            )
1111            .await
1112            .unwrap();
1113
1114        assert!(matches!(
1115            foreground,
1116            TaskStartOutcome::Pending {
1117                kind: TaskKind::Foreground,
1118                ..
1119            }
1120        ));
1121        let background_id = match background {
1122            TaskStartOutcome::Pending {
1123                task_id,
1124                kind: TaskKind::Background,
1125            } => task_id,
1126            other => panic!("unexpected background outcome: {other:?}"),
1127        };
1128
1129        let _ = next_event(&handle).await;
1130        let _ = next_event(&handle).await;
1131        wait_until_entered(fg_entered.as_ref()).await;
1132        wait_until_entered(bg_entered.as_ref()).await;
1133
1134        manager.on_turn_interrupted(&turn_id).await.unwrap();
1135
1136        match next_event(&handle).await {
1137            TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.tool_name, "foreground"),
1138            other => panic!("unexpected event after interrupt: {other:?}"),
1139        }
1140
1141        let running = handle.list_running().await;
1142        assert_eq!(running.len(), 1);
1143        assert_eq!(running[0].id, background_id);
1144        assert_eq!(running[0].tool_name, "background");
1145
1146        bg_release.notify_waiters();
1147        match next_event(&handle).await {
1148            TaskEvent::Completed(snapshot, result) => {
1149                assert_eq!(snapshot.id, background_id);
1150                assert_eq!(result.output, ToolOutput::Text("background-done".into()));
1151            }
1152            other => panic!("unexpected completion event: {other:?}"),
1153        }
1154    }
1155
1156    #[tokio::test]
1157    async fn async_manager_can_cancel_background_tasks_by_id() {
1158        let release = StdArc::new(Notify::new());
1159        let entered = StdArc::new(AtomicBool::new(false));
1160        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1161            "background",
1162            TestBehavior::Block {
1163                entered: entered.clone(),
1164                release,
1165                output: "done",
1166            },
1167        )]));
1168        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1169            "background",
1170            RoutingDecision::Background,
1171        )]));
1172        let handle = manager.handle();
1173        let request = make_request("background", "turn-1", "call-1");
1174
1175        let task_id = match manager
1176            .start_task(
1177                TaskLaunchRequest {
1178                    task_id: None,
1179                    request: request.clone(),
1180                    approved_request: None,
1181                },
1182                make_context(executor, &request.turn_id, None),
1183            )
1184            .await
1185            .unwrap()
1186        {
1187            TaskStartOutcome::Pending { task_id, .. } => task_id,
1188            other => panic!("unexpected start outcome: {other:?}"),
1189        };
1190
1191        let _ = next_event(&handle).await;
1192        wait_until_entered(entered.as_ref()).await;
1193        handle.cancel(task_id.clone()).await.unwrap();
1194
1195        match next_event(&handle).await {
1196            TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.id, task_id),
1197            other => panic!("unexpected event after cancel: {other:?}"),
1198        }
1199
1200        assert!(handle.list_running().await.is_empty());
1201    }
1202
1203    #[tokio::test]
1204    async fn async_manager_manual_delivery_keeps_results_out_of_loop_updates() {
1205        let release = StdArc::new(Notify::new());
1206        let entered = StdArc::new(AtomicBool::new(false));
1207        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1208            "background",
1209            TestBehavior::Block {
1210                entered: entered.clone(),
1211                release: release.clone(),
1212                output: "manual-done",
1213            },
1214        )]));
1215        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1216            "background",
1217            RoutingDecision::Background,
1218        )]));
1219        let handle = manager.handle();
1220        let request = make_request("background", "turn-1", "call-1");
1221
1222        let task_id = match manager
1223            .start_task(
1224                TaskLaunchRequest {
1225                    task_id: None,
1226                    request: request.clone(),
1227                    approved_request: None,
1228                },
1229                make_context(executor, &request.turn_id, None),
1230            )
1231            .await
1232            .unwrap()
1233        {
1234            TaskStartOutcome::Pending { task_id, .. } => task_id,
1235            other => panic!("unexpected start outcome: {other:?}"),
1236        };
1237
1238        let _ = next_event(&handle).await;
1239        wait_until_entered(entered.as_ref()).await;
1240        handle
1241            .set_continue_policy(task_id.clone(), ContinuePolicy::RequestContinue)
1242            .await
1243            .unwrap();
1244        handle
1245            .set_delivery_mode(task_id, DeliveryMode::Manual)
1246            .await
1247            .unwrap();
1248
1249        release.notify_waiters();
1250        match next_event(&handle).await {
1251            TaskEvent::Completed(_, result) => {
1252                assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1253            }
1254            other => panic!("unexpected event: {other:?}"),
1255        }
1256
1257        assert!(
1258            timeout(Duration::from_millis(50), handle.next_event())
1259                .await
1260                .is_err()
1261        );
1262        assert!(
1263            manager
1264                .take_pending_loop_updates()
1265                .await
1266                .unwrap()
1267                .resolutions
1268                .is_empty()
1269        );
1270
1271        let ready_items = handle.drain_ready_items().await;
1272        assert_eq!(ready_items.len(), 1);
1273        assert_eq!(ready_items[0].kind, ItemKind::Tool);
1274        match &ready_items[0].parts[0] {
1275            Part::ToolResult(result) => {
1276                assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1277            }
1278            other => panic!("unexpected ready item: {other:?}"),
1279        }
1280    }
1281
1282    #[tokio::test]
1283    async fn async_manager_to_loop_delivery_can_request_continue() {
1284        let release = StdArc::new(Notify::new());
1285        let entered = StdArc::new(AtomicBool::new(false));
1286        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1287            "background",
1288            TestBehavior::Block {
1289                entered: entered.clone(),
1290                release: release.clone(),
1291                output: "loop-done",
1292            },
1293        )]));
1294        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1295            "background",
1296            RoutingDecision::Background,
1297        )]));
1298        let handle = manager.handle();
1299        let request = make_request("background", "turn-1", "call-1");
1300
1301        let task_id = match manager
1302            .start_task(
1303                TaskLaunchRequest {
1304                    task_id: None,
1305                    request: request.clone(),
1306                    approved_request: None,
1307                },
1308                make_context(
1309                    executor,
1310                    &request.turn_id,
1311                    Some(TurnCancellation::new(
1312                        CancellationController::new().handle(),
1313                    )),
1314                ),
1315            )
1316            .await
1317            .unwrap()
1318        {
1319            TaskStartOutcome::Pending { task_id, .. } => task_id,
1320            other => panic!("unexpected start outcome: {other:?}"),
1321        };
1322
1323        let _ = next_event(&handle).await;
1324        wait_until_entered(entered.as_ref()).await;
1325        handle
1326            .set_continue_policy(task_id, ContinuePolicy::RequestContinue)
1327            .await
1328            .unwrap();
1329
1330        release.notify_waiters();
1331        match next_event(&handle).await {
1332            TaskEvent::Completed(_, result) => {
1333                assert_eq!(result.output, ToolOutput::Text("loop-done".into()))
1334            }
1335            other => panic!("unexpected completion event: {other:?}"),
1336        }
1337        match next_event(&handle).await {
1338            TaskEvent::ContinueRequested => {}
1339            other => panic!("unexpected follow-up event: {other:?}"),
1340        }
1341
1342        let updates = manager.take_pending_loop_updates().await.unwrap();
1343        assert_eq!(updates.resolutions.len(), 1);
1344        assert!(handle.drain_ready_items().await.is_empty());
1345    }
1346
1347    #[tokio::test]
1348    async fn wait_for_idle_returns_after_loop_updates_are_queued() {
1349        let release = StdArc::new(Notify::new());
1350        let entered = StdArc::new(AtomicBool::new(false));
1351        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1352            "background",
1353            TestBehavior::Block {
1354                entered: entered.clone(),
1355                release: release.clone(),
1356                output: "idle-done",
1357            },
1358        )]));
1359        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1360            "background",
1361            RoutingDecision::Background,
1362        )]));
1363        let handle = manager.handle();
1364        let request = make_request("background", "turn-1", "call-1");
1365
1366        let outcome = manager
1367            .start_task(
1368                TaskLaunchRequest {
1369                    task_id: None,
1370                    request: request.clone(),
1371                    approved_request: None,
1372                },
1373                make_context(executor, &request.turn_id, None),
1374            )
1375            .await
1376            .unwrap();
1377        assert!(matches!(outcome, TaskStartOutcome::Pending { .. }));
1378
1379        let _ = next_event(&handle).await;
1380        wait_until_entered(entered.as_ref()).await;
1381        release.notify_waiters();
1382
1383        timeout(Duration::from_secs(1), handle.wait_for_idle())
1384            .await
1385            .expect("wait_for_idle timed out");
1386
1387        let updates = manager.take_pending_loop_updates().await.unwrap();
1388        assert_eq!(updates.resolutions.len(), 1);
1389        match &updates.resolutions[0] {
1390            TaskResolution::Item(item) => match &item.parts[0] {
1391                Part::ToolResult(result) => {
1392                    assert_eq!(result.call_id, request.call_id);
1393                    assert_eq!(result.output, ToolOutput::Text("idle-done".into()));
1394                }
1395                other => panic!("unexpected tool item: {other:?}"),
1396            },
1397            other => panic!("unexpected pending update: {other:?}"),
1398        }
1399    }
1400}