Skip to main content

agentkit_task_manager/
lib.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use agentkit_core::{
7    Item, MetadataMap, TaskId, ToolCallId, ToolResultPart, TurnCancellation, TurnId,
8};
9use agentkit_tools_core::{
10    ApprovalRequest, OwnedToolContext, ToolError, ToolExecutionOutcome, ToolExecutor, ToolRequest,
11};
12use async_trait::async_trait;
13use thiserror::Error;
14use tokio::sync::{Mutex, Notify, mpsc};
15use tokio::task::JoinHandle;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum TaskKind {
19    Foreground,
20    Background,
21}
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ContinuePolicy {
25    NotifyOnly,
26    RequestContinue,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
30pub enum DeliveryMode {
31    ToLoop,
32    Manual,
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct TaskSnapshot {
37    pub id: TaskId,
38    pub turn_id: TurnId,
39    pub call_id: ToolCallId,
40    pub tool_name: String,
41    pub kind: TaskKind,
42    pub metadata: MetadataMap,
43}
44
45#[derive(Clone, Debug, PartialEq)]
46pub enum TaskEvent {
47    Started(TaskSnapshot),
48    Detached(TaskSnapshot),
49    Completed(TaskSnapshot, ToolResultPart),
50    Cancelled(TaskSnapshot),
51    Failed(TaskSnapshot, ToolError),
52    ContinueRequested,
53}
54
55#[derive(Clone, Debug, PartialEq)]
56pub struct TaskApproval {
57    pub task_id: TaskId,
58    pub tool_request: ToolRequest,
59    pub approval: ApprovalRequest,
60}
61
62#[derive(Clone, Debug, PartialEq)]
63pub enum TaskResolution {
64    Item(Item),
65    Approval(TaskApproval),
66}
67
68#[derive(Clone, Debug, PartialEq)]
69pub enum TaskStartOutcome {
70    Ready(Box<TaskResolution>),
71    Pending { task_id: TaskId, kind: TaskKind },
72}
73
74#[derive(Clone, Debug, PartialEq)]
75pub enum TurnTaskUpdate {
76    Resolution(Box<TaskResolution>),
77    Detached(TaskSnapshot),
78}
79
80#[derive(Clone, Debug, Default, PartialEq)]
81pub struct PendingLoopUpdates {
82    pub resolutions: VecDeque<TaskResolution>,
83}
84
85/// How a task should be invoked. Mutually exclusive between plain execution
86/// and resuming after approval.
87#[derive(Clone, Debug, Default)]
88pub enum TaskLaunchKind {
89    /// Execute the tool with the active permission policy.
90    #[default]
91    Plain,
92    /// Re-execute a previously-interrupted call after the user approved it.
93    Approved(ApprovalRequest),
94}
95
96#[derive(Clone, Debug)]
97pub struct TaskLaunchRequest {
98    pub task_id: Option<TaskId>,
99    pub request: ToolRequest,
100    pub kind: TaskLaunchKind,
101}
102
103impl TaskLaunchRequest {
104    /// Plain launch (no prior approval / auth).
105    pub fn plain(task_id: Option<TaskId>, request: ToolRequest) -> Self {
106        Self {
107            task_id,
108            request,
109            kind: TaskLaunchKind::Plain,
110        }
111    }
112
113    /// Resume after the user approved the call.
114    pub fn approved(
115        task_id: Option<TaskId>,
116        request: ToolRequest,
117        approval: ApprovalRequest,
118    ) -> Self {
119        Self {
120            task_id,
121            request,
122            kind: TaskLaunchKind::Approved(approval),
123        }
124    }
125}
126
127#[derive(Clone)]
128pub struct TaskStartContext {
129    pub executor: Arc<dyn ToolExecutor>,
130    pub tool_context: OwnedToolContext,
131}
132
133#[derive(Debug, Error, Clone, PartialEq, Eq)]
134pub enum TaskManagerError {
135    #[error("task not found: {0}")]
136    NotFound(TaskId),
137    #[error("task manager internal error: {0}")]
138    Internal(String),
139}
140
141pub trait TaskRoutingPolicy: Send + Sync {
142    fn route(&self, request: &ToolRequest) -> RoutingDecision;
143}
144
145impl<F> TaskRoutingPolicy for F
146where
147    F: Fn(&ToolRequest) -> RoutingDecision + Send + Sync,
148{
149    fn route(&self, request: &ToolRequest) -> RoutingDecision {
150        self(request)
151    }
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq)]
155pub enum RoutingDecision {
156    Foreground,
157    Background,
158    ForegroundThenDetachAfter(Duration),
159}
160
161struct DefaultRoutingPolicy;
162
163impl TaskRoutingPolicy for DefaultRoutingPolicy {
164    fn route(&self, _request: &ToolRequest) -> RoutingDecision {
165        RoutingDecision::Foreground
166    }
167}
168
169#[async_trait]
170pub trait TaskManager: Send + Sync {
171    async fn start_task(
172        &self,
173        request: TaskLaunchRequest,
174        ctx: TaskStartContext,
175    ) -> Result<TaskStartOutcome, TaskManagerError>;
176
177    async fn wait_for_turn(
178        &self,
179        turn_id: &TurnId,
180        cancellation: Option<TurnCancellation>,
181    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError>;
182
183    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError>;
184
185    async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError>;
186
187    fn handle(&self) -> TaskManagerHandle;
188}
189
190#[async_trait]
191trait TaskManagerControl: Send + Sync {
192    async fn next_event(&self) -> Option<TaskEvent>;
193    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError>;
194    async fn list_running(&self) -> Vec<TaskSnapshot>;
195    async fn list_completed(&self) -> Vec<TaskSnapshot>;
196    async fn drain_ready_items(&self) -> Vec<Item>;
197    async fn set_continue_policy(
198        &self,
199        task_id: TaskId,
200        policy: ContinuePolicy,
201    ) -> Result<(), TaskManagerError>;
202    async fn set_delivery_mode(
203        &self,
204        task_id: TaskId,
205        mode: DeliveryMode,
206    ) -> Result<(), TaskManagerError>;
207    async fn wait_for_idle(&self);
208}
209
210#[derive(Clone)]
211pub struct TaskManagerHandle {
212    inner: Arc<dyn TaskManagerControl>,
213}
214
215impl TaskManagerHandle {
216    pub async fn next_event(&self) -> Option<TaskEvent> {
217        self.inner.next_event().await
218    }
219
220    pub async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
221        self.inner.cancel(task_id).await
222    }
223
224    pub async fn list_running(&self) -> Vec<TaskSnapshot> {
225        self.inner.list_running().await
226    }
227
228    pub async fn list_completed(&self) -> Vec<TaskSnapshot> {
229        self.inner.list_completed().await
230    }
231
232    pub async fn drain_ready_items(&self) -> Vec<Item> {
233        self.inner.drain_ready_items().await
234    }
235
236    pub async fn set_continue_policy(
237        &self,
238        task_id: TaskId,
239        policy: ContinuePolicy,
240    ) -> Result<(), TaskManagerError> {
241        self.inner.set_continue_policy(task_id, policy).await
242    }
243
244    pub async fn set_delivery_mode(
245        &self,
246        task_id: TaskId,
247        mode: DeliveryMode,
248    ) -> Result<(), TaskManagerError> {
249        self.inner.set_delivery_mode(task_id, mode).await
250    }
251
252    /// Wait until all running tasks have completed.
253    pub async fn wait_for_idle(&self) {
254        self.inner.wait_for_idle().await
255    }
256}
257
258pub struct SimpleTaskManager {
259    state: Arc<HandleState>,
260}
261
262impl SimpleTaskManager {
263    pub fn new() -> Self {
264        Self {
265            state: Arc::new(HandleState::default()),
266        }
267    }
268}
269
270impl Default for SimpleTaskManager {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276#[async_trait]
277impl TaskManager for SimpleTaskManager {
278    async fn start_task(
279        &self,
280        request: TaskLaunchRequest,
281        ctx: TaskStartContext,
282    ) -> Result<TaskStartOutcome, TaskManagerError> {
283        let task_id = request
284            .task_id
285            .clone()
286            .unwrap_or_else(|| self.state.next_task_id());
287        let outcome = match &request.kind {
288            TaskLaunchKind::Approved(approved) => {
289                ctx.executor
290                    .execute_approved_owned(request.request.clone(), approved, ctx.tool_context)
291                    .await
292            }
293            TaskLaunchKind::Plain => {
294                ctx.executor
295                    .execute_owned(request.request.clone(), ctx.tool_context)
296                    .await
297            }
298        };
299        Ok(TaskStartOutcome::Ready(Box::new(
300            map_outcome_to_resolution(Some(task_id), request.request, outcome),
301        )))
302    }
303
304    async fn wait_for_turn(
305        &self,
306        _turn_id: &TurnId,
307        _cancellation: Option<TurnCancellation>,
308    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
309        Ok(None)
310    }
311
312    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
313        Ok(PendingLoopUpdates::default())
314    }
315
316    async fn on_turn_interrupted(&self, _turn_id: &TurnId) -> Result<(), TaskManagerError> {
317        Ok(())
318    }
319
320    fn handle(&self) -> TaskManagerHandle {
321        TaskManagerHandle {
322            inner: self.state.clone(),
323        }
324    }
325}
326
327#[derive(Default)]
328struct HandleState {
329    next_task_index: AtomicU64,
330    events_rx: Mutex<Option<mpsc::UnboundedReceiver<TaskEvent>>>,
331}
332
333impl HandleState {
334    fn next_task_id(&self) -> TaskId {
335        let next = self.next_task_index.fetch_add(1, Ordering::SeqCst) + 1;
336        TaskId::new(format!("task-{}", next))
337    }
338}
339
340#[async_trait]
341impl TaskManagerControl for HandleState {
342    async fn next_event(&self) -> Option<TaskEvent> {
343        let mut rx = self.events_rx.lock().await;
344        match rx.as_mut() {
345            Some(inner) => inner.recv().await,
346            None => None,
347        }
348    }
349
350    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
351        Err(TaskManagerError::NotFound(task_id))
352    }
353
354    async fn list_running(&self) -> Vec<TaskSnapshot> {
355        Vec::new()
356    }
357
358    async fn list_completed(&self) -> Vec<TaskSnapshot> {
359        Vec::new()
360    }
361
362    async fn drain_ready_items(&self) -> Vec<Item> {
363        Vec::new()
364    }
365
366    async fn set_continue_policy(
367        &self,
368        task_id: TaskId,
369        _policy: ContinuePolicy,
370    ) -> Result<(), TaskManagerError> {
371        Err(TaskManagerError::NotFound(task_id))
372    }
373
374    async fn set_delivery_mode(
375        &self,
376        task_id: TaskId,
377        _mode: DeliveryMode,
378    ) -> Result<(), TaskManagerError> {
379        Err(TaskManagerError::NotFound(task_id))
380    }
381
382    async fn wait_for_idle(&self) {}
383}
384
385pub struct AsyncTaskManager {
386    inner: Arc<AsyncInner>,
387    routing: Arc<dyn TaskRoutingPolicy>,
388}
389
390impl AsyncTaskManager {
391    pub fn new() -> Self {
392        let (event_tx, event_rx) = mpsc::unbounded_channel();
393        Self {
394            inner: Arc::new(AsyncInner {
395                state: Mutex::new(AsyncState::default()),
396                host_event_tx: event_tx,
397                host_event_rx: Mutex::new(event_rx),
398                notify: Notify::new(),
399            }),
400            routing: Arc::new(DefaultRoutingPolicy),
401        }
402    }
403
404    pub fn routing(mut self, policy: impl TaskRoutingPolicy + 'static) -> Self {
405        self.routing = Arc::new(policy);
406        self
407    }
408}
409
410impl Default for AsyncTaskManager {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416#[derive(Default)]
417struct AsyncState {
418    next_task_index: u64,
419    tasks: BTreeMap<TaskId, TaskRecord>,
420    per_turn_running: BTreeMap<TurnId, usize>,
421    per_turn_updates: BTreeMap<TurnId, VecDeque<TurnTaskUpdate>>,
422    pending_loop_updates: VecDeque<TaskResolution>,
423    manual_ready_items: Vec<Item>,
424}
425
426struct TaskRecord {
427    snapshot: TaskSnapshot,
428    continue_policy: ContinuePolicy,
429    delivery_mode: DeliveryMode,
430    running: bool,
431    completed: bool,
432    join: Option<JoinHandle<()>>,
433}
434
435struct AsyncInner {
436    state: Mutex<AsyncState>,
437    host_event_tx: mpsc::UnboundedSender<TaskEvent>,
438    host_event_rx: Mutex<mpsc::UnboundedReceiver<TaskEvent>>,
439    notify: Notify,
440}
441
442impl AsyncInner {
443    async fn next_task_id(&self) -> TaskId {
444        let mut state = self.state.lock().await;
445        state.next_task_index += 1;
446        TaskId::new(format!("task-{}", state.next_task_index))
447    }
448}
449
450#[async_trait]
451impl TaskManager for AsyncTaskManager {
452    async fn start_task(
453        &self,
454        request: TaskLaunchRequest,
455        ctx: TaskStartContext,
456    ) -> Result<TaskStartOutcome, TaskManagerError> {
457        let route = self.routing.route(&request.request);
458        let task_id = match request.task_id.clone() {
459            Some(existing) => existing,
460            None => self.inner.next_task_id().await,
461        };
462        let initial_kind = match route {
463            RoutingDecision::Background => TaskKind::Background,
464            _ => TaskKind::Foreground,
465        };
466        let snapshot = TaskSnapshot {
467            id: task_id.clone(),
468            turn_id: request.request.turn_id.clone(),
469            call_id: request.request.call_id.clone(),
470            tool_name: request.request.tool_name.to_string(),
471            kind: initial_kind,
472            metadata: request.request.metadata.clone(),
473        };
474        let _ = self
475            .inner
476            .host_event_tx
477            .send(TaskEvent::Started(snapshot.clone()));
478
479        let mut state = self.inner.state.lock().await;
480        state.tasks.insert(
481            task_id.clone(),
482            TaskRecord {
483                snapshot: snapshot.clone(),
484                continue_policy: ContinuePolicy::NotifyOnly,
485                delivery_mode: DeliveryMode::ToLoop,
486                running: true,
487                completed: false,
488                join: None,
489            },
490        );
491        if initial_kind == TaskKind::Foreground {
492            *state
493                .per_turn_running
494                .entry(snapshot.turn_id.clone())
495                .or_default() += 1;
496        }
497        drop(state);
498
499        let event_tx = self.inner.host_event_tx.clone();
500        let inner = self.inner.clone();
501        let task_id_for_future = task_id.clone();
502        let turn_id = snapshot.turn_id.clone();
503        let kind = request.kind.clone();
504        let exec_request = request.request.clone();
505        let owned_ctx = ctx.tool_context.clone();
506        let executor = ctx.executor.clone();
507        let route_copy = route;
508        let join = tokio::spawn(async move {
509            if let RoutingDecision::ForegroundThenDetachAfter(duration) = route_copy {
510                let event_tx = event_tx.clone();
511                let inner = inner.clone();
512                let task_id = task_id_for_future.clone();
513                let turn_id = turn_id.clone();
514                tokio::spawn(async move {
515                    tokio::time::sleep(duration).await;
516                    let mut state = inner.state.lock().await;
517                    let snapshot = if let Some(record) = state.tasks.get_mut(&task_id)
518                        && record.running
519                        && record.snapshot.kind == TaskKind::Foreground
520                    {
521                        record.snapshot.kind = TaskKind::Background;
522                        Some(record.snapshot.clone())
523                    } else {
524                        None
525                    };
526                    if let Some(snapshot) = snapshot {
527                        if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
528                            *count = count.saturating_sub(1);
529                            if *count == 0 {
530                                state.per_turn_running.remove(&turn_id);
531                            }
532                        }
533                        state
534                            .per_turn_updates
535                            .entry(turn_id.clone())
536                            .or_default()
537                            .push_back(TurnTaskUpdate::Detached(snapshot.clone()));
538                        let _ = event_tx.send(TaskEvent::Detached(snapshot));
539                        inner.notify.notify_waiters();
540                    }
541                });
542            }
543
544            let outcome = match &kind {
545                TaskLaunchKind::Approved(approval) => {
546                    executor
547                        .execute_approved_owned(exec_request.clone(), approval, owned_ctx)
548                        .await
549                }
550                TaskLaunchKind::Plain => {
551                    executor
552                        .execute_owned(exec_request.clone(), owned_ctx)
553                        .await
554                }
555            };
556
557            let resolution =
558                map_outcome_to_resolution(Some(task_id_for_future.clone()), exec_request, outcome);
559            let completed_result = match &resolution {
560                TaskResolution::Item(item) => item.parts.iter().find_map(|part| match part {
561                    agentkit_core::Part::ToolResult(result) => Some(result.clone()),
562                    _ => None,
563                }),
564                TaskResolution::Approval(_) => None,
565            };
566
567            let (snapshot, should_request_continue) = {
568                let mut state = inner.state.lock().await;
569                let Some(record) = state.tasks.get_mut(&task_id_for_future) else {
570                    return;
571                };
572                record.running = false;
573                record.completed = true;
574                let snapshot = record.snapshot.clone();
575                let continue_policy = record.continue_policy;
576                let delivery_mode = record.delivery_mode;
577                let current_kind = snapshot.kind;
578
579                if current_kind == TaskKind::Foreground {
580                    if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
581                        *count = count.saturating_sub(1);
582                        if *count == 0 {
583                            state.per_turn_running.remove(&turn_id);
584                        }
585                    }
586                    state
587                        .per_turn_updates
588                        .entry(turn_id.clone())
589                        .or_default()
590                        .push_back(TurnTaskUpdate::Resolution(Box::new(resolution.clone())));
591                } else {
592                    match &resolution {
593                        TaskResolution::Item(_) if delivery_mode == DeliveryMode::ToLoop => {
594                            state.pending_loop_updates.push_back(resolution.clone());
595                        }
596                        TaskResolution::Approval(_) if delivery_mode == DeliveryMode::ToLoop => {
597                            state.pending_loop_updates.push_back(resolution.clone());
598                        }
599                        TaskResolution::Item(item) => {
600                            state.manual_ready_items.push(item.clone());
601                        }
602                        TaskResolution::Approval(_) => {}
603                    }
604                }
605
606                (
607                    snapshot,
608                    current_kind == TaskKind::Background
609                        && delivery_mode == DeliveryMode::ToLoop
610                        && continue_policy == ContinuePolicy::RequestContinue,
611                )
612            };
613
614            if let Some(result) = completed_result {
615                let _ = event_tx.send(TaskEvent::Completed(snapshot.clone(), result));
616            }
617            if should_request_continue {
618                let _ = event_tx.send(TaskEvent::ContinueRequested);
619            }
620            inner.notify.notify_waiters();
621        });
622
623        let mut state = self.inner.state.lock().await;
624        if let Some(record) = state.tasks.get_mut(&task_id) {
625            record.join = Some(join);
626        }
627        Ok(TaskStartOutcome::Pending {
628            task_id,
629            kind: initial_kind,
630        })
631    }
632
633    async fn wait_for_turn(
634        &self,
635        turn_id: &TurnId,
636        cancellation: Option<TurnCancellation>,
637    ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
638        loop {
639            {
640                let mut state = self.inner.state.lock().await;
641                if let Some(queue) = state.per_turn_updates.get_mut(turn_id)
642                    && let Some(update) = queue.pop_front()
643                {
644                    return Ok(Some(update));
645                }
646                if state
647                    .per_turn_running
648                    .get(turn_id)
649                    .copied()
650                    .unwrap_or_default()
651                    == 0
652                {
653                    return Ok(None);
654                }
655            }
656            if cancellation
657                .as_ref()
658                .is_some_and(TurnCancellation::is_cancelled)
659            {
660                return Ok(None);
661            }
662            if let Some(cancellation) = cancellation.as_ref() {
663                tokio::select! {
664                    _ = self.inner.notify.notified() => {}
665                    _ = cancellation.cancelled() => return Ok(None),
666                }
667            } else {
668                self.inner.notify.notified().await;
669            }
670        }
671    }
672
673    async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
674        let mut state = self.inner.state.lock().await;
675        Ok(PendingLoopUpdates {
676            resolutions: std::mem::take(&mut state.pending_loop_updates),
677        })
678    }
679
680    async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError> {
681        let mut state = self.inner.state.lock().await;
682        let interrupted: Vec<TaskId> = state
683            .tasks
684            .iter()
685            .filter_map(|(id, record)| {
686                (record.snapshot.turn_id == *turn_id
687                    && record.snapshot.kind == TaskKind::Foreground
688                    && record.running)
689                    .then_some(id.clone())
690            })
691            .collect();
692        for task_id in interrupted {
693            if let Some(record) = state.tasks.get_mut(&task_id) {
694                record.running = false;
695                if let Some(join) = record.join.take() {
696                    join.abort();
697                }
698                let snapshot = record.snapshot.clone();
699                let _ = self
700                    .inner
701                    .host_event_tx
702                    .send(TaskEvent::Cancelled(snapshot));
703            }
704        }
705        state.per_turn_running.remove(turn_id);
706        self.inner.notify.notify_waiters();
707        Ok(())
708    }
709
710    fn handle(&self) -> TaskManagerHandle {
711        TaskManagerHandle {
712            inner: self.inner.clone(),
713        }
714    }
715}
716
717#[async_trait]
718impl TaskManagerControl for AsyncInner {
719    async fn next_event(&self) -> Option<TaskEvent> {
720        self.host_event_rx.lock().await.recv().await
721    }
722
723    async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
724        let mut state = self.state.lock().await;
725        let record = state
726            .tasks
727            .get_mut(&task_id)
728            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
729        if let Some(join) = record.join.take() {
730            join.abort();
731        }
732        record.running = false;
733        let snapshot = record.snapshot.clone();
734        if record.snapshot.kind == TaskKind::Foreground
735            && let Some(count) = state.per_turn_running.get_mut(&snapshot.turn_id)
736        {
737            *count = count.saturating_sub(1);
738            if *count == 0 {
739                state.per_turn_running.remove(&snapshot.turn_id);
740            }
741        }
742        let _ = self.host_event_tx.send(TaskEvent::Cancelled(snapshot));
743        self.notify.notify_waiters();
744        Ok(())
745    }
746
747    async fn list_running(&self) -> Vec<TaskSnapshot> {
748        let state = self.state.lock().await;
749        state
750            .tasks
751            .values()
752            .filter(|record| record.running)
753            .map(|record| record.snapshot.clone())
754            .collect()
755    }
756
757    async fn list_completed(&self) -> Vec<TaskSnapshot> {
758        let state = self.state.lock().await;
759        state
760            .tasks
761            .values()
762            .filter(|record| record.completed)
763            .map(|record| record.snapshot.clone())
764            .collect()
765    }
766
767    async fn drain_ready_items(&self) -> Vec<Item> {
768        let mut state = self.state.lock().await;
769        std::mem::take(&mut state.manual_ready_items)
770    }
771
772    async fn set_continue_policy(
773        &self,
774        task_id: TaskId,
775        policy: ContinuePolicy,
776    ) -> Result<(), TaskManagerError> {
777        let mut state = self.state.lock().await;
778        let record = state
779            .tasks
780            .get_mut(&task_id)
781            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
782        record.continue_policy = policy;
783        Ok(())
784    }
785
786    async fn set_delivery_mode(
787        &self,
788        task_id: TaskId,
789        mode: DeliveryMode,
790    ) -> Result<(), TaskManagerError> {
791        let mut state = self.state.lock().await;
792        let record = state
793            .tasks
794            .get_mut(&task_id)
795            .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
796        record.delivery_mode = mode;
797        Ok(())
798    }
799
800    async fn wait_for_idle(&self) {
801        loop {
802            {
803                let state = self.state.lock().await;
804                if !state.tasks.values().any(|r| r.running) {
805                    return;
806                }
807            }
808            self.notify.notified().await;
809        }
810    }
811}
812
813fn map_outcome_to_resolution(
814    task_id: Option<TaskId>,
815    request: ToolRequest,
816    outcome: ToolExecutionOutcome,
817) -> TaskResolution {
818    match outcome {
819        ToolExecutionOutcome::Completed(result) => TaskResolution::Item(Item {
820            id: None,
821            kind: agentkit_core::ItemKind::Tool,
822            parts: vec![agentkit_core::Part::ToolResult(result.result)],
823            metadata: result.metadata,
824            usage: None,
825            finish_reason: None,
826            created_at: None,
827        }),
828        ToolExecutionOutcome::Interrupted(
829            agentkit_tools_core::ToolInterruption::ApprovalRequired(mut approval),
830        ) => {
831            let task_id = task_id.unwrap_or_default();
832            approval.task_id = Some(task_id.clone());
833            TaskResolution::Approval(TaskApproval {
834                task_id,
835                tool_request: request,
836                approval,
837            })
838        }
839        ToolExecutionOutcome::Failed(error) => TaskResolution::Item(Item {
840            id: None,
841            kind: agentkit_core::ItemKind::Tool,
842            parts: vec![agentkit_core::Part::ToolResult(ToolResultPart {
843                call_id: request.call_id,
844                output: agentkit_core::ToolOutput::Text(error.to_string()),
845                is_error: true,
846                metadata: request.metadata,
847            })],
848            metadata: MetadataMap::new(),
849            usage: None,
850            finish_reason: None,
851            created_at: None,
852        }),
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use std::collections::BTreeMap;
859    use std::sync::Arc as StdArc;
860    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
861
862    use agentkit_core::{
863        CancellationController, ItemKind, Part, SessionId, ToolOutput, TurnCancellation,
864    };
865    use agentkit_tools_core::{
866        ApprovalReason, PermissionChecker, PermissionDecision, ToolAnnotations, ToolInterruption,
867        ToolName, ToolResult, ToolSpec,
868    };
869    use serde_json::json;
870    use tokio::sync::Notify;
871    use tokio::time::{Duration, timeout};
872
873    use super::*;
874
875    struct AllowAllPermissions;
876
877    impl PermissionChecker for AllowAllPermissions {
878        fn evaluate(
879            &self,
880            _request: &dyn agentkit_tools_core::PermissionRequest,
881        ) -> PermissionDecision {
882            PermissionDecision::Allow
883        }
884    }
885
886    #[derive(Clone)]
887    enum TestBehavior {
888        Block {
889            entered: StdArc<AtomicBool>,
890            release: StdArc<Notify>,
891            output: &'static str,
892        },
893        Approval,
894    }
895
896    #[derive(Clone)]
897    struct TestExecutor {
898        behaviors: BTreeMap<String, TestBehavior>,
899    }
900
901    impl TestExecutor {
902        fn new(behaviors: impl IntoIterator<Item = (impl Into<String>, TestBehavior)>) -> Self {
903            Self {
904                behaviors: behaviors
905                    .into_iter()
906                    .map(|(name, behavior)| (name.into(), behavior))
907                    .collect(),
908            }
909        }
910    }
911
912    #[async_trait]
913    impl ToolExecutor for TestExecutor {
914        fn specs(&self) -> Vec<ToolSpec> {
915            self.behaviors
916                .keys()
917                .map(|name| ToolSpec {
918                    name: ToolName::new(name),
919                    description: format!("test tool {name}"),
920                    input_schema: json!({
921                        "type": "object",
922                        "properties": {},
923                        "additionalProperties": false
924                    }),
925                    output_schema: None,
926                    annotations: ToolAnnotations::default(),
927                    metadata: MetadataMap::new(),
928                })
929                .collect()
930        }
931
932        async fn execute(
933            &self,
934            request: ToolRequest,
935            _ctx: &mut agentkit_tools_core::ToolContext<'_>,
936        ) -> ToolExecutionOutcome {
937            match self.behaviors.get(request.tool_name.0.as_str()) {
938                Some(TestBehavior::Block {
939                    entered,
940                    release,
941                    output,
942                }) => {
943                    entered.store(true, AtomicOrdering::SeqCst);
944                    release.notified().await;
945                    ToolExecutionOutcome::Completed(ToolResult {
946                        result: ToolResultPart {
947                            call_id: request.call_id,
948                            output: ToolOutput::Text((*output).into()),
949                            is_error: false,
950                            metadata: request.metadata,
951                        },
952                        duration: None,
953                        metadata: MetadataMap::new(),
954                    })
955                }
956                Some(TestBehavior::Approval) => ToolExecutionOutcome::Interrupted(
957                    ToolInterruption::ApprovalRequired(ApprovalRequest {
958                        task_id: None,
959                        call_id: Some(request.call_id.clone()),
960                        id: "approval:test".into(),
961                        request_kind: "tool.test".into(),
962                        reason: ApprovalReason::SensitivePath,
963                        summary: "requires approval".into(),
964                        metadata: MetadataMap::new(),
965                    }),
966                ),
967                None => ToolExecutionOutcome::Failed(ToolError::Unavailable(
968                    request.tool_name.0.clone(),
969                )),
970            }
971        }
972    }
973
974    struct NameRoutingPolicy {
975        routes: BTreeMap<String, RoutingDecision>,
976    }
977
978    impl NameRoutingPolicy {
979        fn new(routes: impl IntoIterator<Item = (impl Into<String>, RoutingDecision)>) -> Self {
980            Self {
981                routes: routes
982                    .into_iter()
983                    .map(|(name, decision)| (name.into(), decision))
984                    .collect(),
985            }
986        }
987    }
988
989    impl TaskRoutingPolicy for NameRoutingPolicy {
990        fn route(&self, request: &ToolRequest) -> RoutingDecision {
991            self.routes
992                .get(request.tool_name.0.as_str())
993                .copied()
994                .unwrap_or(RoutingDecision::Foreground)
995        }
996    }
997
998    fn make_request(tool_name: &str, turn_id: &str, call_id: &str) -> ToolRequest {
999        ToolRequest {
1000            call_id: ToolCallId::new(call_id),
1001            tool_name: ToolName::new(tool_name),
1002            input: json!({}),
1003            session_id: SessionId::new("session-1"),
1004            turn_id: TurnId::new(turn_id),
1005            metadata: MetadataMap::new(),
1006        }
1007    }
1008
1009    fn make_context(
1010        executor: Arc<dyn ToolExecutor>,
1011        turn_id: &TurnId,
1012        cancellation: Option<TurnCancellation>,
1013    ) -> TaskStartContext {
1014        TaskStartContext {
1015            executor,
1016            tool_context: OwnedToolContext {
1017                session_id: SessionId::new("session-1"),
1018                turn_id: turn_id.clone(),
1019                metadata: MetadataMap::new(),
1020                permissions: Arc::new(AllowAllPermissions),
1021                resources: Arc::new(()),
1022                cancellation,
1023                execution_scope: None,
1024                approved_request: None,
1025            },
1026        }
1027    }
1028
1029    async fn next_event(handle: &TaskManagerHandle) -> TaskEvent {
1030        timeout(Duration::from_secs(1), handle.next_event())
1031            .await
1032            .expect("timed out waiting for task event")
1033            .expect("task event stream ended unexpectedly")
1034    }
1035
1036    async fn wait_until_entered(entered: &AtomicBool) {
1037        timeout(Duration::from_secs(1), async {
1038            while !entered.load(AtomicOrdering::SeqCst) {
1039                tokio::task::yield_now().await;
1040            }
1041        })
1042        .await
1043        .expect("task never entered execution");
1044    }
1045
1046    #[tokio::test]
1047    async fn simple_task_manager_executes_inline_and_assigns_task_ids() {
1048        let manager = SimpleTaskManager::new();
1049        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1050            "needs-approval",
1051            TestBehavior::Approval,
1052        )]));
1053        let request = make_request("needs-approval", "turn-1", "call-1");
1054
1055        let outcome = manager
1056            .start_task(
1057                TaskLaunchRequest {
1058                    task_id: None,
1059                    request: request.clone(),
1060                    kind: TaskLaunchKind::Plain,
1061                },
1062                make_context(executor, &request.turn_id, None),
1063            )
1064            .await
1065            .unwrap();
1066
1067        match outcome {
1068            TaskStartOutcome::Ready(resolution) => match *resolution {
1069                TaskResolution::Approval(task) => {
1070                    assert!(!task.task_id.0.is_empty());
1071                    assert_eq!(task.approval.task_id.as_ref(), Some(&task.task_id));
1072                    assert_eq!(task.tool_request.call_id, request.call_id);
1073                }
1074                other => panic!("unexpected task resolution: {other:?}"),
1075            },
1076            other => panic!("unexpected start outcome: {other:?}"),
1077        }
1078
1079        assert!(manager.handle().list_running().await.is_empty());
1080    }
1081
1082    #[tokio::test]
1083    async fn async_manager_interrupt_cancels_foreground_only() {
1084        let fg_release = StdArc::new(Notify::new());
1085        let fg_entered = StdArc::new(AtomicBool::new(false));
1086        let bg_release = StdArc::new(Notify::new());
1087        let bg_entered = StdArc::new(AtomicBool::new(false));
1088        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([
1089            (
1090                "foreground",
1091                TestBehavior::Block {
1092                    entered: fg_entered.clone(),
1093                    release: fg_release.clone(),
1094                    output: "foreground-done",
1095                },
1096            ),
1097            (
1098                "background",
1099                TestBehavior::Block {
1100                    entered: bg_entered.clone(),
1101                    release: bg_release.clone(),
1102                    output: "background-done",
1103                },
1104            ),
1105        ]));
1106        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([
1107            ("foreground", RoutingDecision::Foreground),
1108            ("background", RoutingDecision::Background),
1109        ]));
1110        let handle = manager.handle();
1111        let turn_id = TurnId::new("turn-1");
1112
1113        let foreground = manager
1114            .start_task(
1115                TaskLaunchRequest {
1116                    task_id: None,
1117                    request: make_request("foreground", "turn-1", "call-fg"),
1118                    kind: TaskLaunchKind::Plain,
1119                },
1120                make_context(executor.clone(), &turn_id, None),
1121            )
1122            .await
1123            .unwrap();
1124        let background = manager
1125            .start_task(
1126                TaskLaunchRequest {
1127                    task_id: None,
1128                    request: make_request("background", "turn-1", "call-bg"),
1129                    kind: TaskLaunchKind::Plain,
1130                },
1131                make_context(executor.clone(), &turn_id, None),
1132            )
1133            .await
1134            .unwrap();
1135
1136        assert!(matches!(
1137            foreground,
1138            TaskStartOutcome::Pending {
1139                kind: TaskKind::Foreground,
1140                ..
1141            }
1142        ));
1143        let background_id = match background {
1144            TaskStartOutcome::Pending {
1145                task_id,
1146                kind: TaskKind::Background,
1147            } => task_id,
1148            other => panic!("unexpected background outcome: {other:?}"),
1149        };
1150
1151        let _ = next_event(&handle).await;
1152        let _ = next_event(&handle).await;
1153        wait_until_entered(fg_entered.as_ref()).await;
1154        wait_until_entered(bg_entered.as_ref()).await;
1155
1156        manager.on_turn_interrupted(&turn_id).await.unwrap();
1157
1158        match next_event(&handle).await {
1159            TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.tool_name, "foreground"),
1160            other => panic!("unexpected event after interrupt: {other:?}"),
1161        }
1162
1163        let running = handle.list_running().await;
1164        assert_eq!(running.len(), 1);
1165        assert_eq!(running[0].id, background_id);
1166        assert_eq!(running[0].tool_name, "background");
1167
1168        bg_release.notify_waiters();
1169        match next_event(&handle).await {
1170            TaskEvent::Completed(snapshot, result) => {
1171                assert_eq!(snapshot.id, background_id);
1172                assert_eq!(result.output, ToolOutput::Text("background-done".into()));
1173            }
1174            other => panic!("unexpected completion event: {other:?}"),
1175        }
1176    }
1177
1178    #[tokio::test]
1179    async fn async_manager_can_cancel_background_tasks_by_id() {
1180        let release = StdArc::new(Notify::new());
1181        let entered = StdArc::new(AtomicBool::new(false));
1182        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1183            "background",
1184            TestBehavior::Block {
1185                entered: entered.clone(),
1186                release,
1187                output: "done",
1188            },
1189        )]));
1190        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1191            "background",
1192            RoutingDecision::Background,
1193        )]));
1194        let handle = manager.handle();
1195        let request = make_request("background", "turn-1", "call-1");
1196
1197        let task_id = match manager
1198            .start_task(
1199                TaskLaunchRequest {
1200                    task_id: None,
1201                    request: request.clone(),
1202                    kind: TaskLaunchKind::Plain,
1203                },
1204                make_context(executor, &request.turn_id, None),
1205            )
1206            .await
1207            .unwrap()
1208        {
1209            TaskStartOutcome::Pending { task_id, .. } => task_id,
1210            other => panic!("unexpected start outcome: {other:?}"),
1211        };
1212
1213        let _ = next_event(&handle).await;
1214        wait_until_entered(entered.as_ref()).await;
1215        handle.cancel(task_id.clone()).await.unwrap();
1216
1217        match next_event(&handle).await {
1218            TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.id, task_id),
1219            other => panic!("unexpected event after cancel: {other:?}"),
1220        }
1221
1222        assert!(handle.list_running().await.is_empty());
1223    }
1224
1225    #[tokio::test]
1226    async fn async_manager_manual_delivery_keeps_results_out_of_loop_updates() {
1227        let release = StdArc::new(Notify::new());
1228        let entered = StdArc::new(AtomicBool::new(false));
1229        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1230            "background",
1231            TestBehavior::Block {
1232                entered: entered.clone(),
1233                release: release.clone(),
1234                output: "manual-done",
1235            },
1236        )]));
1237        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1238            "background",
1239            RoutingDecision::Background,
1240        )]));
1241        let handle = manager.handle();
1242        let request = make_request("background", "turn-1", "call-1");
1243
1244        let task_id = match manager
1245            .start_task(
1246                TaskLaunchRequest {
1247                    task_id: None,
1248                    request: request.clone(),
1249                    kind: TaskLaunchKind::Plain,
1250                },
1251                make_context(executor, &request.turn_id, None),
1252            )
1253            .await
1254            .unwrap()
1255        {
1256            TaskStartOutcome::Pending { task_id, .. } => task_id,
1257            other => panic!("unexpected start outcome: {other:?}"),
1258        };
1259
1260        let _ = next_event(&handle).await;
1261        wait_until_entered(entered.as_ref()).await;
1262        handle
1263            .set_continue_policy(task_id.clone(), ContinuePolicy::RequestContinue)
1264            .await
1265            .unwrap();
1266        handle
1267            .set_delivery_mode(task_id, DeliveryMode::Manual)
1268            .await
1269            .unwrap();
1270
1271        release.notify_waiters();
1272        match next_event(&handle).await {
1273            TaskEvent::Completed(_, result) => {
1274                assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1275            }
1276            other => panic!("unexpected event: {other:?}"),
1277        }
1278
1279        assert!(
1280            timeout(Duration::from_millis(50), handle.next_event())
1281                .await
1282                .is_err()
1283        );
1284        assert!(
1285            manager
1286                .take_pending_loop_updates()
1287                .await
1288                .unwrap()
1289                .resolutions
1290                .is_empty()
1291        );
1292
1293        let ready_items = handle.drain_ready_items().await;
1294        assert_eq!(ready_items.len(), 1);
1295        assert_eq!(ready_items[0].kind, ItemKind::Tool);
1296        match &ready_items[0].parts[0] {
1297            Part::ToolResult(result) => {
1298                assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1299            }
1300            other => panic!("unexpected ready item: {other:?}"),
1301        }
1302    }
1303
1304    #[tokio::test]
1305    async fn async_manager_to_loop_delivery_can_request_continue() {
1306        let release = StdArc::new(Notify::new());
1307        let entered = StdArc::new(AtomicBool::new(false));
1308        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1309            "background",
1310            TestBehavior::Block {
1311                entered: entered.clone(),
1312                release: release.clone(),
1313                output: "loop-done",
1314            },
1315        )]));
1316        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1317            "background",
1318            RoutingDecision::Background,
1319        )]));
1320        let handle = manager.handle();
1321        let request = make_request("background", "turn-1", "call-1");
1322
1323        let task_id = match manager
1324            .start_task(
1325                TaskLaunchRequest {
1326                    task_id: None,
1327                    request: request.clone(),
1328                    kind: TaskLaunchKind::Plain,
1329                },
1330                make_context(
1331                    executor,
1332                    &request.turn_id,
1333                    Some(TurnCancellation::new(
1334                        CancellationController::new().handle(),
1335                    )),
1336                ),
1337            )
1338            .await
1339            .unwrap()
1340        {
1341            TaskStartOutcome::Pending { task_id, .. } => task_id,
1342            other => panic!("unexpected start outcome: {other:?}"),
1343        };
1344
1345        let _ = next_event(&handle).await;
1346        wait_until_entered(entered.as_ref()).await;
1347        handle
1348            .set_continue_policy(task_id, ContinuePolicy::RequestContinue)
1349            .await
1350            .unwrap();
1351
1352        release.notify_waiters();
1353        match next_event(&handle).await {
1354            TaskEvent::Completed(_, result) => {
1355                assert_eq!(result.output, ToolOutput::Text("loop-done".into()))
1356            }
1357            other => panic!("unexpected completion event: {other:?}"),
1358        }
1359        match next_event(&handle).await {
1360            TaskEvent::ContinueRequested => {}
1361            other => panic!("unexpected follow-up event: {other:?}"),
1362        }
1363
1364        let updates = manager.take_pending_loop_updates().await.unwrap();
1365        assert_eq!(updates.resolutions.len(), 1);
1366        assert!(handle.drain_ready_items().await.is_empty());
1367    }
1368
1369    #[tokio::test]
1370    async fn wait_for_idle_returns_after_loop_updates_are_queued() {
1371        let release = StdArc::new(Notify::new());
1372        let entered = StdArc::new(AtomicBool::new(false));
1373        let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1374            "background",
1375            TestBehavior::Block {
1376                entered: entered.clone(),
1377                release: release.clone(),
1378                output: "idle-done",
1379            },
1380        )]));
1381        let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1382            "background",
1383            RoutingDecision::Background,
1384        )]));
1385        let handle = manager.handle();
1386        let request = make_request("background", "turn-1", "call-1");
1387
1388        let outcome = manager
1389            .start_task(
1390                TaskLaunchRequest {
1391                    task_id: None,
1392                    request: request.clone(),
1393                    kind: TaskLaunchKind::Plain,
1394                },
1395                make_context(executor, &request.turn_id, None),
1396            )
1397            .await
1398            .unwrap();
1399        assert!(matches!(outcome, TaskStartOutcome::Pending { .. }));
1400
1401        let _ = next_event(&handle).await;
1402        wait_until_entered(entered.as_ref()).await;
1403        release.notify_waiters();
1404
1405        timeout(Duration::from_secs(1), handle.wait_for_idle())
1406            .await
1407            .expect("wait_for_idle timed out");
1408
1409        let updates = manager.take_pending_loop_updates().await.unwrap();
1410        assert_eq!(updates.resolutions.len(), 1);
1411        match &updates.resolutions[0] {
1412            TaskResolution::Item(item) => match &item.parts[0] {
1413                Part::ToolResult(result) => {
1414                    assert_eq!(result.call_id, request.call_id);
1415                    assert_eq!(result.output, ToolOutput::Text("idle-done".into()));
1416                }
1417                other => panic!("unexpected tool item: {other:?}"),
1418            },
1419            other => panic!("unexpected pending update: {other:?}"),
1420        }
1421    }
1422}