1use std::collections::{BTreeMap, VecDeque};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use agentkit_core::{
7 Item, MetadataMap, TaskId, ToolCallId, ToolResultPart, TurnCancellation, TurnId,
8};
9use agentkit_tools_core::{
10 ApprovalRequest, OwnedToolContext, ToolError, ToolExecutionOutcome, ToolExecutor, ToolRequest,
11};
12use async_trait::async_trait;
13use thiserror::Error;
14use tokio::sync::{Mutex, Notify, mpsc};
15use tokio::task::JoinHandle;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum TaskKind {
19 Foreground,
20 Background,
21}
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ContinuePolicy {
25 NotifyOnly,
26 RequestContinue,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
30pub enum DeliveryMode {
31 ToLoop,
32 Manual,
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct TaskSnapshot {
37 pub id: TaskId,
38 pub turn_id: TurnId,
39 pub call_id: ToolCallId,
40 pub tool_name: String,
41 pub kind: TaskKind,
42 pub metadata: MetadataMap,
43}
44
45#[derive(Clone, Debug, PartialEq)]
46pub enum TaskEvent {
47 Started(TaskSnapshot),
48 Detached(TaskSnapshot),
49 Completed(TaskSnapshot, ToolResultPart),
50 Cancelled(TaskSnapshot),
51 Failed(TaskSnapshot, ToolError),
52 ContinueRequested,
53}
54
55#[derive(Clone, Debug, PartialEq)]
56pub struct TaskApproval {
57 pub task_id: TaskId,
58 pub tool_request: ToolRequest,
59 pub approval: ApprovalRequest,
60}
61
62#[derive(Clone, Debug, PartialEq)]
63pub enum TaskResolution {
64 Item(Item),
65 Approval(TaskApproval),
66}
67
68#[derive(Clone, Debug, PartialEq)]
69pub enum TaskStartOutcome {
70 Ready(Box<TaskResolution>),
71 Pending { task_id: TaskId, kind: TaskKind },
72}
73
74#[derive(Clone, Debug, PartialEq)]
75pub enum TurnTaskUpdate {
76 Resolution(Box<TaskResolution>),
77 Detached(TaskSnapshot),
78}
79
80#[derive(Clone, Debug, Default, PartialEq)]
81pub struct PendingLoopUpdates {
82 pub resolutions: VecDeque<TaskResolution>,
83}
84
85#[derive(Clone, Debug, Default)]
88pub enum TaskLaunchKind {
89 #[default]
91 Plain,
92 Approved(ApprovalRequest),
94}
95
96#[derive(Clone, Debug)]
97pub struct TaskLaunchRequest {
98 pub task_id: Option<TaskId>,
99 pub request: ToolRequest,
100 pub kind: TaskLaunchKind,
101}
102
103impl TaskLaunchRequest {
104 pub fn plain(task_id: Option<TaskId>, request: ToolRequest) -> Self {
106 Self {
107 task_id,
108 request,
109 kind: TaskLaunchKind::Plain,
110 }
111 }
112
113 pub fn approved(
115 task_id: Option<TaskId>,
116 request: ToolRequest,
117 approval: ApprovalRequest,
118 ) -> Self {
119 Self {
120 task_id,
121 request,
122 kind: TaskLaunchKind::Approved(approval),
123 }
124 }
125}
126
127#[derive(Clone)]
128pub struct TaskStartContext {
129 pub executor: Arc<dyn ToolExecutor>,
130 pub tool_context: OwnedToolContext,
131}
132
133#[derive(Debug, Error, Clone, PartialEq, Eq)]
134pub enum TaskManagerError {
135 #[error("task not found: {0}")]
136 NotFound(TaskId),
137 #[error("task manager internal error: {0}")]
138 Internal(String),
139}
140
141pub trait TaskRoutingPolicy: Send + Sync {
142 fn route(&self, request: &ToolRequest) -> RoutingDecision;
143}
144
145impl<F> TaskRoutingPolicy for F
146where
147 F: Fn(&ToolRequest) -> RoutingDecision + Send + Sync,
148{
149 fn route(&self, request: &ToolRequest) -> RoutingDecision {
150 self(request)
151 }
152}
153
154#[derive(Clone, Copy, Debug, PartialEq, Eq)]
155pub enum RoutingDecision {
156 Foreground,
157 Background,
158 ForegroundThenDetachAfter(Duration),
159}
160
161struct DefaultRoutingPolicy;
162
163impl TaskRoutingPolicy for DefaultRoutingPolicy {
164 fn route(&self, _request: &ToolRequest) -> RoutingDecision {
165 RoutingDecision::Foreground
166 }
167}
168
169#[async_trait]
170pub trait TaskManager: Send + Sync {
171 async fn start_task(
172 &self,
173 request: TaskLaunchRequest,
174 ctx: TaskStartContext,
175 ) -> Result<TaskStartOutcome, TaskManagerError>;
176
177 async fn wait_for_turn(
178 &self,
179 turn_id: &TurnId,
180 cancellation: Option<TurnCancellation>,
181 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError>;
182
183 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError>;
184
185 async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError>;
186
187 fn handle(&self) -> TaskManagerHandle;
188}
189
190#[async_trait]
191trait TaskManagerControl: Send + Sync {
192 async fn next_event(&self) -> Option<TaskEvent>;
193 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError>;
194 async fn list_running(&self) -> Vec<TaskSnapshot>;
195 async fn list_completed(&self) -> Vec<TaskSnapshot>;
196 async fn drain_ready_items(&self) -> Vec<Item>;
197 async fn set_continue_policy(
198 &self,
199 task_id: TaskId,
200 policy: ContinuePolicy,
201 ) -> Result<(), TaskManagerError>;
202 async fn set_delivery_mode(
203 &self,
204 task_id: TaskId,
205 mode: DeliveryMode,
206 ) -> Result<(), TaskManagerError>;
207 async fn wait_for_idle(&self);
208}
209
210#[derive(Clone)]
211pub struct TaskManagerHandle {
212 inner: Arc<dyn TaskManagerControl>,
213}
214
215impl TaskManagerHandle {
216 pub async fn next_event(&self) -> Option<TaskEvent> {
217 self.inner.next_event().await
218 }
219
220 pub async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
221 self.inner.cancel(task_id).await
222 }
223
224 pub async fn list_running(&self) -> Vec<TaskSnapshot> {
225 self.inner.list_running().await
226 }
227
228 pub async fn list_completed(&self) -> Vec<TaskSnapshot> {
229 self.inner.list_completed().await
230 }
231
232 pub async fn drain_ready_items(&self) -> Vec<Item> {
233 self.inner.drain_ready_items().await
234 }
235
236 pub async fn set_continue_policy(
237 &self,
238 task_id: TaskId,
239 policy: ContinuePolicy,
240 ) -> Result<(), TaskManagerError> {
241 self.inner.set_continue_policy(task_id, policy).await
242 }
243
244 pub async fn set_delivery_mode(
245 &self,
246 task_id: TaskId,
247 mode: DeliveryMode,
248 ) -> Result<(), TaskManagerError> {
249 self.inner.set_delivery_mode(task_id, mode).await
250 }
251
252 pub async fn wait_for_idle(&self) {
254 self.inner.wait_for_idle().await
255 }
256}
257
258pub struct SimpleTaskManager {
259 state: Arc<HandleState>,
260}
261
262impl SimpleTaskManager {
263 pub fn new() -> Self {
264 Self {
265 state: Arc::new(HandleState::default()),
266 }
267 }
268}
269
270impl Default for SimpleTaskManager {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[async_trait]
277impl TaskManager for SimpleTaskManager {
278 async fn start_task(
279 &self,
280 request: TaskLaunchRequest,
281 ctx: TaskStartContext,
282 ) -> Result<TaskStartOutcome, TaskManagerError> {
283 let task_id = request
284 .task_id
285 .clone()
286 .unwrap_or_else(|| self.state.next_task_id());
287 let outcome = match &request.kind {
288 TaskLaunchKind::Approved(approved) => {
289 ctx.executor
290 .execute_approved_owned(request.request.clone(), approved, ctx.tool_context)
291 .await
292 }
293 TaskLaunchKind::Plain => {
294 ctx.executor
295 .execute_owned(request.request.clone(), ctx.tool_context)
296 .await
297 }
298 };
299 Ok(TaskStartOutcome::Ready(Box::new(
300 map_outcome_to_resolution(Some(task_id), request.request, outcome),
301 )))
302 }
303
304 async fn wait_for_turn(
305 &self,
306 _turn_id: &TurnId,
307 _cancellation: Option<TurnCancellation>,
308 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
309 Ok(None)
310 }
311
312 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
313 Ok(PendingLoopUpdates::default())
314 }
315
316 async fn on_turn_interrupted(&self, _turn_id: &TurnId) -> Result<(), TaskManagerError> {
317 Ok(())
318 }
319
320 fn handle(&self) -> TaskManagerHandle {
321 TaskManagerHandle {
322 inner: self.state.clone(),
323 }
324 }
325}
326
327#[derive(Default)]
328struct HandleState {
329 next_task_index: AtomicU64,
330 events_rx: Mutex<Option<mpsc::UnboundedReceiver<TaskEvent>>>,
331}
332
333impl HandleState {
334 fn next_task_id(&self) -> TaskId {
335 let next = self.next_task_index.fetch_add(1, Ordering::SeqCst) + 1;
336 TaskId::new(format!("task-{}", next))
337 }
338}
339
340#[async_trait]
341impl TaskManagerControl for HandleState {
342 async fn next_event(&self) -> Option<TaskEvent> {
343 let mut rx = self.events_rx.lock().await;
344 match rx.as_mut() {
345 Some(inner) => inner.recv().await,
346 None => None,
347 }
348 }
349
350 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
351 Err(TaskManagerError::NotFound(task_id))
352 }
353
354 async fn list_running(&self) -> Vec<TaskSnapshot> {
355 Vec::new()
356 }
357
358 async fn list_completed(&self) -> Vec<TaskSnapshot> {
359 Vec::new()
360 }
361
362 async fn drain_ready_items(&self) -> Vec<Item> {
363 Vec::new()
364 }
365
366 async fn set_continue_policy(
367 &self,
368 task_id: TaskId,
369 _policy: ContinuePolicy,
370 ) -> Result<(), TaskManagerError> {
371 Err(TaskManagerError::NotFound(task_id))
372 }
373
374 async fn set_delivery_mode(
375 &self,
376 task_id: TaskId,
377 _mode: DeliveryMode,
378 ) -> Result<(), TaskManagerError> {
379 Err(TaskManagerError::NotFound(task_id))
380 }
381
382 async fn wait_for_idle(&self) {}
383}
384
385pub struct AsyncTaskManager {
386 inner: Arc<AsyncInner>,
387 routing: Arc<dyn TaskRoutingPolicy>,
388}
389
390impl AsyncTaskManager {
391 pub fn new() -> Self {
392 let (event_tx, event_rx) = mpsc::unbounded_channel();
393 Self {
394 inner: Arc::new(AsyncInner {
395 state: Mutex::new(AsyncState::default()),
396 host_event_tx: event_tx,
397 host_event_rx: Mutex::new(event_rx),
398 notify: Notify::new(),
399 }),
400 routing: Arc::new(DefaultRoutingPolicy),
401 }
402 }
403
404 pub fn routing(mut self, policy: impl TaskRoutingPolicy + 'static) -> Self {
405 self.routing = Arc::new(policy);
406 self
407 }
408}
409
410impl Default for AsyncTaskManager {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[derive(Default)]
417struct AsyncState {
418 next_task_index: u64,
419 tasks: BTreeMap<TaskId, TaskRecord>,
420 per_turn_running: BTreeMap<TurnId, usize>,
421 per_turn_updates: BTreeMap<TurnId, VecDeque<TurnTaskUpdate>>,
422 pending_loop_updates: VecDeque<TaskResolution>,
423 manual_ready_items: Vec<Item>,
424}
425
426struct TaskRecord {
427 snapshot: TaskSnapshot,
428 continue_policy: ContinuePolicy,
429 delivery_mode: DeliveryMode,
430 running: bool,
431 completed: bool,
432 join: Option<JoinHandle<()>>,
433}
434
435struct AsyncInner {
436 state: Mutex<AsyncState>,
437 host_event_tx: mpsc::UnboundedSender<TaskEvent>,
438 host_event_rx: Mutex<mpsc::UnboundedReceiver<TaskEvent>>,
439 notify: Notify,
440}
441
442impl AsyncInner {
443 async fn next_task_id(&self) -> TaskId {
444 let mut state = self.state.lock().await;
445 state.next_task_index += 1;
446 TaskId::new(format!("task-{}", state.next_task_index))
447 }
448}
449
450#[async_trait]
451impl TaskManager for AsyncTaskManager {
452 async fn start_task(
453 &self,
454 request: TaskLaunchRequest,
455 ctx: TaskStartContext,
456 ) -> Result<TaskStartOutcome, TaskManagerError> {
457 let route = self.routing.route(&request.request);
458 let task_id = match request.task_id.clone() {
459 Some(existing) => existing,
460 None => self.inner.next_task_id().await,
461 };
462 let initial_kind = match route {
463 RoutingDecision::Background => TaskKind::Background,
464 _ => TaskKind::Foreground,
465 };
466 let snapshot = TaskSnapshot {
467 id: task_id.clone(),
468 turn_id: request.request.turn_id.clone(),
469 call_id: request.request.call_id.clone(),
470 tool_name: request.request.tool_name.to_string(),
471 kind: initial_kind,
472 metadata: request.request.metadata.clone(),
473 };
474 let _ = self
475 .inner
476 .host_event_tx
477 .send(TaskEvent::Started(snapshot.clone()));
478
479 let mut state = self.inner.state.lock().await;
480 state.tasks.insert(
481 task_id.clone(),
482 TaskRecord {
483 snapshot: snapshot.clone(),
484 continue_policy: ContinuePolicy::NotifyOnly,
485 delivery_mode: DeliveryMode::ToLoop,
486 running: true,
487 completed: false,
488 join: None,
489 },
490 );
491 if initial_kind == TaskKind::Foreground {
492 *state
493 .per_turn_running
494 .entry(snapshot.turn_id.clone())
495 .or_default() += 1;
496 }
497 drop(state);
498
499 let event_tx = self.inner.host_event_tx.clone();
500 let inner = self.inner.clone();
501 let task_id_for_future = task_id.clone();
502 let turn_id = snapshot.turn_id.clone();
503 let kind = request.kind.clone();
504 let exec_request = request.request.clone();
505 let owned_ctx = ctx.tool_context.clone();
506 let executor = ctx.executor.clone();
507 let route_copy = route;
508 let join = tokio::spawn(async move {
509 if let RoutingDecision::ForegroundThenDetachAfter(duration) = route_copy {
510 let event_tx = event_tx.clone();
511 let inner = inner.clone();
512 let task_id = task_id_for_future.clone();
513 let turn_id = turn_id.clone();
514 tokio::spawn(async move {
515 tokio::time::sleep(duration).await;
516 let mut state = inner.state.lock().await;
517 let snapshot = if let Some(record) = state.tasks.get_mut(&task_id)
518 && record.running
519 && record.snapshot.kind == TaskKind::Foreground
520 {
521 record.snapshot.kind = TaskKind::Background;
522 Some(record.snapshot.clone())
523 } else {
524 None
525 };
526 if let Some(snapshot) = snapshot {
527 if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
528 *count = count.saturating_sub(1);
529 if *count == 0 {
530 state.per_turn_running.remove(&turn_id);
531 }
532 }
533 state
534 .per_turn_updates
535 .entry(turn_id.clone())
536 .or_default()
537 .push_back(TurnTaskUpdate::Detached(snapshot.clone()));
538 let _ = event_tx.send(TaskEvent::Detached(snapshot));
539 inner.notify.notify_waiters();
540 }
541 });
542 }
543
544 let outcome = match &kind {
545 TaskLaunchKind::Approved(approval) => {
546 executor
547 .execute_approved_owned(exec_request.clone(), approval, owned_ctx)
548 .await
549 }
550 TaskLaunchKind::Plain => {
551 executor
552 .execute_owned(exec_request.clone(), owned_ctx)
553 .await
554 }
555 };
556
557 let resolution =
558 map_outcome_to_resolution(Some(task_id_for_future.clone()), exec_request, outcome);
559 let completed_result = match &resolution {
560 TaskResolution::Item(item) => item.parts.iter().find_map(|part| match part {
561 agentkit_core::Part::ToolResult(result) => Some(result.clone()),
562 _ => None,
563 }),
564 TaskResolution::Approval(_) => None,
565 };
566
567 let (snapshot, should_request_continue) = {
568 let mut state = inner.state.lock().await;
569 let Some(record) = state.tasks.get_mut(&task_id_for_future) else {
570 return;
571 };
572 record.running = false;
573 record.completed = true;
574 let snapshot = record.snapshot.clone();
575 let continue_policy = record.continue_policy;
576 let delivery_mode = record.delivery_mode;
577 let current_kind = snapshot.kind;
578
579 if current_kind == TaskKind::Foreground {
580 if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
581 *count = count.saturating_sub(1);
582 if *count == 0 {
583 state.per_turn_running.remove(&turn_id);
584 }
585 }
586 state
587 .per_turn_updates
588 .entry(turn_id.clone())
589 .or_default()
590 .push_back(TurnTaskUpdate::Resolution(Box::new(resolution.clone())));
591 } else {
592 match &resolution {
593 TaskResolution::Item(_) if delivery_mode == DeliveryMode::ToLoop => {
594 state.pending_loop_updates.push_back(resolution.clone());
595 }
596 TaskResolution::Approval(_) if delivery_mode == DeliveryMode::ToLoop => {
597 state.pending_loop_updates.push_back(resolution.clone());
598 }
599 TaskResolution::Item(item) => {
600 state.manual_ready_items.push(item.clone());
601 }
602 TaskResolution::Approval(_) => {}
603 }
604 }
605
606 (
607 snapshot,
608 current_kind == TaskKind::Background
609 && delivery_mode == DeliveryMode::ToLoop
610 && continue_policy == ContinuePolicy::RequestContinue,
611 )
612 };
613
614 if let Some(result) = completed_result {
615 let _ = event_tx.send(TaskEvent::Completed(snapshot.clone(), result));
616 }
617 if should_request_continue {
618 let _ = event_tx.send(TaskEvent::ContinueRequested);
619 }
620 inner.notify.notify_waiters();
621 });
622
623 let mut state = self.inner.state.lock().await;
624 if let Some(record) = state.tasks.get_mut(&task_id) {
625 record.join = Some(join);
626 }
627 Ok(TaskStartOutcome::Pending {
628 task_id,
629 kind: initial_kind,
630 })
631 }
632
633 async fn wait_for_turn(
634 &self,
635 turn_id: &TurnId,
636 cancellation: Option<TurnCancellation>,
637 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
638 loop {
639 {
640 let mut state = self.inner.state.lock().await;
641 if let Some(queue) = state.per_turn_updates.get_mut(turn_id)
642 && let Some(update) = queue.pop_front()
643 {
644 return Ok(Some(update));
645 }
646 if state
647 .per_turn_running
648 .get(turn_id)
649 .copied()
650 .unwrap_or_default()
651 == 0
652 {
653 return Ok(None);
654 }
655 }
656 if cancellation
657 .as_ref()
658 .is_some_and(TurnCancellation::is_cancelled)
659 {
660 return Ok(None);
661 }
662 if let Some(cancellation) = cancellation.as_ref() {
663 tokio::select! {
664 _ = self.inner.notify.notified() => {}
665 _ = cancellation.cancelled() => return Ok(None),
666 }
667 } else {
668 self.inner.notify.notified().await;
669 }
670 }
671 }
672
673 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
674 let mut state = self.inner.state.lock().await;
675 Ok(PendingLoopUpdates {
676 resolutions: std::mem::take(&mut state.pending_loop_updates),
677 })
678 }
679
680 async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError> {
681 let mut state = self.inner.state.lock().await;
682 let interrupted: Vec<TaskId> = state
683 .tasks
684 .iter()
685 .filter_map(|(id, record)| {
686 (record.snapshot.turn_id == *turn_id
687 && record.snapshot.kind == TaskKind::Foreground
688 && record.running)
689 .then_some(id.clone())
690 })
691 .collect();
692 for task_id in interrupted {
693 if let Some(record) = state.tasks.get_mut(&task_id) {
694 record.running = false;
695 if let Some(join) = record.join.take() {
696 join.abort();
697 }
698 let snapshot = record.snapshot.clone();
699 let _ = self
700 .inner
701 .host_event_tx
702 .send(TaskEvent::Cancelled(snapshot));
703 }
704 }
705 state.per_turn_running.remove(turn_id);
706 self.inner.notify.notify_waiters();
707 Ok(())
708 }
709
710 fn handle(&self) -> TaskManagerHandle {
711 TaskManagerHandle {
712 inner: self.inner.clone(),
713 }
714 }
715}
716
717#[async_trait]
718impl TaskManagerControl for AsyncInner {
719 async fn next_event(&self) -> Option<TaskEvent> {
720 self.host_event_rx.lock().await.recv().await
721 }
722
723 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
724 let mut state = self.state.lock().await;
725 let record = state
726 .tasks
727 .get_mut(&task_id)
728 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
729 if let Some(join) = record.join.take() {
730 join.abort();
731 }
732 record.running = false;
733 let snapshot = record.snapshot.clone();
734 if record.snapshot.kind == TaskKind::Foreground
735 && let Some(count) = state.per_turn_running.get_mut(&snapshot.turn_id)
736 {
737 *count = count.saturating_sub(1);
738 if *count == 0 {
739 state.per_turn_running.remove(&snapshot.turn_id);
740 }
741 }
742 let _ = self.host_event_tx.send(TaskEvent::Cancelled(snapshot));
743 self.notify.notify_waiters();
744 Ok(())
745 }
746
747 async fn list_running(&self) -> Vec<TaskSnapshot> {
748 let state = self.state.lock().await;
749 state
750 .tasks
751 .values()
752 .filter(|record| record.running)
753 .map(|record| record.snapshot.clone())
754 .collect()
755 }
756
757 async fn list_completed(&self) -> Vec<TaskSnapshot> {
758 let state = self.state.lock().await;
759 state
760 .tasks
761 .values()
762 .filter(|record| record.completed)
763 .map(|record| record.snapshot.clone())
764 .collect()
765 }
766
767 async fn drain_ready_items(&self) -> Vec<Item> {
768 let mut state = self.state.lock().await;
769 std::mem::take(&mut state.manual_ready_items)
770 }
771
772 async fn set_continue_policy(
773 &self,
774 task_id: TaskId,
775 policy: ContinuePolicy,
776 ) -> Result<(), TaskManagerError> {
777 let mut state = self.state.lock().await;
778 let record = state
779 .tasks
780 .get_mut(&task_id)
781 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
782 record.continue_policy = policy;
783 Ok(())
784 }
785
786 async fn set_delivery_mode(
787 &self,
788 task_id: TaskId,
789 mode: DeliveryMode,
790 ) -> Result<(), TaskManagerError> {
791 let mut state = self.state.lock().await;
792 let record = state
793 .tasks
794 .get_mut(&task_id)
795 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
796 record.delivery_mode = mode;
797 Ok(())
798 }
799
800 async fn wait_for_idle(&self) {
801 loop {
802 {
803 let state = self.state.lock().await;
804 if !state.tasks.values().any(|r| r.running) {
805 return;
806 }
807 }
808 self.notify.notified().await;
809 }
810 }
811}
812
813fn map_outcome_to_resolution(
814 task_id: Option<TaskId>,
815 request: ToolRequest,
816 outcome: ToolExecutionOutcome,
817) -> TaskResolution {
818 match outcome {
819 ToolExecutionOutcome::Completed(result) => TaskResolution::Item(Item {
820 id: None,
821 kind: agentkit_core::ItemKind::Tool,
822 parts: vec![agentkit_core::Part::ToolResult(result.result)],
823 metadata: result.metadata,
824 usage: None,
825 finish_reason: None,
826 created_at: None,
827 }),
828 ToolExecutionOutcome::Interrupted(
829 agentkit_tools_core::ToolInterruption::ApprovalRequired(mut approval),
830 ) => {
831 let task_id = task_id.unwrap_or_default();
832 approval.task_id = Some(task_id.clone());
833 TaskResolution::Approval(TaskApproval {
834 task_id,
835 tool_request: request,
836 approval,
837 })
838 }
839 ToolExecutionOutcome::Failed(error) => TaskResolution::Item(Item {
840 id: None,
841 kind: agentkit_core::ItemKind::Tool,
842 parts: vec![agentkit_core::Part::ToolResult(ToolResultPart {
843 call_id: request.call_id,
844 output: agentkit_core::ToolOutput::Text(error.to_string()),
845 is_error: true,
846 metadata: request.metadata,
847 })],
848 metadata: MetadataMap::new(),
849 usage: None,
850 finish_reason: None,
851 created_at: None,
852 }),
853 }
854}
855
856#[cfg(test)]
857mod tests {
858 use std::collections::BTreeMap;
859 use std::sync::Arc as StdArc;
860 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
861
862 use agentkit_core::{
863 CancellationController, ItemKind, Part, SessionId, ToolOutput, TurnCancellation,
864 };
865 use agentkit_tools_core::{
866 ApprovalReason, PermissionChecker, PermissionDecision, ToolAnnotations, ToolInterruption,
867 ToolName, ToolResult, ToolSpec,
868 };
869 use serde_json::json;
870 use tokio::sync::Notify;
871 use tokio::time::{Duration, timeout};
872
873 use super::*;
874
875 struct AllowAllPermissions;
876
877 impl PermissionChecker for AllowAllPermissions {
878 fn evaluate(
879 &self,
880 _request: &dyn agentkit_tools_core::PermissionRequest,
881 ) -> PermissionDecision {
882 PermissionDecision::Allow
883 }
884 }
885
886 #[derive(Clone)]
887 enum TestBehavior {
888 Block {
889 entered: StdArc<AtomicBool>,
890 release: StdArc<Notify>,
891 output: &'static str,
892 },
893 Approval,
894 }
895
896 #[derive(Clone)]
897 struct TestExecutor {
898 behaviors: BTreeMap<String, TestBehavior>,
899 }
900
901 impl TestExecutor {
902 fn new(behaviors: impl IntoIterator<Item = (impl Into<String>, TestBehavior)>) -> Self {
903 Self {
904 behaviors: behaviors
905 .into_iter()
906 .map(|(name, behavior)| (name.into(), behavior))
907 .collect(),
908 }
909 }
910 }
911
912 #[async_trait]
913 impl ToolExecutor for TestExecutor {
914 fn specs(&self) -> Vec<ToolSpec> {
915 self.behaviors
916 .keys()
917 .map(|name| ToolSpec {
918 name: ToolName::new(name),
919 description: format!("test tool {name}"),
920 input_schema: json!({
921 "type": "object",
922 "properties": {},
923 "additionalProperties": false
924 }),
925 output_schema: None,
926 annotations: ToolAnnotations::default(),
927 metadata: MetadataMap::new(),
928 })
929 .collect()
930 }
931
932 async fn execute(
933 &self,
934 request: ToolRequest,
935 _ctx: &mut agentkit_tools_core::ToolContext<'_>,
936 ) -> ToolExecutionOutcome {
937 match self.behaviors.get(request.tool_name.0.as_str()) {
938 Some(TestBehavior::Block {
939 entered,
940 release,
941 output,
942 }) => {
943 entered.store(true, AtomicOrdering::SeqCst);
944 release.notified().await;
945 ToolExecutionOutcome::Completed(ToolResult {
946 result: ToolResultPart {
947 call_id: request.call_id,
948 output: ToolOutput::Text((*output).into()),
949 is_error: false,
950 metadata: request.metadata,
951 },
952 duration: None,
953 metadata: MetadataMap::new(),
954 })
955 }
956 Some(TestBehavior::Approval) => ToolExecutionOutcome::Interrupted(
957 ToolInterruption::ApprovalRequired(ApprovalRequest {
958 task_id: None,
959 call_id: Some(request.call_id.clone()),
960 id: "approval:test".into(),
961 request_kind: "tool.test".into(),
962 reason: ApprovalReason::SensitivePath,
963 summary: "requires approval".into(),
964 metadata: MetadataMap::new(),
965 }),
966 ),
967 None => ToolExecutionOutcome::Failed(ToolError::Unavailable(
968 request.tool_name.0.clone(),
969 )),
970 }
971 }
972 }
973
974 struct NameRoutingPolicy {
975 routes: BTreeMap<String, RoutingDecision>,
976 }
977
978 impl NameRoutingPolicy {
979 fn new(routes: impl IntoIterator<Item = (impl Into<String>, RoutingDecision)>) -> Self {
980 Self {
981 routes: routes
982 .into_iter()
983 .map(|(name, decision)| (name.into(), decision))
984 .collect(),
985 }
986 }
987 }
988
989 impl TaskRoutingPolicy for NameRoutingPolicy {
990 fn route(&self, request: &ToolRequest) -> RoutingDecision {
991 self.routes
992 .get(request.tool_name.0.as_str())
993 .copied()
994 .unwrap_or(RoutingDecision::Foreground)
995 }
996 }
997
998 fn make_request(tool_name: &str, turn_id: &str, call_id: &str) -> ToolRequest {
999 ToolRequest {
1000 call_id: ToolCallId::new(call_id),
1001 tool_name: ToolName::new(tool_name),
1002 input: json!({}),
1003 session_id: SessionId::new("session-1"),
1004 turn_id: TurnId::new(turn_id),
1005 metadata: MetadataMap::new(),
1006 }
1007 }
1008
1009 fn make_context(
1010 executor: Arc<dyn ToolExecutor>,
1011 turn_id: &TurnId,
1012 cancellation: Option<TurnCancellation>,
1013 ) -> TaskStartContext {
1014 TaskStartContext {
1015 executor,
1016 tool_context: OwnedToolContext {
1017 session_id: SessionId::new("session-1"),
1018 turn_id: turn_id.clone(),
1019 metadata: MetadataMap::new(),
1020 permissions: Arc::new(AllowAllPermissions),
1021 resources: Arc::new(()),
1022 cancellation,
1023 execution_scope: None,
1024 approved_request: None,
1025 },
1026 }
1027 }
1028
1029 async fn next_event(handle: &TaskManagerHandle) -> TaskEvent {
1030 timeout(Duration::from_secs(1), handle.next_event())
1031 .await
1032 .expect("timed out waiting for task event")
1033 .expect("task event stream ended unexpectedly")
1034 }
1035
1036 async fn wait_until_entered(entered: &AtomicBool) {
1037 timeout(Duration::from_secs(1), async {
1038 while !entered.load(AtomicOrdering::SeqCst) {
1039 tokio::task::yield_now().await;
1040 }
1041 })
1042 .await
1043 .expect("task never entered execution");
1044 }
1045
1046 #[tokio::test]
1047 async fn simple_task_manager_executes_inline_and_assigns_task_ids() {
1048 let manager = SimpleTaskManager::new();
1049 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1050 "needs-approval",
1051 TestBehavior::Approval,
1052 )]));
1053 let request = make_request("needs-approval", "turn-1", "call-1");
1054
1055 let outcome = manager
1056 .start_task(
1057 TaskLaunchRequest {
1058 task_id: None,
1059 request: request.clone(),
1060 kind: TaskLaunchKind::Plain,
1061 },
1062 make_context(executor, &request.turn_id, None),
1063 )
1064 .await
1065 .unwrap();
1066
1067 match outcome {
1068 TaskStartOutcome::Ready(resolution) => match *resolution {
1069 TaskResolution::Approval(task) => {
1070 assert!(!task.task_id.0.is_empty());
1071 assert_eq!(task.approval.task_id.as_ref(), Some(&task.task_id));
1072 assert_eq!(task.tool_request.call_id, request.call_id);
1073 }
1074 other => panic!("unexpected task resolution: {other:?}"),
1075 },
1076 other => panic!("unexpected start outcome: {other:?}"),
1077 }
1078
1079 assert!(manager.handle().list_running().await.is_empty());
1080 }
1081
1082 #[tokio::test]
1083 async fn async_manager_interrupt_cancels_foreground_only() {
1084 let fg_release = StdArc::new(Notify::new());
1085 let fg_entered = StdArc::new(AtomicBool::new(false));
1086 let bg_release = StdArc::new(Notify::new());
1087 let bg_entered = StdArc::new(AtomicBool::new(false));
1088 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([
1089 (
1090 "foreground",
1091 TestBehavior::Block {
1092 entered: fg_entered.clone(),
1093 release: fg_release.clone(),
1094 output: "foreground-done",
1095 },
1096 ),
1097 (
1098 "background",
1099 TestBehavior::Block {
1100 entered: bg_entered.clone(),
1101 release: bg_release.clone(),
1102 output: "background-done",
1103 },
1104 ),
1105 ]));
1106 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([
1107 ("foreground", RoutingDecision::Foreground),
1108 ("background", RoutingDecision::Background),
1109 ]));
1110 let handle = manager.handle();
1111 let turn_id = TurnId::new("turn-1");
1112
1113 let foreground = manager
1114 .start_task(
1115 TaskLaunchRequest {
1116 task_id: None,
1117 request: make_request("foreground", "turn-1", "call-fg"),
1118 kind: TaskLaunchKind::Plain,
1119 },
1120 make_context(executor.clone(), &turn_id, None),
1121 )
1122 .await
1123 .unwrap();
1124 let background = manager
1125 .start_task(
1126 TaskLaunchRequest {
1127 task_id: None,
1128 request: make_request("background", "turn-1", "call-bg"),
1129 kind: TaskLaunchKind::Plain,
1130 },
1131 make_context(executor.clone(), &turn_id, None),
1132 )
1133 .await
1134 .unwrap();
1135
1136 assert!(matches!(
1137 foreground,
1138 TaskStartOutcome::Pending {
1139 kind: TaskKind::Foreground,
1140 ..
1141 }
1142 ));
1143 let background_id = match background {
1144 TaskStartOutcome::Pending {
1145 task_id,
1146 kind: TaskKind::Background,
1147 } => task_id,
1148 other => panic!("unexpected background outcome: {other:?}"),
1149 };
1150
1151 let _ = next_event(&handle).await;
1152 let _ = next_event(&handle).await;
1153 wait_until_entered(fg_entered.as_ref()).await;
1154 wait_until_entered(bg_entered.as_ref()).await;
1155
1156 manager.on_turn_interrupted(&turn_id).await.unwrap();
1157
1158 match next_event(&handle).await {
1159 TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.tool_name, "foreground"),
1160 other => panic!("unexpected event after interrupt: {other:?}"),
1161 }
1162
1163 let running = handle.list_running().await;
1164 assert_eq!(running.len(), 1);
1165 assert_eq!(running[0].id, background_id);
1166 assert_eq!(running[0].tool_name, "background");
1167
1168 bg_release.notify_waiters();
1169 match next_event(&handle).await {
1170 TaskEvent::Completed(snapshot, result) => {
1171 assert_eq!(snapshot.id, background_id);
1172 assert_eq!(result.output, ToolOutput::Text("background-done".into()));
1173 }
1174 other => panic!("unexpected completion event: {other:?}"),
1175 }
1176 }
1177
1178 #[tokio::test]
1179 async fn async_manager_can_cancel_background_tasks_by_id() {
1180 let release = StdArc::new(Notify::new());
1181 let entered = StdArc::new(AtomicBool::new(false));
1182 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1183 "background",
1184 TestBehavior::Block {
1185 entered: entered.clone(),
1186 release,
1187 output: "done",
1188 },
1189 )]));
1190 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1191 "background",
1192 RoutingDecision::Background,
1193 )]));
1194 let handle = manager.handle();
1195 let request = make_request("background", "turn-1", "call-1");
1196
1197 let task_id = match manager
1198 .start_task(
1199 TaskLaunchRequest {
1200 task_id: None,
1201 request: request.clone(),
1202 kind: TaskLaunchKind::Plain,
1203 },
1204 make_context(executor, &request.turn_id, None),
1205 )
1206 .await
1207 .unwrap()
1208 {
1209 TaskStartOutcome::Pending { task_id, .. } => task_id,
1210 other => panic!("unexpected start outcome: {other:?}"),
1211 };
1212
1213 let _ = next_event(&handle).await;
1214 wait_until_entered(entered.as_ref()).await;
1215 handle.cancel(task_id.clone()).await.unwrap();
1216
1217 match next_event(&handle).await {
1218 TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.id, task_id),
1219 other => panic!("unexpected event after cancel: {other:?}"),
1220 }
1221
1222 assert!(handle.list_running().await.is_empty());
1223 }
1224
1225 #[tokio::test]
1226 async fn async_manager_manual_delivery_keeps_results_out_of_loop_updates() {
1227 let release = StdArc::new(Notify::new());
1228 let entered = StdArc::new(AtomicBool::new(false));
1229 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1230 "background",
1231 TestBehavior::Block {
1232 entered: entered.clone(),
1233 release: release.clone(),
1234 output: "manual-done",
1235 },
1236 )]));
1237 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1238 "background",
1239 RoutingDecision::Background,
1240 )]));
1241 let handle = manager.handle();
1242 let request = make_request("background", "turn-1", "call-1");
1243
1244 let task_id = match manager
1245 .start_task(
1246 TaskLaunchRequest {
1247 task_id: None,
1248 request: request.clone(),
1249 kind: TaskLaunchKind::Plain,
1250 },
1251 make_context(executor, &request.turn_id, None),
1252 )
1253 .await
1254 .unwrap()
1255 {
1256 TaskStartOutcome::Pending { task_id, .. } => task_id,
1257 other => panic!("unexpected start outcome: {other:?}"),
1258 };
1259
1260 let _ = next_event(&handle).await;
1261 wait_until_entered(entered.as_ref()).await;
1262 handle
1263 .set_continue_policy(task_id.clone(), ContinuePolicy::RequestContinue)
1264 .await
1265 .unwrap();
1266 handle
1267 .set_delivery_mode(task_id, DeliveryMode::Manual)
1268 .await
1269 .unwrap();
1270
1271 release.notify_waiters();
1272 match next_event(&handle).await {
1273 TaskEvent::Completed(_, result) => {
1274 assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1275 }
1276 other => panic!("unexpected event: {other:?}"),
1277 }
1278
1279 assert!(
1280 timeout(Duration::from_millis(50), handle.next_event())
1281 .await
1282 .is_err()
1283 );
1284 assert!(
1285 manager
1286 .take_pending_loop_updates()
1287 .await
1288 .unwrap()
1289 .resolutions
1290 .is_empty()
1291 );
1292
1293 let ready_items = handle.drain_ready_items().await;
1294 assert_eq!(ready_items.len(), 1);
1295 assert_eq!(ready_items[0].kind, ItemKind::Tool);
1296 match &ready_items[0].parts[0] {
1297 Part::ToolResult(result) => {
1298 assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1299 }
1300 other => panic!("unexpected ready item: {other:?}"),
1301 }
1302 }
1303
1304 #[tokio::test]
1305 async fn async_manager_to_loop_delivery_can_request_continue() {
1306 let release = StdArc::new(Notify::new());
1307 let entered = StdArc::new(AtomicBool::new(false));
1308 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1309 "background",
1310 TestBehavior::Block {
1311 entered: entered.clone(),
1312 release: release.clone(),
1313 output: "loop-done",
1314 },
1315 )]));
1316 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1317 "background",
1318 RoutingDecision::Background,
1319 )]));
1320 let handle = manager.handle();
1321 let request = make_request("background", "turn-1", "call-1");
1322
1323 let task_id = match manager
1324 .start_task(
1325 TaskLaunchRequest {
1326 task_id: None,
1327 request: request.clone(),
1328 kind: TaskLaunchKind::Plain,
1329 },
1330 make_context(
1331 executor,
1332 &request.turn_id,
1333 Some(TurnCancellation::new(
1334 CancellationController::new().handle(),
1335 )),
1336 ),
1337 )
1338 .await
1339 .unwrap()
1340 {
1341 TaskStartOutcome::Pending { task_id, .. } => task_id,
1342 other => panic!("unexpected start outcome: {other:?}"),
1343 };
1344
1345 let _ = next_event(&handle).await;
1346 wait_until_entered(entered.as_ref()).await;
1347 handle
1348 .set_continue_policy(task_id, ContinuePolicy::RequestContinue)
1349 .await
1350 .unwrap();
1351
1352 release.notify_waiters();
1353 match next_event(&handle).await {
1354 TaskEvent::Completed(_, result) => {
1355 assert_eq!(result.output, ToolOutput::Text("loop-done".into()))
1356 }
1357 other => panic!("unexpected completion event: {other:?}"),
1358 }
1359 match next_event(&handle).await {
1360 TaskEvent::ContinueRequested => {}
1361 other => panic!("unexpected follow-up event: {other:?}"),
1362 }
1363
1364 let updates = manager.take_pending_loop_updates().await.unwrap();
1365 assert_eq!(updates.resolutions.len(), 1);
1366 assert!(handle.drain_ready_items().await.is_empty());
1367 }
1368
1369 #[tokio::test]
1370 async fn wait_for_idle_returns_after_loop_updates_are_queued() {
1371 let release = StdArc::new(Notify::new());
1372 let entered = StdArc::new(AtomicBool::new(false));
1373 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1374 "background",
1375 TestBehavior::Block {
1376 entered: entered.clone(),
1377 release: release.clone(),
1378 output: "idle-done",
1379 },
1380 )]));
1381 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1382 "background",
1383 RoutingDecision::Background,
1384 )]));
1385 let handle = manager.handle();
1386 let request = make_request("background", "turn-1", "call-1");
1387
1388 let outcome = manager
1389 .start_task(
1390 TaskLaunchRequest {
1391 task_id: None,
1392 request: request.clone(),
1393 kind: TaskLaunchKind::Plain,
1394 },
1395 make_context(executor, &request.turn_id, None),
1396 )
1397 .await
1398 .unwrap();
1399 assert!(matches!(outcome, TaskStartOutcome::Pending { .. }));
1400
1401 let _ = next_event(&handle).await;
1402 wait_until_entered(entered.as_ref()).await;
1403 release.notify_waiters();
1404
1405 timeout(Duration::from_secs(1), handle.wait_for_idle())
1406 .await
1407 .expect("wait_for_idle timed out");
1408
1409 let updates = manager.take_pending_loop_updates().await.unwrap();
1410 assert_eq!(updates.resolutions.len(), 1);
1411 match &updates.resolutions[0] {
1412 TaskResolution::Item(item) => match &item.parts[0] {
1413 Part::ToolResult(result) => {
1414 assert_eq!(result.call_id, request.call_id);
1415 assert_eq!(result.output, ToolOutput::Text("idle-done".into()));
1416 }
1417 other => panic!("unexpected tool item: {other:?}"),
1418 },
1419 other => panic!("unexpected pending update: {other:?}"),
1420 }
1421 }
1422}