1use std::collections::{BTreeMap, VecDeque};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use agentkit_core::{
7 Item, MetadataMap, TaskId, ToolCallId, ToolResultPart, TurnCancellation, TurnId,
8};
9use agentkit_tools_core::{
10 ApprovalRequest, AuthRequest, OwnedToolContext, ToolError, ToolExecutionOutcome, ToolExecutor,
11 ToolRequest,
12};
13use async_trait::async_trait;
14use thiserror::Error;
15use tokio::sync::{Mutex, Notify, mpsc};
16use tokio::task::JoinHandle;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq)]
19pub enum TaskKind {
20 Foreground,
21 Background,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub enum ContinuePolicy {
26 NotifyOnly,
27 RequestContinue,
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum DeliveryMode {
32 ToLoop,
33 Manual,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq)]
37pub struct TaskSnapshot {
38 pub id: TaskId,
39 pub turn_id: TurnId,
40 pub call_id: ToolCallId,
41 pub tool_name: String,
42 pub kind: TaskKind,
43 pub metadata: MetadataMap,
44}
45
46#[derive(Clone, Debug, PartialEq)]
47pub enum TaskEvent {
48 Started(TaskSnapshot),
49 Detached(TaskSnapshot),
50 Completed(TaskSnapshot, ToolResultPart),
51 Cancelled(TaskSnapshot),
52 Failed(TaskSnapshot, ToolError),
53 ContinueRequested,
54}
55
56#[derive(Clone, Debug, PartialEq)]
57pub struct TaskApproval {
58 pub task_id: TaskId,
59 pub tool_request: ToolRequest,
60 pub approval: ApprovalRequest,
61}
62
63#[derive(Clone, Debug, PartialEq)]
64pub struct TaskAuth {
65 pub task_id: TaskId,
66 pub tool_request: ToolRequest,
67 pub auth: AuthRequest,
68}
69
70#[derive(Clone, Debug, PartialEq)]
71pub enum TaskResolution {
72 Item(Item),
73 Approval(TaskApproval),
74 Auth(TaskAuth),
75}
76
77#[derive(Clone, Debug, PartialEq)]
78pub enum TaskStartOutcome {
79 Ready(Box<TaskResolution>),
80 Pending { task_id: TaskId, kind: TaskKind },
81}
82
83#[derive(Clone, Debug, PartialEq)]
84pub enum TurnTaskUpdate {
85 Resolution(Box<TaskResolution>),
86 Detached(TaskSnapshot),
87}
88
89#[derive(Clone, Debug, Default, PartialEq)]
90pub struct PendingLoopUpdates {
91 pub resolutions: VecDeque<TaskResolution>,
92}
93
94#[derive(Clone, Debug)]
95pub struct TaskLaunchRequest {
96 pub task_id: Option<TaskId>,
97 pub request: ToolRequest,
98 pub approved_request: Option<ApprovalRequest>,
99}
100
101#[derive(Clone)]
102pub struct TaskStartContext {
103 pub executor: Arc<dyn ToolExecutor>,
104 pub tool_context: OwnedToolContext,
105}
106
107#[derive(Debug, Error, Clone, PartialEq, Eq)]
108pub enum TaskManagerError {
109 #[error("task not found: {0}")]
110 NotFound(TaskId),
111 #[error("task manager internal error: {0}")]
112 Internal(String),
113}
114
115pub trait TaskRoutingPolicy: Send + Sync {
116 fn route(&self, request: &ToolRequest) -> RoutingDecision;
117}
118
119impl<F> TaskRoutingPolicy for F
120where
121 F: Fn(&ToolRequest) -> RoutingDecision + Send + Sync,
122{
123 fn route(&self, request: &ToolRequest) -> RoutingDecision {
124 self(request)
125 }
126}
127
128#[derive(Clone, Copy, Debug, PartialEq, Eq)]
129pub enum RoutingDecision {
130 Foreground,
131 Background,
132 ForegroundThenDetachAfter(Duration),
133}
134
135struct DefaultRoutingPolicy;
136
137impl TaskRoutingPolicy for DefaultRoutingPolicy {
138 fn route(&self, _request: &ToolRequest) -> RoutingDecision {
139 RoutingDecision::Foreground
140 }
141}
142
143#[async_trait]
144pub trait TaskManager: Send + Sync {
145 async fn start_task(
146 &self,
147 request: TaskLaunchRequest,
148 ctx: TaskStartContext,
149 ) -> Result<TaskStartOutcome, TaskManagerError>;
150
151 async fn wait_for_turn(
152 &self,
153 turn_id: &TurnId,
154 cancellation: Option<TurnCancellation>,
155 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError>;
156
157 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError>;
158
159 async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError>;
160
161 fn handle(&self) -> TaskManagerHandle;
162}
163
164#[async_trait]
165trait TaskManagerControl: Send + Sync {
166 async fn next_event(&self) -> Option<TaskEvent>;
167 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError>;
168 async fn list_running(&self) -> Vec<TaskSnapshot>;
169 async fn list_completed(&self) -> Vec<TaskSnapshot>;
170 async fn drain_ready_items(&self) -> Vec<Item>;
171 async fn set_continue_policy(
172 &self,
173 task_id: TaskId,
174 policy: ContinuePolicy,
175 ) -> Result<(), TaskManagerError>;
176 async fn set_delivery_mode(
177 &self,
178 task_id: TaskId,
179 mode: DeliveryMode,
180 ) -> Result<(), TaskManagerError>;
181 async fn wait_for_idle(&self);
182}
183
184#[derive(Clone)]
185pub struct TaskManagerHandle {
186 inner: Arc<dyn TaskManagerControl>,
187}
188
189impl TaskManagerHandle {
190 pub async fn next_event(&self) -> Option<TaskEvent> {
191 self.inner.next_event().await
192 }
193
194 pub async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
195 self.inner.cancel(task_id).await
196 }
197
198 pub async fn list_running(&self) -> Vec<TaskSnapshot> {
199 self.inner.list_running().await
200 }
201
202 pub async fn list_completed(&self) -> Vec<TaskSnapshot> {
203 self.inner.list_completed().await
204 }
205
206 pub async fn drain_ready_items(&self) -> Vec<Item> {
207 self.inner.drain_ready_items().await
208 }
209
210 pub async fn set_continue_policy(
211 &self,
212 task_id: TaskId,
213 policy: ContinuePolicy,
214 ) -> Result<(), TaskManagerError> {
215 self.inner.set_continue_policy(task_id, policy).await
216 }
217
218 pub async fn set_delivery_mode(
219 &self,
220 task_id: TaskId,
221 mode: DeliveryMode,
222 ) -> Result<(), TaskManagerError> {
223 self.inner.set_delivery_mode(task_id, mode).await
224 }
225
226 pub async fn wait_for_idle(&self) {
228 self.inner.wait_for_idle().await
229 }
230}
231
232pub struct SimpleTaskManager {
233 state: Arc<HandleState>,
234}
235
236impl SimpleTaskManager {
237 pub fn new() -> Self {
238 Self {
239 state: Arc::new(HandleState::default()),
240 }
241 }
242}
243
244impl Default for SimpleTaskManager {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250#[async_trait]
251impl TaskManager for SimpleTaskManager {
252 async fn start_task(
253 &self,
254 request: TaskLaunchRequest,
255 ctx: TaskStartContext,
256 ) -> Result<TaskStartOutcome, TaskManagerError> {
257 let task_id = request
258 .task_id
259 .clone()
260 .unwrap_or_else(|| self.state.next_task_id());
261 let outcome = match request.approved_request.as_ref() {
262 Some(approved) => {
263 ctx.executor
264 .execute_approved_owned(request.request.clone(), approved, ctx.tool_context)
265 .await
266 }
267 None => {
268 ctx.executor
269 .execute_owned(request.request.clone(), ctx.tool_context)
270 .await
271 }
272 };
273 Ok(TaskStartOutcome::Ready(Box::new(
274 map_outcome_to_resolution(Some(task_id), request.request, outcome),
275 )))
276 }
277
278 async fn wait_for_turn(
279 &self,
280 _turn_id: &TurnId,
281 _cancellation: Option<TurnCancellation>,
282 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
283 Ok(None)
284 }
285
286 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
287 Ok(PendingLoopUpdates::default())
288 }
289
290 async fn on_turn_interrupted(&self, _turn_id: &TurnId) -> Result<(), TaskManagerError> {
291 Ok(())
292 }
293
294 fn handle(&self) -> TaskManagerHandle {
295 TaskManagerHandle {
296 inner: self.state.clone(),
297 }
298 }
299}
300
301#[derive(Default)]
302struct HandleState {
303 next_task_index: AtomicU64,
304 events_rx: Mutex<Option<mpsc::UnboundedReceiver<TaskEvent>>>,
305}
306
307impl HandleState {
308 fn next_task_id(&self) -> TaskId {
309 let next = self.next_task_index.fetch_add(1, Ordering::SeqCst) + 1;
310 TaskId::new(format!("task-{}", next))
311 }
312}
313
314#[async_trait]
315impl TaskManagerControl for HandleState {
316 async fn next_event(&self) -> Option<TaskEvent> {
317 let mut rx = self.events_rx.lock().await;
318 match rx.as_mut() {
319 Some(inner) => inner.recv().await,
320 None => None,
321 }
322 }
323
324 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
325 Err(TaskManagerError::NotFound(task_id))
326 }
327
328 async fn list_running(&self) -> Vec<TaskSnapshot> {
329 Vec::new()
330 }
331
332 async fn list_completed(&self) -> Vec<TaskSnapshot> {
333 Vec::new()
334 }
335
336 async fn drain_ready_items(&self) -> Vec<Item> {
337 Vec::new()
338 }
339
340 async fn set_continue_policy(
341 &self,
342 task_id: TaskId,
343 _policy: ContinuePolicy,
344 ) -> Result<(), TaskManagerError> {
345 Err(TaskManagerError::NotFound(task_id))
346 }
347
348 async fn set_delivery_mode(
349 &self,
350 task_id: TaskId,
351 _mode: DeliveryMode,
352 ) -> Result<(), TaskManagerError> {
353 Err(TaskManagerError::NotFound(task_id))
354 }
355
356 async fn wait_for_idle(&self) {}
357}
358
359pub struct AsyncTaskManager {
360 inner: Arc<AsyncInner>,
361 routing: Arc<dyn TaskRoutingPolicy>,
362}
363
364impl AsyncTaskManager {
365 pub fn new() -> Self {
366 let (event_tx, event_rx) = mpsc::unbounded_channel();
367 Self {
368 inner: Arc::new(AsyncInner {
369 state: Mutex::new(AsyncState::default()),
370 host_event_tx: event_tx,
371 host_event_rx: Mutex::new(event_rx),
372 notify: Notify::new(),
373 }),
374 routing: Arc::new(DefaultRoutingPolicy),
375 }
376 }
377
378 pub fn routing(mut self, policy: impl TaskRoutingPolicy + 'static) -> Self {
379 self.routing = Arc::new(policy);
380 self
381 }
382}
383
384impl Default for AsyncTaskManager {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390#[derive(Default)]
391struct AsyncState {
392 next_task_index: u64,
393 tasks: BTreeMap<TaskId, TaskRecord>,
394 per_turn_running: BTreeMap<TurnId, usize>,
395 per_turn_updates: BTreeMap<TurnId, VecDeque<TurnTaskUpdate>>,
396 pending_loop_updates: VecDeque<TaskResolution>,
397 manual_ready_items: Vec<Item>,
398}
399
400struct TaskRecord {
401 snapshot: TaskSnapshot,
402 continue_policy: ContinuePolicy,
403 delivery_mode: DeliveryMode,
404 running: bool,
405 completed: bool,
406 join: Option<JoinHandle<()>>,
407}
408
409struct AsyncInner {
410 state: Mutex<AsyncState>,
411 host_event_tx: mpsc::UnboundedSender<TaskEvent>,
412 host_event_rx: Mutex<mpsc::UnboundedReceiver<TaskEvent>>,
413 notify: Notify,
414}
415
416impl AsyncInner {
417 async fn next_task_id(&self) -> TaskId {
418 let mut state = self.state.lock().await;
419 state.next_task_index += 1;
420 TaskId::new(format!("task-{}", state.next_task_index))
421 }
422}
423
424#[async_trait]
425impl TaskManager for AsyncTaskManager {
426 async fn start_task(
427 &self,
428 request: TaskLaunchRequest,
429 ctx: TaskStartContext,
430 ) -> Result<TaskStartOutcome, TaskManagerError> {
431 let route = self.routing.route(&request.request);
432 let task_id = match request.task_id.clone() {
433 Some(existing) => existing,
434 None => self.inner.next_task_id().await,
435 };
436 let initial_kind = match route {
437 RoutingDecision::Background => TaskKind::Background,
438 _ => TaskKind::Foreground,
439 };
440 let snapshot = TaskSnapshot {
441 id: task_id.clone(),
442 turn_id: request.request.turn_id.clone(),
443 call_id: request.request.call_id.clone(),
444 tool_name: request.request.tool_name.to_string(),
445 kind: initial_kind,
446 metadata: request.request.metadata.clone(),
447 };
448 let _ = self
449 .inner
450 .host_event_tx
451 .send(TaskEvent::Started(snapshot.clone()));
452
453 let mut state = self.inner.state.lock().await;
454 state.tasks.insert(
455 task_id.clone(),
456 TaskRecord {
457 snapshot: snapshot.clone(),
458 continue_policy: ContinuePolicy::NotifyOnly,
459 delivery_mode: DeliveryMode::ToLoop,
460 running: true,
461 completed: false,
462 join: None,
463 },
464 );
465 if initial_kind == TaskKind::Foreground {
466 *state
467 .per_turn_running
468 .entry(snapshot.turn_id.clone())
469 .or_default() += 1;
470 }
471 drop(state);
472
473 let event_tx = self.inner.host_event_tx.clone();
474 let inner = self.inner.clone();
475 let task_id_for_future = task_id.clone();
476 let turn_id = snapshot.turn_id.clone();
477 let approved = request.approved_request.clone();
478 let exec_request = request.request.clone();
479 let owned_ctx = ctx.tool_context.clone();
480 let executor = ctx.executor.clone();
481 let route_copy = route;
482 let join = tokio::spawn(async move {
483 if let RoutingDecision::ForegroundThenDetachAfter(duration) = route_copy {
484 let event_tx = event_tx.clone();
485 let inner = inner.clone();
486 let task_id = task_id_for_future.clone();
487 let turn_id = turn_id.clone();
488 tokio::spawn(async move {
489 tokio::time::sleep(duration).await;
490 let mut state = inner.state.lock().await;
491 let snapshot = if let Some(record) = state.tasks.get_mut(&task_id)
492 && record.running
493 && record.snapshot.kind == TaskKind::Foreground
494 {
495 record.snapshot.kind = TaskKind::Background;
496 Some(record.snapshot.clone())
497 } else {
498 None
499 };
500 if let Some(snapshot) = snapshot {
501 if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
502 *count = count.saturating_sub(1);
503 if *count == 0 {
504 state.per_turn_running.remove(&turn_id);
505 }
506 }
507 state
508 .per_turn_updates
509 .entry(turn_id.clone())
510 .or_default()
511 .push_back(TurnTaskUpdate::Detached(snapshot.clone()));
512 let _ = event_tx.send(TaskEvent::Detached(snapshot));
513 inner.notify.notify_waiters();
514 }
515 });
516 }
517
518 let outcome = match approved.as_ref() {
519 Some(approval) => {
520 executor
521 .execute_approved_owned(exec_request.clone(), approval, owned_ctx)
522 .await
523 }
524 None => {
525 executor
526 .execute_owned(exec_request.clone(), owned_ctx)
527 .await
528 }
529 };
530
531 let resolution =
532 map_outcome_to_resolution(Some(task_id_for_future.clone()), exec_request, outcome);
533 let completed_result = match &resolution {
534 TaskResolution::Item(item) => item.parts.iter().find_map(|part| match part {
535 agentkit_core::Part::ToolResult(result) => Some(result.clone()),
536 _ => None,
537 }),
538 TaskResolution::Approval(_) | TaskResolution::Auth(_) => None,
539 };
540
541 let (snapshot, should_request_continue) = {
542 let mut state = inner.state.lock().await;
543 let Some(record) = state.tasks.get_mut(&task_id_for_future) else {
544 return;
545 };
546 record.running = false;
547 record.completed = true;
548 let snapshot = record.snapshot.clone();
549 let continue_policy = record.continue_policy;
550 let delivery_mode = record.delivery_mode;
551 let current_kind = snapshot.kind;
552
553 if current_kind == TaskKind::Foreground {
554 if let Some(count) = state.per_turn_running.get_mut(&turn_id) {
555 *count = count.saturating_sub(1);
556 if *count == 0 {
557 state.per_turn_running.remove(&turn_id);
558 }
559 }
560 state
561 .per_turn_updates
562 .entry(turn_id.clone())
563 .or_default()
564 .push_back(TurnTaskUpdate::Resolution(Box::new(resolution.clone())));
565 } else {
566 match &resolution {
567 TaskResolution::Item(_) if delivery_mode == DeliveryMode::ToLoop => {
568 state.pending_loop_updates.push_back(resolution.clone());
569 }
570 TaskResolution::Approval(_) | TaskResolution::Auth(_)
571 if delivery_mode == DeliveryMode::ToLoop =>
572 {
573 state.pending_loop_updates.push_back(resolution.clone());
574 }
575 TaskResolution::Item(item) => {
576 state.manual_ready_items.push(item.clone());
577 }
578 TaskResolution::Approval(_) | TaskResolution::Auth(_) => {}
579 }
580 }
581
582 (
583 snapshot,
584 current_kind == TaskKind::Background
585 && delivery_mode == DeliveryMode::ToLoop
586 && continue_policy == ContinuePolicy::RequestContinue,
587 )
588 };
589
590 if let Some(result) = completed_result {
591 let _ = event_tx.send(TaskEvent::Completed(snapshot.clone(), result));
592 }
593 if should_request_continue {
594 let _ = event_tx.send(TaskEvent::ContinueRequested);
595 }
596 inner.notify.notify_waiters();
597 });
598
599 let mut state = self.inner.state.lock().await;
600 if let Some(record) = state.tasks.get_mut(&task_id) {
601 record.join = Some(join);
602 }
603 Ok(TaskStartOutcome::Pending {
604 task_id,
605 kind: initial_kind,
606 })
607 }
608
609 async fn wait_for_turn(
610 &self,
611 turn_id: &TurnId,
612 cancellation: Option<TurnCancellation>,
613 ) -> Result<Option<TurnTaskUpdate>, TaskManagerError> {
614 loop {
615 {
616 let mut state = self.inner.state.lock().await;
617 if let Some(queue) = state.per_turn_updates.get_mut(turn_id)
618 && let Some(update) = queue.pop_front()
619 {
620 return Ok(Some(update));
621 }
622 if state
623 .per_turn_running
624 .get(turn_id)
625 .copied()
626 .unwrap_or_default()
627 == 0
628 {
629 return Ok(None);
630 }
631 }
632 if cancellation
633 .as_ref()
634 .is_some_and(TurnCancellation::is_cancelled)
635 {
636 return Ok(None);
637 }
638 if let Some(cancellation) = cancellation.as_ref() {
639 tokio::select! {
640 _ = self.inner.notify.notified() => {}
641 _ = cancellation.cancelled() => return Ok(None),
642 }
643 } else {
644 self.inner.notify.notified().await;
645 }
646 }
647 }
648
649 async fn take_pending_loop_updates(&self) -> Result<PendingLoopUpdates, TaskManagerError> {
650 let mut state = self.inner.state.lock().await;
651 Ok(PendingLoopUpdates {
652 resolutions: std::mem::take(&mut state.pending_loop_updates),
653 })
654 }
655
656 async fn on_turn_interrupted(&self, turn_id: &TurnId) -> Result<(), TaskManagerError> {
657 let mut state = self.inner.state.lock().await;
658 let interrupted: Vec<TaskId> = state
659 .tasks
660 .iter()
661 .filter_map(|(id, record)| {
662 (record.snapshot.turn_id == *turn_id
663 && record.snapshot.kind == TaskKind::Foreground
664 && record.running)
665 .then_some(id.clone())
666 })
667 .collect();
668 for task_id in interrupted {
669 if let Some(record) = state.tasks.get_mut(&task_id) {
670 record.running = false;
671 if let Some(join) = record.join.take() {
672 join.abort();
673 }
674 let snapshot = record.snapshot.clone();
675 let _ = self
676 .inner
677 .host_event_tx
678 .send(TaskEvent::Cancelled(snapshot));
679 }
680 }
681 state.per_turn_running.remove(turn_id);
682 self.inner.notify.notify_waiters();
683 Ok(())
684 }
685
686 fn handle(&self) -> TaskManagerHandle {
687 TaskManagerHandle {
688 inner: self.inner.clone(),
689 }
690 }
691}
692
693#[async_trait]
694impl TaskManagerControl for AsyncInner {
695 async fn next_event(&self) -> Option<TaskEvent> {
696 self.host_event_rx.lock().await.recv().await
697 }
698
699 async fn cancel(&self, task_id: TaskId) -> Result<(), TaskManagerError> {
700 let mut state = self.state.lock().await;
701 let record = state
702 .tasks
703 .get_mut(&task_id)
704 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
705 if let Some(join) = record.join.take() {
706 join.abort();
707 }
708 record.running = false;
709 let snapshot = record.snapshot.clone();
710 if record.snapshot.kind == TaskKind::Foreground
711 && let Some(count) = state.per_turn_running.get_mut(&snapshot.turn_id)
712 {
713 *count = count.saturating_sub(1);
714 if *count == 0 {
715 state.per_turn_running.remove(&snapshot.turn_id);
716 }
717 }
718 let _ = self.host_event_tx.send(TaskEvent::Cancelled(snapshot));
719 self.notify.notify_waiters();
720 Ok(())
721 }
722
723 async fn list_running(&self) -> Vec<TaskSnapshot> {
724 let state = self.state.lock().await;
725 state
726 .tasks
727 .values()
728 .filter(|record| record.running)
729 .map(|record| record.snapshot.clone())
730 .collect()
731 }
732
733 async fn list_completed(&self) -> Vec<TaskSnapshot> {
734 let state = self.state.lock().await;
735 state
736 .tasks
737 .values()
738 .filter(|record| record.completed)
739 .map(|record| record.snapshot.clone())
740 .collect()
741 }
742
743 async fn drain_ready_items(&self) -> Vec<Item> {
744 let mut state = self.state.lock().await;
745 std::mem::take(&mut state.manual_ready_items)
746 }
747
748 async fn set_continue_policy(
749 &self,
750 task_id: TaskId,
751 policy: ContinuePolicy,
752 ) -> Result<(), TaskManagerError> {
753 let mut state = self.state.lock().await;
754 let record = state
755 .tasks
756 .get_mut(&task_id)
757 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
758 record.continue_policy = policy;
759 Ok(())
760 }
761
762 async fn set_delivery_mode(
763 &self,
764 task_id: TaskId,
765 mode: DeliveryMode,
766 ) -> Result<(), TaskManagerError> {
767 let mut state = self.state.lock().await;
768 let record = state
769 .tasks
770 .get_mut(&task_id)
771 .ok_or_else(|| TaskManagerError::NotFound(task_id.clone()))?;
772 record.delivery_mode = mode;
773 Ok(())
774 }
775
776 async fn wait_for_idle(&self) {
777 loop {
778 {
779 let state = self.state.lock().await;
780 if !state.tasks.values().any(|r| r.running) {
781 return;
782 }
783 }
784 self.notify.notified().await;
785 }
786 }
787}
788
789fn map_outcome_to_resolution(
790 task_id: Option<TaskId>,
791 request: ToolRequest,
792 outcome: ToolExecutionOutcome,
793) -> TaskResolution {
794 match outcome {
795 ToolExecutionOutcome::Completed(result) => TaskResolution::Item(Item {
796 id: None,
797 kind: agentkit_core::ItemKind::Tool,
798 parts: vec![agentkit_core::Part::ToolResult(result.result)],
799 metadata: result.metadata,
800 }),
801 ToolExecutionOutcome::Interrupted(
802 agentkit_tools_core::ToolInterruption::ApprovalRequired(mut approval),
803 ) => {
804 let task_id = task_id.unwrap_or_default();
805 approval.task_id = Some(task_id.clone());
806 TaskResolution::Approval(TaskApproval {
807 task_id,
808 tool_request: request,
809 approval,
810 })
811 }
812 ToolExecutionOutcome::Interrupted(agentkit_tools_core::ToolInterruption::AuthRequired(
813 mut auth,
814 )) => {
815 let task_id = task_id.unwrap_or_default();
816 auth.task_id = Some(task_id.clone());
817 TaskResolution::Auth(TaskAuth {
818 task_id,
819 tool_request: request,
820 auth,
821 })
822 }
823 ToolExecutionOutcome::Failed(error) => TaskResolution::Item(Item {
824 id: None,
825 kind: agentkit_core::ItemKind::Tool,
826 parts: vec![agentkit_core::Part::ToolResult(ToolResultPart {
827 call_id: request.call_id,
828 output: agentkit_core::ToolOutput::Text(error.to_string()),
829 is_error: true,
830 metadata: request.metadata,
831 })],
832 metadata: MetadataMap::new(),
833 }),
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use std::collections::BTreeMap;
840 use std::sync::Arc as StdArc;
841 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
842
843 use agentkit_core::{
844 CancellationController, ItemKind, Part, SessionId, ToolOutput, TurnCancellation,
845 };
846 use agentkit_tools_core::{
847 ApprovalReason, PermissionChecker, PermissionDecision, ToolAnnotations, ToolInterruption,
848 ToolName, ToolResult, ToolSpec,
849 };
850 use serde_json::json;
851 use tokio::sync::Notify;
852 use tokio::time::{Duration, timeout};
853
854 use super::*;
855
856 struct AllowAllPermissions;
857
858 impl PermissionChecker for AllowAllPermissions {
859 fn evaluate(
860 &self,
861 _request: &dyn agentkit_tools_core::PermissionRequest,
862 ) -> PermissionDecision {
863 PermissionDecision::Allow
864 }
865 }
866
867 #[derive(Clone)]
868 enum TestBehavior {
869 Block {
870 entered: StdArc<AtomicBool>,
871 release: StdArc<Notify>,
872 output: &'static str,
873 },
874 Approval,
875 }
876
877 #[derive(Clone)]
878 struct TestExecutor {
879 behaviors: BTreeMap<String, TestBehavior>,
880 }
881
882 impl TestExecutor {
883 fn new(behaviors: impl IntoIterator<Item = (impl Into<String>, TestBehavior)>) -> Self {
884 Self {
885 behaviors: behaviors
886 .into_iter()
887 .map(|(name, behavior)| (name.into(), behavior))
888 .collect(),
889 }
890 }
891 }
892
893 #[async_trait]
894 impl ToolExecutor for TestExecutor {
895 fn specs(&self) -> Vec<ToolSpec> {
896 self.behaviors
897 .keys()
898 .map(|name| ToolSpec {
899 name: ToolName::new(name),
900 description: format!("test tool {name}"),
901 input_schema: json!({
902 "type": "object",
903 "properties": {},
904 "additionalProperties": false
905 }),
906 annotations: ToolAnnotations::default(),
907 metadata: MetadataMap::new(),
908 })
909 .collect()
910 }
911
912 async fn execute(
913 &self,
914 request: ToolRequest,
915 _ctx: &mut agentkit_tools_core::ToolContext<'_>,
916 ) -> ToolExecutionOutcome {
917 match self.behaviors.get(request.tool_name.0.as_str()) {
918 Some(TestBehavior::Block {
919 entered,
920 release,
921 output,
922 }) => {
923 entered.store(true, AtomicOrdering::SeqCst);
924 release.notified().await;
925 ToolExecutionOutcome::Completed(ToolResult {
926 result: ToolResultPart {
927 call_id: request.call_id,
928 output: ToolOutput::Text((*output).into()),
929 is_error: false,
930 metadata: request.metadata,
931 },
932 duration: None,
933 metadata: MetadataMap::new(),
934 })
935 }
936 Some(TestBehavior::Approval) => ToolExecutionOutcome::Interrupted(
937 ToolInterruption::ApprovalRequired(ApprovalRequest {
938 task_id: None,
939 call_id: Some(request.call_id.clone()),
940 id: "approval:test".into(),
941 request_kind: "tool.test".into(),
942 reason: ApprovalReason::SensitivePath,
943 summary: "requires approval".into(),
944 metadata: MetadataMap::new(),
945 }),
946 ),
947 None => ToolExecutionOutcome::Failed(ToolError::Unavailable(
948 request.tool_name.0.clone(),
949 )),
950 }
951 }
952 }
953
954 struct NameRoutingPolicy {
955 routes: BTreeMap<String, RoutingDecision>,
956 }
957
958 impl NameRoutingPolicy {
959 fn new(routes: impl IntoIterator<Item = (impl Into<String>, RoutingDecision)>) -> Self {
960 Self {
961 routes: routes
962 .into_iter()
963 .map(|(name, decision)| (name.into(), decision))
964 .collect(),
965 }
966 }
967 }
968
969 impl TaskRoutingPolicy for NameRoutingPolicy {
970 fn route(&self, request: &ToolRequest) -> RoutingDecision {
971 self.routes
972 .get(request.tool_name.0.as_str())
973 .copied()
974 .unwrap_or(RoutingDecision::Foreground)
975 }
976 }
977
978 fn make_request(tool_name: &str, turn_id: &str, call_id: &str) -> ToolRequest {
979 ToolRequest {
980 call_id: ToolCallId::new(call_id),
981 tool_name: ToolName::new(tool_name),
982 input: json!({}),
983 session_id: SessionId::new("session-1"),
984 turn_id: TurnId::new(turn_id),
985 metadata: MetadataMap::new(),
986 }
987 }
988
989 fn make_context(
990 executor: Arc<dyn ToolExecutor>,
991 turn_id: &TurnId,
992 cancellation: Option<TurnCancellation>,
993 ) -> TaskStartContext {
994 TaskStartContext {
995 executor,
996 tool_context: OwnedToolContext {
997 session_id: SessionId::new("session-1"),
998 turn_id: turn_id.clone(),
999 metadata: MetadataMap::new(),
1000 permissions: Arc::new(AllowAllPermissions),
1001 resources: Arc::new(()),
1002 cancellation,
1003 },
1004 }
1005 }
1006
1007 async fn next_event(handle: &TaskManagerHandle) -> TaskEvent {
1008 timeout(Duration::from_secs(1), handle.next_event())
1009 .await
1010 .expect("timed out waiting for task event")
1011 .expect("task event stream ended unexpectedly")
1012 }
1013
1014 async fn wait_until_entered(entered: &AtomicBool) {
1015 timeout(Duration::from_secs(1), async {
1016 while !entered.load(AtomicOrdering::SeqCst) {
1017 tokio::task::yield_now().await;
1018 }
1019 })
1020 .await
1021 .expect("task never entered execution");
1022 }
1023
1024 #[tokio::test]
1025 async fn simple_task_manager_executes_inline_and_assigns_task_ids() {
1026 let manager = SimpleTaskManager::new();
1027 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1028 "needs-approval",
1029 TestBehavior::Approval,
1030 )]));
1031 let request = make_request("needs-approval", "turn-1", "call-1");
1032
1033 let outcome = manager
1034 .start_task(
1035 TaskLaunchRequest {
1036 task_id: None,
1037 request: request.clone(),
1038 approved_request: None,
1039 },
1040 make_context(executor, &request.turn_id, None),
1041 )
1042 .await
1043 .unwrap();
1044
1045 match outcome {
1046 TaskStartOutcome::Ready(resolution) => match *resolution {
1047 TaskResolution::Approval(task) => {
1048 assert!(!task.task_id.0.is_empty());
1049 assert_eq!(task.approval.task_id.as_ref(), Some(&task.task_id));
1050 assert_eq!(task.tool_request.call_id, request.call_id);
1051 }
1052 other => panic!("unexpected task resolution: {other:?}"),
1053 },
1054 other => panic!("unexpected start outcome: {other:?}"),
1055 }
1056
1057 assert!(manager.handle().list_running().await.is_empty());
1058 }
1059
1060 #[tokio::test]
1061 async fn async_manager_interrupt_cancels_foreground_only() {
1062 let fg_release = StdArc::new(Notify::new());
1063 let fg_entered = StdArc::new(AtomicBool::new(false));
1064 let bg_release = StdArc::new(Notify::new());
1065 let bg_entered = StdArc::new(AtomicBool::new(false));
1066 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([
1067 (
1068 "foreground",
1069 TestBehavior::Block {
1070 entered: fg_entered.clone(),
1071 release: fg_release.clone(),
1072 output: "foreground-done",
1073 },
1074 ),
1075 (
1076 "background",
1077 TestBehavior::Block {
1078 entered: bg_entered.clone(),
1079 release: bg_release.clone(),
1080 output: "background-done",
1081 },
1082 ),
1083 ]));
1084 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([
1085 ("foreground", RoutingDecision::Foreground),
1086 ("background", RoutingDecision::Background),
1087 ]));
1088 let handle = manager.handle();
1089 let turn_id = TurnId::new("turn-1");
1090
1091 let foreground = manager
1092 .start_task(
1093 TaskLaunchRequest {
1094 task_id: None,
1095 request: make_request("foreground", "turn-1", "call-fg"),
1096 approved_request: None,
1097 },
1098 make_context(executor.clone(), &turn_id, None),
1099 )
1100 .await
1101 .unwrap();
1102 let background = manager
1103 .start_task(
1104 TaskLaunchRequest {
1105 task_id: None,
1106 request: make_request("background", "turn-1", "call-bg"),
1107 approved_request: None,
1108 },
1109 make_context(executor.clone(), &turn_id, None),
1110 )
1111 .await
1112 .unwrap();
1113
1114 assert!(matches!(
1115 foreground,
1116 TaskStartOutcome::Pending {
1117 kind: TaskKind::Foreground,
1118 ..
1119 }
1120 ));
1121 let background_id = match background {
1122 TaskStartOutcome::Pending {
1123 task_id,
1124 kind: TaskKind::Background,
1125 } => task_id,
1126 other => panic!("unexpected background outcome: {other:?}"),
1127 };
1128
1129 let _ = next_event(&handle).await;
1130 let _ = next_event(&handle).await;
1131 wait_until_entered(fg_entered.as_ref()).await;
1132 wait_until_entered(bg_entered.as_ref()).await;
1133
1134 manager.on_turn_interrupted(&turn_id).await.unwrap();
1135
1136 match next_event(&handle).await {
1137 TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.tool_name, "foreground"),
1138 other => panic!("unexpected event after interrupt: {other:?}"),
1139 }
1140
1141 let running = handle.list_running().await;
1142 assert_eq!(running.len(), 1);
1143 assert_eq!(running[0].id, background_id);
1144 assert_eq!(running[0].tool_name, "background");
1145
1146 bg_release.notify_waiters();
1147 match next_event(&handle).await {
1148 TaskEvent::Completed(snapshot, result) => {
1149 assert_eq!(snapshot.id, background_id);
1150 assert_eq!(result.output, ToolOutput::Text("background-done".into()));
1151 }
1152 other => panic!("unexpected completion event: {other:?}"),
1153 }
1154 }
1155
1156 #[tokio::test]
1157 async fn async_manager_can_cancel_background_tasks_by_id() {
1158 let release = StdArc::new(Notify::new());
1159 let entered = StdArc::new(AtomicBool::new(false));
1160 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1161 "background",
1162 TestBehavior::Block {
1163 entered: entered.clone(),
1164 release,
1165 output: "done",
1166 },
1167 )]));
1168 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1169 "background",
1170 RoutingDecision::Background,
1171 )]));
1172 let handle = manager.handle();
1173 let request = make_request("background", "turn-1", "call-1");
1174
1175 let task_id = match manager
1176 .start_task(
1177 TaskLaunchRequest {
1178 task_id: None,
1179 request: request.clone(),
1180 approved_request: None,
1181 },
1182 make_context(executor, &request.turn_id, None),
1183 )
1184 .await
1185 .unwrap()
1186 {
1187 TaskStartOutcome::Pending { task_id, .. } => task_id,
1188 other => panic!("unexpected start outcome: {other:?}"),
1189 };
1190
1191 let _ = next_event(&handle).await;
1192 wait_until_entered(entered.as_ref()).await;
1193 handle.cancel(task_id.clone()).await.unwrap();
1194
1195 match next_event(&handle).await {
1196 TaskEvent::Cancelled(snapshot) => assert_eq!(snapshot.id, task_id),
1197 other => panic!("unexpected event after cancel: {other:?}"),
1198 }
1199
1200 assert!(handle.list_running().await.is_empty());
1201 }
1202
1203 #[tokio::test]
1204 async fn async_manager_manual_delivery_keeps_results_out_of_loop_updates() {
1205 let release = StdArc::new(Notify::new());
1206 let entered = StdArc::new(AtomicBool::new(false));
1207 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1208 "background",
1209 TestBehavior::Block {
1210 entered: entered.clone(),
1211 release: release.clone(),
1212 output: "manual-done",
1213 },
1214 )]));
1215 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1216 "background",
1217 RoutingDecision::Background,
1218 )]));
1219 let handle = manager.handle();
1220 let request = make_request("background", "turn-1", "call-1");
1221
1222 let task_id = match manager
1223 .start_task(
1224 TaskLaunchRequest {
1225 task_id: None,
1226 request: request.clone(),
1227 approved_request: None,
1228 },
1229 make_context(executor, &request.turn_id, None),
1230 )
1231 .await
1232 .unwrap()
1233 {
1234 TaskStartOutcome::Pending { task_id, .. } => task_id,
1235 other => panic!("unexpected start outcome: {other:?}"),
1236 };
1237
1238 let _ = next_event(&handle).await;
1239 wait_until_entered(entered.as_ref()).await;
1240 handle
1241 .set_continue_policy(task_id.clone(), ContinuePolicy::RequestContinue)
1242 .await
1243 .unwrap();
1244 handle
1245 .set_delivery_mode(task_id, DeliveryMode::Manual)
1246 .await
1247 .unwrap();
1248
1249 release.notify_waiters();
1250 match next_event(&handle).await {
1251 TaskEvent::Completed(_, result) => {
1252 assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1253 }
1254 other => panic!("unexpected event: {other:?}"),
1255 }
1256
1257 assert!(
1258 timeout(Duration::from_millis(50), handle.next_event())
1259 .await
1260 .is_err()
1261 );
1262 assert!(
1263 manager
1264 .take_pending_loop_updates()
1265 .await
1266 .unwrap()
1267 .resolutions
1268 .is_empty()
1269 );
1270
1271 let ready_items = handle.drain_ready_items().await;
1272 assert_eq!(ready_items.len(), 1);
1273 assert_eq!(ready_items[0].kind, ItemKind::Tool);
1274 match &ready_items[0].parts[0] {
1275 Part::ToolResult(result) => {
1276 assert_eq!(result.output, ToolOutput::Text("manual-done".into()))
1277 }
1278 other => panic!("unexpected ready item: {other:?}"),
1279 }
1280 }
1281
1282 #[tokio::test]
1283 async fn async_manager_to_loop_delivery_can_request_continue() {
1284 let release = StdArc::new(Notify::new());
1285 let entered = StdArc::new(AtomicBool::new(false));
1286 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1287 "background",
1288 TestBehavior::Block {
1289 entered: entered.clone(),
1290 release: release.clone(),
1291 output: "loop-done",
1292 },
1293 )]));
1294 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1295 "background",
1296 RoutingDecision::Background,
1297 )]));
1298 let handle = manager.handle();
1299 let request = make_request("background", "turn-1", "call-1");
1300
1301 let task_id = match manager
1302 .start_task(
1303 TaskLaunchRequest {
1304 task_id: None,
1305 request: request.clone(),
1306 approved_request: None,
1307 },
1308 make_context(
1309 executor,
1310 &request.turn_id,
1311 Some(TurnCancellation::new(
1312 CancellationController::new().handle(),
1313 )),
1314 ),
1315 )
1316 .await
1317 .unwrap()
1318 {
1319 TaskStartOutcome::Pending { task_id, .. } => task_id,
1320 other => panic!("unexpected start outcome: {other:?}"),
1321 };
1322
1323 let _ = next_event(&handle).await;
1324 wait_until_entered(entered.as_ref()).await;
1325 handle
1326 .set_continue_policy(task_id, ContinuePolicy::RequestContinue)
1327 .await
1328 .unwrap();
1329
1330 release.notify_waiters();
1331 match next_event(&handle).await {
1332 TaskEvent::Completed(_, result) => {
1333 assert_eq!(result.output, ToolOutput::Text("loop-done".into()))
1334 }
1335 other => panic!("unexpected completion event: {other:?}"),
1336 }
1337 match next_event(&handle).await {
1338 TaskEvent::ContinueRequested => {}
1339 other => panic!("unexpected follow-up event: {other:?}"),
1340 }
1341
1342 let updates = manager.take_pending_loop_updates().await.unwrap();
1343 assert_eq!(updates.resolutions.len(), 1);
1344 assert!(handle.drain_ready_items().await.is_empty());
1345 }
1346
1347 #[tokio::test]
1348 async fn wait_for_idle_returns_after_loop_updates_are_queued() {
1349 let release = StdArc::new(Notify::new());
1350 let entered = StdArc::new(AtomicBool::new(false));
1351 let executor: Arc<dyn ToolExecutor> = Arc::new(TestExecutor::new([(
1352 "background",
1353 TestBehavior::Block {
1354 entered: entered.clone(),
1355 release: release.clone(),
1356 output: "idle-done",
1357 },
1358 )]));
1359 let manager = AsyncTaskManager::new().routing(NameRoutingPolicy::new([(
1360 "background",
1361 RoutingDecision::Background,
1362 )]));
1363 let handle = manager.handle();
1364 let request = make_request("background", "turn-1", "call-1");
1365
1366 let outcome = manager
1367 .start_task(
1368 TaskLaunchRequest {
1369 task_id: None,
1370 request: request.clone(),
1371 approved_request: None,
1372 },
1373 make_context(executor, &request.turn_id, None),
1374 )
1375 .await
1376 .unwrap();
1377 assert!(matches!(outcome, TaskStartOutcome::Pending { .. }));
1378
1379 let _ = next_event(&handle).await;
1380 wait_until_entered(entered.as_ref()).await;
1381 release.notify_waiters();
1382
1383 timeout(Duration::from_secs(1), handle.wait_for_idle())
1384 .await
1385 .expect("wait_for_idle timed out");
1386
1387 let updates = manager.take_pending_loop_updates().await.unwrap();
1388 assert_eq!(updates.resolutions.len(), 1);
1389 match &updates.resolutions[0] {
1390 TaskResolution::Item(item) => match &item.parts[0] {
1391 Part::ToolResult(result) => {
1392 assert_eq!(result.call_id, request.call_id);
1393 assert_eq!(result.output, ToolOutput::Text("idle-done".into()));
1394 }
1395 other => panic!("unexpected tool item: {other:?}"),
1396 },
1397 other => panic!("unexpected pending update: {other:?}"),
1398 }
1399 }
1400}