1pub mod types;
33pub use types::*;
34
35use crate::agent::backend::LlmBackend;
36use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
37use crate::tools::ToolRegistry;
38use async_trait::async_trait;
39use futures::future::join_all;
40use serde_json::Value;
41use std::collections::HashSet;
42use std::sync::Arc;
43use std::time::Instant;
44use tokio::time::timeout;
45
46fn to_backend_message(msg: &ConversationMessage) -> Message {
61 let tool_result = if msg.role == Role::Tool {
62 msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
63 tool_call_id: id.clone(),
64 content: serde_json::from_str(&msg.content)
65 .unwrap_or(serde_json::Value::String(msg.content.clone())),
66 success: true,
67 })
68 } else {
69 None
70 };
71
72 Message {
73 role: msg.role.clone(),
74 content: msg.content.clone(),
75 tool_calls: msg.tool_calls.clone(),
76 tool_result,
77 }
78}
79
80pub struct ToolCoordinator {
105 backend: Arc<dyn LlmBackend>,
106 registry: Arc<ToolRegistry>,
107 config: ToolCallingConfig,
108}
109
110impl ToolCoordinator {
111 pub fn new(
113 backend: Arc<dyn LlmBackend>,
114 registry: Arc<ToolRegistry>,
115 config: ToolCallingConfig,
116 ) -> Self {
117 Self {
118 backend,
119 registry,
120 config,
121 }
122 }
123
124 pub async fn execute(
128 &self,
129 system_prompt: Option<&str>,
130 user_prompt: &str,
131 ) -> crate::Result<CoordinatorResult> {
132 let mut messages: Vec<ConversationMessage> = Vec::new();
133 if let Some(sys) = system_prompt {
134 messages.push(ConversationMessage::system(sys));
135 }
136 messages.push(ConversationMessage::user(user_prompt));
137 self.execute_with_history(messages).await
138 }
139
140 pub async fn execute_with_history(
146 &self,
147 mut messages: Vec<ConversationMessage>,
148 ) -> crate::Result<CoordinatorResult> {
149 let tool_defs = self.registry.get_definitions();
150 let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
151 let mut total_usage = TokenUsage::default();
152
153 for iteration in 0..self.config.max_iterations {
154 let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
156
157 let response = self
159 .backend
160 .generate(&backend_messages, &tool_defs, None)
161 .await?;
162
163 if let Some(usage) = &response.usage {
165 total_usage.prompt_tokens += usage.prompt_tokens;
166 total_usage.completion_tokens += usage.completion_tokens;
167 total_usage.total_tokens += usage.total_tokens;
168 total_usage.reasoning_tokens += usage.reasoning_tokens;
169 total_usage.action_tokens += usage.action_tokens;
170 }
171
172 messages.push(ConversationMessage::assistant(
174 &response.content,
175 response.tool_calls.clone(),
176 ));
177
178 if response.tool_calls.is_empty() {
180 return Ok(CoordinatorResult {
181 content: response.content,
182 tool_calls: all_tool_calls,
183 iterations: iteration + 1,
184 finish_reason: FinishReason::Stop,
185 total_usage,
186 message_history: messages,
187 });
188 }
189
190 if response.content.is_empty() && response.tool_calls.is_empty() {
192 return Ok(CoordinatorResult {
193 content: String::new(),
194 tool_calls: all_tool_calls,
195 iterations: iteration + 1,
196 finish_reason: FinishReason::Stop,
197 total_usage,
198 message_history: messages,
199 });
200 }
201
202 for tc in &response.tool_calls {
204 if !self.registry.has_tool(&tc.name) {
205 return Ok(CoordinatorResult {
206 content: response.content,
207 tool_calls: all_tool_calls,
208 iterations: iteration + 1,
209 finish_reason: FinishReason::UnknownTool(tc.name.clone()),
210 total_usage,
211 message_history: messages,
212 });
213 }
214 }
215
216 let records = self.execute_tool_calls(&response.tool_calls).await?;
218
219 if self.config.stop_on_error {
221 if let Some(failed) = records.iter().find(|r| !r.success) {
222 let err_msg = failed
223 .result
224 .get("error")
225 .and_then(|v| v.as_str())
226 .unwrap_or("tool error")
227 .to_string();
228 return Ok(CoordinatorResult {
229 content: response.content,
230 tool_calls: all_tool_calls,
231 iterations: iteration + 1,
232 finish_reason: FinishReason::Error(err_msg),
233 total_usage,
234 message_history: messages,
235 });
236 }
237 }
238
239 for record in records {
241 messages.push(ConversationMessage::tool_result(&record.id, &record.result));
242 all_tool_calls.push(record);
243 }
244 }
245
246 Ok(CoordinatorResult {
248 content: messages
249 .last()
250 .map(|m| m.content.clone())
251 .unwrap_or_default(),
252 tool_calls: all_tool_calls,
253 iterations: self.config.max_iterations,
254 finish_reason: FinishReason::MaxIterations,
255 total_usage,
256 message_history: messages,
257 })
258 }
259
260 async fn execute_tool_calls(
265 &self,
266 calls: &[ToolCallRequest],
267 ) -> crate::Result<Vec<ToolCallRecord>> {
268 if self.config.parallel_execution {
269 self.execute_parallel(calls).await
270 } else {
271 self.execute_sequential(calls).await
272 }
273 }
274
275 async fn execute_parallel(
276 &self,
277 calls: &[ToolCallRequest],
278 ) -> crate::Result<Vec<ToolCallRecord>> {
279 let futures = calls.iter().map(|c| self.execute_single_tool(c));
280 let results = join_all(futures).await;
281
282 let mut records = Vec::with_capacity(results.len());
283 for (i, res) in results.into_iter().enumerate() {
284 match res {
285 Ok(record) => records.push(record),
286 Err(e) if self.config.stop_on_error => return Err(e),
287 Err(e) => {
288 let call = &calls[i];
290 records.push(ToolCallRecord {
291 id: call.id.clone(),
292 name: call.name.clone(),
293 arguments: call.arguments.clone(),
294 result: serde_json::json!({"error": e.to_string()}),
295 success: false,
296 duration_ms: 0,
297 });
298 }
299 }
300 }
301 Ok(records)
302 }
303
304 async fn execute_sequential(
305 &self,
306 calls: &[ToolCallRequest],
307 ) -> crate::Result<Vec<ToolCallRecord>> {
308 let mut records = Vec::with_capacity(calls.len());
309 for call in calls {
310 match self.execute_single_tool(call).await {
311 Ok(record) => records.push(record),
312 Err(e) if self.config.stop_on_error => return Err(e),
313 Err(e) => {
314 records.push(ToolCallRecord {
315 id: call.id.clone(),
316 name: call.name.clone(),
317 arguments: call.arguments.clone(),
318 result: serde_json::json!({"error": e.to_string()}),
319 success: false,
320 duration_ms: 0,
321 });
322 }
323 }
324 }
325 Ok(records)
326 }
327
328 async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
329 let start = Instant::now();
330
331 let result = timeout(
332 self.config.tool_timeout,
333 self.registry.execute(&call.name, call.arguments.clone()),
334 )
335 .await;
336
337 let duration_ms = start.elapsed().as_millis() as u64;
338
339 match result {
340 Ok(Ok(value)) => Ok(ToolCallRecord {
341 id: call.id.clone(),
342 name: call.name.clone(),
343 arguments: call.arguments.clone(),
344 result: value,
345 success: true,
346 duration_ms,
347 }),
348 Ok(Err(e)) => Ok(ToolCallRecord {
349 id: call.id.clone(),
350 name: call.name.clone(),
351 arguments: call.arguments.clone(),
352 result: serde_json::json!({"error": e.to_string()}),
353 success: false,
354 duration_ms,
355 }),
356 Err(_elapsed) => Ok(ToolCallRecord {
357 id: call.id.clone(),
358 name: call.name.clone(),
359 arguments: call.arguments.clone(),
360 result: serde_json::json!({"error": "tool execution timed out"}),
361 success: false,
362 duration_ms,
363 }),
364 }
365 }
366}
367
368pub const KNOWN_AGENT_TYPES: &[&str] = &[
374 "explore",
375 "plan",
376 "task",
377 "reviewer",
378 "designer",
379 "librarian",
380];
381
382#[derive(Debug, Clone, PartialEq, Eq)]
384pub struct ScheduledTask {
385 pub id: String,
386 pub agent_type: String,
387 pub assignment: String,
388}
389
390#[derive(Debug, Clone, PartialEq, Eq)]
392pub enum ScheduleError {
393 EmptyTaskList,
394 InvalidAgentType(String),
395 DuplicateTaskId(String),
396}
397
398impl std::fmt::Display for ScheduleError {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 match self {
401 ScheduleError::EmptyTaskList => write!(f, "task list must not be empty"),
402 ScheduleError::InvalidAgentType(agent) => write!(
403 f,
404 "unknown agent type '{agent}'. Valid types: {}",
405 KNOWN_AGENT_TYPES.join(", ")
406 ),
407 ScheduleError::DuplicateTaskId(id) => write!(f, "duplicate task id '{id}'"),
408 }
409 }
410}
411
412pub fn validate_task_schedule(tasks: &[ScheduledTask]) -> Result<(), ScheduleError> {
415 if tasks.is_empty() {
416 return Err(ScheduleError::EmptyTaskList);
417 }
418
419 let mut seen_ids = HashSet::with_capacity(tasks.len());
420 for task in tasks {
421 if !KNOWN_AGENT_TYPES.contains(&task.agent_type.as_str()) {
422 return Err(ScheduleError::InvalidAgentType(task.agent_type.clone()));
423 }
424 if !seen_ids.insert(task.id.clone()) {
425 return Err(ScheduleError::DuplicateTaskId(task.id.clone()));
426 }
427 if task.assignment.trim().is_empty() {
428 return Err(ScheduleError::InvalidAgentType(
429 "assignment must be non-empty".into(),
430 ));
431 }
432 }
433 Ok(())
434}
435
436#[derive(Debug, Clone, PartialEq, Eq)]
438pub struct ScheduledTaskResult {
439 pub id: String,
440 pub output: Value,
441}
442
443#[async_trait]
445pub trait TaskRunner: Send + Sync {
446 async fn run(&self, task: &ScheduledTask) -> crate::Result<Value>;
447}
448
449pub struct TaskScheduleCoordinator<R> {
451 runner: Arc<R>,
452}
453
454impl<R: TaskRunner> TaskScheduleCoordinator<R> {
455 pub fn new(runner: Arc<R>) -> Self {
456 Self { runner }
457 }
458
459 pub async fn schedule(
461 &self,
462 tasks: &[ScheduledTask],
463 ) -> Result<Vec<ScheduledTaskResult>, ScheduleError> {
464 validate_task_schedule(tasks)?;
465
466 let mut results = Vec::with_capacity(tasks.len());
467 for task in tasks {
468 let output = self.runner.run(task).await.map_err(|e| {
469 ScheduleError::InvalidAgentType(format!("task '{}' failed: {e}", task.id))
470 })?;
471 results.push(ScheduledTaskResult {
472 id: task.id.clone(),
473 output,
474 });
475 }
476 Ok(results)
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use std::sync::Arc;
484
485 #[tokio::test]
489 async fn execute_with_empty_registry_returns_model_response() {
490 use crate::agent::backend::mock::MockBackend;
491
492 let backend = Arc::new(MockBackend::with_text("Hello, world!"));
493 let registry = Arc::new(ToolRegistry::new());
494 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
495
496 let result = coordinator
497 .execute(None, "Say hello")
498 .await
499 .expect("coordinator should not error");
500
501 assert_eq!(result.content, "Hello, world!");
502 assert_eq!(result.finish_reason, FinishReason::Stop);
503 assert_eq!(result.iterations, 1);
504 assert!(result.tool_calls.is_empty());
505 assert_eq!(result.message_history.len(), 2);
507 }
508
509 #[test]
511 fn tool_calling_config_defaults_are_sensible() {
512 use std::time::Duration;
513 let cfg = ToolCallingConfig::default();
514 assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
515 assert!(
516 cfg.parallel_execution,
517 "parallel_execution should default to true"
518 );
519 assert_eq!(
520 cfg.tool_timeout,
521 Duration::from_secs(30),
522 "tool_timeout default changed"
523 );
524 assert!(!cfg.stop_on_error, "stop_on_error should default to false");
525 }
526
527 #[tokio::test]
532 async fn coordinator_result_captures_finish_reason_max_iterations() {
533 use crate::agent::backend::mock::{MockBackend, MockResponse};
534 use crate::tools::Tool;
535 use async_trait::async_trait;
536 use serde_json::Value;
537
538 struct NoOpTool;
540
541 #[async_trait]
542 impl Tool for NoOpTool {
543 fn name(&self) -> &str {
544 "noop"
545 }
546 fn description(&self) -> &str {
547 "does nothing"
548 }
549 fn parameters_schema(&self) -> Value {
550 serde_json::json!({"type": "object", "properties": {}})
551 }
552 async fn execute(&self, _args: Value) -> crate::Result<Value> {
553 Ok(serde_json::json!({"ok": true}))
554 }
555 }
556
557 let responses: Vec<MockResponse> = (0..15)
560 .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
561 .collect();
562 let backend = Arc::new(MockBackend::new(responses));
563
564 let mut registry = ToolRegistry::new();
565 registry.register(std::sync::Arc::new(NoOpTool));
566 let registry = Arc::new(registry);
567
568 let config = ToolCallingConfig {
569 max_iterations: 3,
570 parallel_execution: false,
571 ..ToolCallingConfig::default()
572 };
573 let coordinator = ToolCoordinator::new(backend, registry, config);
574
575 let result = coordinator
576 .execute(None, "loop forever")
577 .await
578 .expect("coordinator should not hard-error");
579
580 assert_eq!(
581 result.finish_reason,
582 FinishReason::MaxIterations,
583 "expected MaxIterations, got {:?}",
584 result.finish_reason
585 );
586 assert_eq!(result.iterations, 3);
587 assert_eq!(result.tool_calls.len(), 3);
589 assert!(result.tool_calls.iter().all(|tc| tc.success));
590 }
591
592 #[tokio::test]
595 async fn test_unknown_tool_validation_returns_unknown_tool_finish_reason() {
596 use crate::agent::backend::mock::MockBackend;
597
598 let backend = Arc::new(MockBackend::with_tool_call(
599 "call_ghost",
600 "definitely_not_registered",
601 serde_json::json!({}),
602 "should not reach this",
603 ));
604 let registry = Arc::new(ToolRegistry::new());
605 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
606
607 let result = coordinator
608 .execute(None, "use a ghost tool")
609 .await
610 .expect("unknown tool should surface as a coordinator result, not a hard error");
611
612 assert_eq!(
613 result.finish_reason,
614 FinishReason::UnknownTool("definitely_not_registered".into())
615 );
616 assert!(
617 result.tool_calls.is_empty(),
618 "unknown tool must not be executed"
619 );
620 assert_eq!(result.iterations, 1);
621 }
622
623 #[tokio::test]
626 async fn test_stop_on_error_halts_on_failed_tool_execution() {
627 use crate::agent::backend::mock::{MockBackend, MockResponse};
628 use crate::tools::Tool;
629 use async_trait::async_trait;
630 use serde_json::Value;
631
632 struct FailingTool;
633
634 #[async_trait]
635 impl Tool for FailingTool {
636 fn name(&self) -> &str {
637 "fail_me"
638 }
639 fn description(&self) -> &str {
640 "always fails"
641 }
642 fn parameters_schema(&self) -> Value {
643 serde_json::json!({"type": "object", "properties": {}})
644 }
645 async fn execute(&self, _args: Value) -> crate::Result<Value> {
646 Err(crate::PawanError::Tool("intentional failure".into()))
647 }
648 }
649
650 let backend = Arc::new(MockBackend::new(vec![
651 MockResponse::tool_call("fail_me", serde_json::json!({})),
652 MockResponse::text("unreachable"),
653 ]));
654
655 let mut registry = ToolRegistry::new();
656 registry.register(Arc::new(FailingTool));
657 let registry = Arc::new(registry);
658
659 let config = ToolCallingConfig {
660 stop_on_error: true,
661 parallel_execution: false,
662 ..ToolCallingConfig::default()
663 };
664 let coordinator = ToolCoordinator::new(backend, registry, config);
665
666 let result = coordinator
667 .execute(None, "trigger failure")
668 .await
669 .expect("stop_on_error should return Ok with Error finish reason");
670
671 match &result.finish_reason {
672 FinishReason::Error(msg) => {
673 assert!(
674 msg.contains("intentional failure"),
675 "error message should propagate from tool, got: {}",
676 msg
677 );
678 }
679 other => panic!("expected FinishReason::Error, got {:?}", other),
680 }
681 assert_eq!(result.iterations, 1);
682 }
683
684 #[tokio::test]
686 async fn test_tool_timeout_records_failed_tool_call() {
687 use crate::agent::backend::mock::{MockBackend, MockResponse};
688 use crate::tools::Tool;
689 use async_trait::async_trait;
690 use serde_json::Value;
691 use std::time::Duration;
692
693 struct SlowTool;
694
695 #[async_trait]
696 impl Tool for SlowTool {
697 fn name(&self) -> &str {
698 "slow_tool"
699 }
700 fn description(&self) -> &str {
701 "sleeps longer than the coordinator timeout"
702 }
703 fn parameters_schema(&self) -> Value {
704 serde_json::json!({"type": "object", "properties": {}})
705 }
706 async fn execute(&self, _args: Value) -> crate::Result<Value> {
707 tokio::time::sleep(Duration::from_secs(2)).await;
708 Ok(serde_json::json!({"ok": true}))
709 }
710 }
711
712 let backend = Arc::new(MockBackend::new(vec![
713 MockResponse::tool_call("slow_tool", serde_json::json!({})),
714 MockResponse::text("done after timeout"),
715 ]));
716
717 let mut registry = ToolRegistry::new();
718 registry.register(Arc::new(SlowTool));
719 let registry = Arc::new(registry);
720
721 let config = ToolCallingConfig {
722 tool_timeout: Duration::from_millis(50),
723 parallel_execution: false,
724 ..ToolCallingConfig::default()
725 };
726 let coordinator = ToolCoordinator::new(backend, registry, config);
727
728 let result = coordinator
729 .execute(None, "run slow tool")
730 .await
731 .expect("timeout should be absorbed into a failed tool record");
732
733 assert_eq!(result.tool_calls.len(), 1);
734 let record = &result.tool_calls[0];
735 assert!(
736 !record.success,
737 "timed-out tool must be marked unsuccessful"
738 );
739 assert_eq!(
740 record.result.get("error").and_then(|v| v.as_str()),
741 Some("tool execution timed out")
742 );
743 assert_eq!(result.finish_reason, FinishReason::Stop);
745 assert_eq!(result.iterations, 2);
746 }
747
748 #[tokio::test]
750 async fn test_execute_with_system_prompt_prepends_system_message() {
751 use crate::agent::backend::mock::MockBackend;
752
753 let backend = Arc::new(MockBackend::with_text("acknowledged"));
754 let registry = Arc::new(ToolRegistry::new());
755 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
756
757 let result = coordinator
758 .execute(Some("be concise"), "hello")
759 .await
760 .expect("execute should succeed");
761
762 assert_eq!(result.message_history.len(), 3);
763 assert_eq!(result.message_history[0].role, Role::System);
764 assert_eq!(result.message_history[0].content, "be concise");
765 assert_eq!(result.message_history[1].role, Role::User);
766 assert_eq!(result.message_history[1].content, "hello");
767 assert_eq!(result.message_history[2].role, Role::Assistant);
768 }
769
770 #[tokio::test]
772 async fn test_token_usage_captured_from_backend_response() {
773 use crate::agent::backend::mock::{MockBackend, MockResponse};
774 use crate::tools::Tool;
775 use async_trait::async_trait;
776 use serde_json::Value;
777
778 struct NoOpTool;
779
780 #[async_trait]
781 impl Tool for NoOpTool {
782 fn name(&self) -> &str {
783 "noop"
784 }
785 fn description(&self) -> &str {
786 "does nothing"
787 }
788 fn parameters_schema(&self) -> Value {
789 serde_json::json!({"type": "object", "properties": {}})
790 }
791 async fn execute(&self, _args: Value) -> crate::Result<Value> {
792 Ok(serde_json::json!({"ok": true}))
793 }
794 }
795
796 let backend = Arc::new(MockBackend::new(vec![
797 MockResponse::tool_call("noop", serde_json::json!({})),
798 MockResponse::TextWithUsage {
799 text: "done".into(),
800 usage: TokenUsage {
801 prompt_tokens: 20,
802 completion_tokens: 8,
803 total_tokens: 28,
804 reasoning_tokens: 3,
805 action_tokens: 5,
806 },
807 },
808 ]));
809
810 let mut registry = ToolRegistry::new();
811 registry.register(Arc::new(NoOpTool));
812 let registry = Arc::new(registry);
813 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
814
815 let result = coordinator
816 .execute(None, "count tokens")
817 .await
818 .expect("execute should succeed");
819
820 assert_eq!(result.total_usage.prompt_tokens, 20);
821 assert_eq!(result.total_usage.completion_tokens, 8);
822 assert_eq!(result.total_usage.total_tokens, 28);
823 assert_eq!(result.total_usage.reasoning_tokens, 3);
824 assert_eq!(result.total_usage.action_tokens, 5);
825 assert_eq!(result.iterations, 2);
826 }
827
828 #[tokio::test]
830 async fn test_parallel_execution_dispatches_multiple_tools_in_one_turn() {
831 use crate::agent::backend::mock::MockBackend;
832 use crate::tools::Tool;
833 use async_trait::async_trait;
834 use serde_json::Value;
835
836 struct EchoTool {
837 suffix: &'static str,
838 }
839
840 #[async_trait]
841 impl Tool for EchoTool {
842 fn name(&self) -> &str {
843 self.suffix
844 }
845 fn description(&self) -> &str {
846 "echoes a suffix"
847 }
848 fn parameters_schema(&self) -> Value {
849 serde_json::json!({"type": "object", "properties": {}})
850 }
851 async fn execute(&self, _args: Value) -> crate::Result<Value> {
852 Ok(serde_json::json!({ "tool": self.suffix }))
853 }
854 }
855
856 let backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
857 ("call_a", "echo_a", serde_json::json!({})),
858 ("call_b", "echo_b", serde_json::json!({})),
859 ]));
860
861 let mut registry = ToolRegistry::new();
862 registry.register(Arc::new(EchoTool { suffix: "echo_a" }));
863 registry.register(Arc::new(EchoTool { suffix: "echo_b" }));
864 let registry = Arc::new(registry);
865
866 let config = ToolCallingConfig {
867 parallel_execution: true,
868 ..ToolCallingConfig::default()
869 };
870 let coordinator = ToolCoordinator::new(backend, registry, config);
871
872 let result = coordinator
873 .execute(None, "run both")
874 .await
875 .expect("parallel tool execution should succeed");
876
877 assert_eq!(result.tool_calls.len(), 2);
878 assert!(result.tool_calls.iter().all(|r| r.success));
879 let names: Vec<&str> = result.tool_calls.iter().map(|r| r.name.as_str()).collect();
880 assert!(names.contains(&"echo_a"));
881 assert!(names.contains(&"echo_b"));
882 assert_eq!(result.finish_reason, FinishReason::Stop);
883 assert_eq!(result.iterations, 2);
884 }
885
886 use async_trait::async_trait;
891 use serde_json::json;
892 use std::sync::Mutex;
893
894 struct MockTaskRunner {
895 dispatched: Mutex<Vec<String>>,
896 }
897
898 impl MockTaskRunner {
899 fn new() -> Self {
900 Self {
901 dispatched: Mutex::new(Vec::new()),
902 }
903 }
904
905 fn dispatched_ids(&self) -> Vec<String> {
906 self.dispatched.lock().unwrap().clone()
907 }
908 }
909
910 #[async_trait]
911 impl TaskRunner for MockTaskRunner {
912 async fn run(&self, task: &ScheduledTask) -> crate::Result<Value> {
913 self.dispatched.lock().unwrap().push(task.id.clone());
914 Ok(json!({
915 "id": task.id,
916 "agent": task.agent_type,
917 "assignment": task.assignment,
918 }))
919 }
920 }
921
922 #[tokio::test]
924 async fn schedule_empty_task_list_rejects_without_dispatch() {
925 let runner = Arc::new(MockTaskRunner::new());
926 let coordinator = TaskScheduleCoordinator::new(runner.clone());
927
928 let err = coordinator
929 .schedule(&[])
930 .await
931 .expect_err("empty task list should fail validation");
932
933 assert_eq!(err, ScheduleError::EmptyTaskList);
934 assert!(runner.dispatched_ids().is_empty());
935 }
936
937 #[tokio::test]
939 async fn schedule_invalid_agent_type_rejects_without_dispatch() {
940 let runner = Arc::new(MockTaskRunner::new());
941 let coordinator = TaskScheduleCoordinator::new(runner.clone());
942
943 let tasks = [ScheduledTask {
944 id: "AuthProbe".into(),
945 agent_type: "not_a_real_agent".into(),
946 assignment: "probe auth".into(),
947 }];
948
949 let err = coordinator
950 .schedule(&tasks)
951 .await
952 .expect_err("invalid agent type should fail validation");
953
954 assert_eq!(
955 err,
956 ScheduleError::InvalidAgentType("not_a_real_agent".into())
957 );
958 assert!(runner.dispatched_ids().is_empty());
959 }
960
961 #[tokio::test]
963 async fn schedule_duplicate_task_ids_rejects_without_dispatch() {
964 let runner = Arc::new(MockTaskRunner::new());
965 let coordinator = TaskScheduleCoordinator::new(runner.clone());
966
967 let tasks = [
968 ScheduledTask {
969 id: "DupId".into(),
970 agent_type: "explore".into(),
971 assignment: "first".into(),
972 },
973 ScheduledTask {
974 id: "DupId".into(),
975 agent_type: "plan".into(),
976 assignment: "second".into(),
977 },
978 ];
979
980 let err = coordinator
981 .schedule(&tasks)
982 .await
983 .expect_err("duplicate ids should fail validation");
984
985 assert_eq!(err, ScheduleError::DuplicateTaskId("DupId".into()));
986 assert!(runner.dispatched_ids().is_empty());
987 }
988
989 #[tokio::test]
991 async fn schedule_valid_tasks_dispatches_via_mock_runner() {
992 let runner = Arc::new(MockTaskRunner::new());
993 let coordinator = TaskScheduleCoordinator::new(runner.clone());
994
995 let tasks = [
996 ScheduledTask {
997 id: "Alpha".into(),
998 agent_type: "explore".into(),
999 assignment: "scan src/".into(),
1000 },
1001 ScheduledTask {
1002 id: "Beta".into(),
1003 agent_type: "plan".into(),
1004 assignment: "draft refactor".into(),
1005 },
1006 ];
1007
1008 let results = coordinator
1009 .schedule(&tasks)
1010 .await
1011 .expect("valid schedule should succeed");
1012
1013 assert_eq!(runner.dispatched_ids(), vec!["Alpha", "Beta"]);
1014 assert_eq!(results.len(), 2);
1015 assert_eq!(results[0].id, "Alpha");
1016 assert_eq!(results[0].output["agent"], "explore");
1017 assert_eq!(results[1].id, "Beta");
1018 assert_eq!(results[1].output["agent"], "plan");
1019 }
1020}