1pub mod types;
33pub use types::*;
34
35use crate::agent::backend::LlmBackend;
36use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
37use crate::tools::ToolRegistry;
38use async_trait::async_trait;
39use futures::future::join_all;
40use serde_json::Value;
41use std::collections::HashSet;
42use std::sync::Arc;
43use std::time::Instant;
44use tokio::time::timeout;
45
46fn to_backend_message(msg: &ConversationMessage) -> Message {
61 let tool_result = if msg.role == Role::Tool {
62 msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
63 tool_call_id: id.clone(),
64 content: serde_json::from_str(&msg.content)
65 .unwrap_or(serde_json::Value::String(msg.content.clone())),
66 success: true,
67 })
68 } else {
69 None
70 };
71
72 Message {
73 role: msg.role.clone(),
74 content: msg.content.clone(),
75 tool_calls: msg.tool_calls.clone(),
76 tool_result,
77 }
78}
79
80pub struct ToolCoordinator {
105 backend: Arc<dyn LlmBackend>,
106 registry: Arc<ToolRegistry>,
107 config: ToolCallingConfig,
108}
109
110impl ToolCoordinator {
111 pub fn new(
113 backend: Arc<dyn LlmBackend>,
114 registry: Arc<ToolRegistry>,
115 config: ToolCallingConfig,
116 ) -> Self {
117 Self {
118 backend,
119 registry,
120 config,
121 }
122 }
123
124 pub async fn execute(
128 &self,
129 system_prompt: Option<&str>,
130 user_prompt: &str,
131 ) -> crate::Result<CoordinatorResult> {
132 let mut messages: Vec<ConversationMessage> = Vec::new();
133 if let Some(sys) = system_prompt {
134 messages.push(ConversationMessage::system(sys));
135 }
136 messages.push(ConversationMessage::user(user_prompt));
137 self.execute_with_history(messages).await
138 }
139
140 pub async fn execute_with_history(
146 &self,
147 mut messages: Vec<ConversationMessage>,
148 ) -> crate::Result<CoordinatorResult> {
149 let tool_defs = self.registry.get_definitions();
150 let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
151 let mut total_usage = TokenUsage::default();
152
153 for iteration in 0..self.config.max_iterations {
154 let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
156
157 let response = self
159 .backend
160 .generate(&backend_messages, &tool_defs, None)
161 .await?;
162
163 if let Some(usage) = &response.usage {
165 total_usage.prompt_tokens += usage.prompt_tokens;
166 total_usage.completion_tokens += usage.completion_tokens;
167 total_usage.total_tokens += usage.total_tokens;
168 total_usage.reasoning_tokens += usage.reasoning_tokens;
169 total_usage.action_tokens += usage.action_tokens;
170 }
171
172 messages.push(ConversationMessage::assistant(
174 &response.content,
175 response.tool_calls.clone(),
176 ));
177
178 if response.tool_calls.is_empty() {
180 return Ok(CoordinatorResult {
181 content: response.content,
182 tool_calls: all_tool_calls,
183 iterations: iteration + 1,
184 finish_reason: FinishReason::Stop,
185 total_usage,
186 message_history: messages,
187 });
188 }
189
190 if response.content.is_empty() && response.tool_calls.is_empty() {
192 return Ok(CoordinatorResult {
193 content: String::new(),
194 tool_calls: all_tool_calls,
195 iterations: iteration + 1,
196 finish_reason: FinishReason::Stop,
197 total_usage,
198 message_history: messages,
199 });
200 }
201
202 for tc in &response.tool_calls {
204 if !self.registry.has_tool(&tc.name) {
205 return Ok(CoordinatorResult {
206 content: response.content,
207 tool_calls: all_tool_calls,
208 iterations: iteration + 1,
209 finish_reason: FinishReason::UnknownTool(tc.name.clone()),
210 total_usage,
211 message_history: messages,
212 });
213 }
214 }
215
216 let records = self.execute_tool_calls(&response.tool_calls).await?;
218
219 if self.config.stop_on_error {
221 if let Some(failed) = records.iter().find(|r| !r.success) {
222 let err_msg = failed
223 .result
224 .get("error")
225 .and_then(|v| v.as_str())
226 .unwrap_or("tool error")
227 .to_string();
228 return Ok(CoordinatorResult {
229 content: response.content,
230 tool_calls: all_tool_calls,
231 iterations: iteration + 1,
232 finish_reason: FinishReason::Error(err_msg),
233 total_usage,
234 message_history: messages,
235 });
236 }
237 }
238
239 for record in records {
241 messages.push(ConversationMessage::tool_result(&record.id, &record.result));
242 all_tool_calls.push(record);
243 }
244 }
245
246 Ok(CoordinatorResult {
248 content: messages
249 .last()
250 .map(|m| m.content.clone())
251 .unwrap_or_default(),
252 tool_calls: all_tool_calls,
253 iterations: self.config.max_iterations,
254 finish_reason: FinishReason::MaxIterations,
255 total_usage,
256 message_history: messages,
257 })
258 }
259
260 async fn execute_tool_calls(
265 &self,
266 calls: &[ToolCallRequest],
267 ) -> crate::Result<Vec<ToolCallRecord>> {
268 if self.config.parallel_execution {
269 self.execute_parallel(calls).await
270 } else {
271 self.execute_sequential(calls).await
272 }
273 }
274
275 async fn execute_parallel(
276 &self,
277 calls: &[ToolCallRequest],
278 ) -> crate::Result<Vec<ToolCallRecord>> {
279 let futures = calls.iter().map(|c| self.execute_single_tool(c));
280 let results = join_all(futures).await;
281
282 let mut records = Vec::with_capacity(results.len());
283 for (i, res) in results.into_iter().enumerate() {
284 match res {
285 Ok(record) => records.push(record),
286 Err(e) if self.config.stop_on_error => return Err(e),
287 Err(e) => {
288 let call = &calls[i];
290 records.push(ToolCallRecord {
291 id: call.id.clone(),
292 name: call.name.clone(),
293 arguments: call.arguments.clone(),
294 result: serde_json::json!({"error": e.to_string()}),
295 success: false,
296 duration_ms: 0,
297 });
298 }
299 }
300 }
301 Ok(records)
302 }
303
304 async fn execute_sequential(
305 &self,
306 calls: &[ToolCallRequest],
307 ) -> crate::Result<Vec<ToolCallRecord>> {
308 let mut records = Vec::with_capacity(calls.len());
309 for call in calls {
310 match self.execute_single_tool(call).await {
311 Ok(record) => records.push(record),
312 Err(e) if self.config.stop_on_error => return Err(e),
313 Err(e) => {
314 records.push(ToolCallRecord {
315 id: call.id.clone(),
316 name: call.name.clone(),
317 arguments: call.arguments.clone(),
318 result: serde_json::json!({"error": e.to_string()}),
319 success: false,
320 duration_ms: 0,
321 });
322 }
323 }
324 }
325 Ok(records)
326 }
327
328 async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
329 let start = Instant::now();
330
331 let result = timeout(
332 self.config.tool_timeout,
333 self.registry.execute(&call.name, call.arguments.clone()),
334 )
335 .await;
336
337 let duration_ms = start.elapsed().as_millis() as u64;
338
339 match result {
340 Ok(Ok(value)) => Ok(ToolCallRecord {
341 id: call.id.clone(),
342 name: call.name.clone(),
343 arguments: call.arguments.clone(),
344 result: value,
345 success: true,
346 duration_ms,
347 }),
348 Ok(Err(e)) => Ok(ToolCallRecord {
349 id: call.id.clone(),
350 name: call.name.clone(),
351 arguments: call.arguments.clone(),
352 result: serde_json::json!({"error": e.to_string()}),
353 success: false,
354 duration_ms,
355 }),
356 Err(_elapsed) => Ok(ToolCallRecord {
357 id: call.id.clone(),
358 name: call.name.clone(),
359 arguments: call.arguments.clone(),
360 result: serde_json::json!({"error": "tool execution timed out"}),
361 success: false,
362 duration_ms,
363 }),
364 }
365 }
366}
367
368pub const KNOWN_AGENT_TYPES: &[&str] = &[
374 "explore",
375 "plan",
376 "task",
377 "reviewer",
378 "designer",
379 "librarian",
380];
381
382#[derive(Debug, Clone, PartialEq, Eq)]
384pub struct ScheduledTask {
385 pub id: String,
386 pub agent_type: String,
387 pub assignment: String,
388}
389
390#[derive(Debug, Clone, PartialEq, Eq)]
392pub enum ScheduleError {
393 EmptyTaskList,
394 InvalidAgentType(String),
395 DuplicateTaskId(String),
396}
397
398impl std::fmt::Display for ScheduleError {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 match self {
401 ScheduleError::EmptyTaskList => write!(f, "task list must not be empty"),
402 ScheduleError::InvalidAgentType(agent) => write!(
403 f,
404 "unknown agent type '{agent}'. Valid types: {}",
405 KNOWN_AGENT_TYPES.join(", ")
406 ),
407 ScheduleError::DuplicateTaskId(id) => write!(f, "duplicate task id '{id}'"),
408 }
409 }
410}
411
412pub fn validate_task_schedule(tasks: &[ScheduledTask]) -> Result<(), ScheduleError> {
415 if tasks.is_empty() {
416 return Err(ScheduleError::EmptyTaskList);
417 }
418
419 let mut seen_ids = HashSet::with_capacity(tasks.len());
420 for task in tasks {
421 if !KNOWN_AGENT_TYPES.contains(&task.agent_type.as_str()) {
422 return Err(ScheduleError::InvalidAgentType(task.agent_type.clone()));
423 }
424 if !seen_ids.insert(task.id.clone()) {
425 return Err(ScheduleError::DuplicateTaskId(task.id.clone()));
426 }
427 if task.assignment.trim().is_empty() {
428 return Err(ScheduleError::InvalidAgentType(
429 "assignment must be non-empty".into(),
430 ));
431 }
432 }
433 Ok(())
434}
435
436#[derive(Debug, Clone, PartialEq, Eq)]
438pub struct ScheduledTaskResult {
439 pub id: String,
440 pub output: Value,
441}
442
443#[async_trait]
445pub trait TaskRunner: Send + Sync {
446 async fn run(&self, task: &ScheduledTask) -> crate::Result<Value>;
447}
448
449pub struct TaskScheduleCoordinator<R> {
451 runner: Arc<R>,
452}
453
454impl<R: TaskRunner> TaskScheduleCoordinator<R> {
455 pub fn new(runner: Arc<R>) -> Self {
456 Self { runner }
457 }
458
459 pub async fn schedule(
461 &self,
462 tasks: &[ScheduledTask],
463 ) -> Result<Vec<ScheduledTaskResult>, ScheduleError> {
464 validate_task_schedule(tasks)?;
465
466 let mut results = Vec::with_capacity(tasks.len());
467 for task in tasks {
468 let output = self.runner.run(task).await.map_err(|e| {
469 ScheduleError::InvalidAgentType(format!("task '{}' failed: {e}", task.id))
470 })?;
471 results.push(ScheduledTaskResult {
472 id: task.id.clone(),
473 output,
474 });
475 }
476 Ok(results)
477 }
478}
479
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use std::sync::Arc;
485
486 #[tokio::test]
490 async fn execute_with_empty_registry_returns_model_response() {
491 use crate::agent::backend::mock::MockBackend;
492
493 let backend = Arc::new(MockBackend::with_text("Hello, world!"));
494 let registry = Arc::new(ToolRegistry::new());
495 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
496
497 let result = coordinator
498 .execute(None, "Say hello")
499 .await
500 .expect("coordinator should not error");
501
502 assert_eq!(result.content, "Hello, world!");
503 assert_eq!(result.finish_reason, FinishReason::Stop);
504 assert_eq!(result.iterations, 1);
505 assert!(result.tool_calls.is_empty());
506 assert_eq!(result.message_history.len(), 2);
508 }
509
510 #[test]
512 fn tool_calling_config_defaults_are_sensible() {
513 use std::time::Duration;
514 let cfg = ToolCallingConfig::default();
515 assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
516 assert!(
517 cfg.parallel_execution,
518 "parallel_execution should default to true"
519 );
520 assert_eq!(
521 cfg.tool_timeout,
522 Duration::from_secs(30),
523 "tool_timeout default changed"
524 );
525 assert!(!cfg.stop_on_error, "stop_on_error should default to false");
526 }
527
528 #[tokio::test]
533 async fn coordinator_result_captures_finish_reason_max_iterations() {
534 use crate::agent::backend::mock::{MockBackend, MockResponse};
535 use crate::tools::Tool;
536 use async_trait::async_trait;
537 use serde_json::Value;
538
539 struct NoOpTool;
541
542 #[async_trait]
543 impl Tool for NoOpTool {
544 fn name(&self) -> &str {
545 "noop"
546 }
547 fn description(&self) -> &str {
548 "does nothing"
549 }
550 fn parameters_schema(&self) -> Value {
551 serde_json::json!({"type": "object", "properties": {}})
552 }
553 async fn execute(&self, _args: Value) -> crate::Result<Value> {
554 Ok(serde_json::json!({"ok": true}))
555 }
556 }
557
558 let responses: Vec<MockResponse> = (0..15)
561 .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
562 .collect();
563 let backend = Arc::new(MockBackend::new(responses));
564
565 let mut registry = ToolRegistry::new();
566 registry.register(std::sync::Arc::new(NoOpTool));
567 let registry = Arc::new(registry);
568
569 let config = ToolCallingConfig {
570 max_iterations: 3,
571 parallel_execution: false,
572 ..ToolCallingConfig::default()
573 };
574 let coordinator = ToolCoordinator::new(backend, registry, config);
575
576 let result = coordinator
577 .execute(None, "loop forever")
578 .await
579 .expect("coordinator should not hard-error");
580
581 assert_eq!(
582 result.finish_reason,
583 FinishReason::MaxIterations,
584 "expected MaxIterations, got {:?}",
585 result.finish_reason
586 );
587 assert_eq!(result.iterations, 3);
588 assert_eq!(result.tool_calls.len(), 3);
590 assert!(result.tool_calls.iter().all(|tc| tc.success));
591 }
592
593 #[tokio::test]
596 async fn test_unknown_tool_validation_returns_unknown_tool_finish_reason() {
597 use crate::agent::backend::mock::MockBackend;
598
599 let backend = Arc::new(MockBackend::with_tool_call(
600 "call_ghost",
601 "definitely_not_registered",
602 serde_json::json!({}),
603 "should not reach this",
604 ));
605 let registry = Arc::new(ToolRegistry::new());
606 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
607
608 let result = coordinator
609 .execute(None, "use a ghost tool")
610 .await
611 .expect("unknown tool should surface as a coordinator result, not a hard error");
612
613 assert_eq!(
614 result.finish_reason,
615 FinishReason::UnknownTool("definitely_not_registered".into())
616 );
617 assert!(
618 result.tool_calls.is_empty(),
619 "unknown tool must not be executed"
620 );
621 assert_eq!(result.iterations, 1);
622 }
623
624 #[tokio::test]
627 async fn test_stop_on_error_halts_on_failed_tool_execution() {
628 use crate::agent::backend::mock::{MockBackend, MockResponse};
629 use crate::tools::Tool;
630 use async_trait::async_trait;
631 use serde_json::Value;
632
633 struct FailingTool;
634
635 #[async_trait]
636 impl Tool for FailingTool {
637 fn name(&self) -> &str {
638 "fail_me"
639 }
640 fn description(&self) -> &str {
641 "always fails"
642 }
643 fn parameters_schema(&self) -> Value {
644 serde_json::json!({"type": "object", "properties": {}})
645 }
646 async fn execute(&self, _args: Value) -> crate::Result<Value> {
647 Err(crate::PawanError::Tool("intentional failure".into()))
648 }
649 }
650
651 let backend = Arc::new(MockBackend::new(vec![
652 MockResponse::tool_call("fail_me", serde_json::json!({})),
653 MockResponse::text("unreachable"),
654 ]));
655
656 let mut registry = ToolRegistry::new();
657 registry.register(Arc::new(FailingTool));
658 let registry = Arc::new(registry);
659
660 let config = ToolCallingConfig {
661 stop_on_error: true,
662 parallel_execution: false,
663 ..ToolCallingConfig::default()
664 };
665 let coordinator = ToolCoordinator::new(backend, registry, config);
666
667 let result = coordinator
668 .execute(None, "trigger failure")
669 .await
670 .expect("stop_on_error should return Ok with Error finish reason");
671
672 match &result.finish_reason {
673 FinishReason::Error(msg) => {
674 assert!(
675 msg.contains("intentional failure"),
676 "error message should propagate from tool, got: {}",
677 msg
678 );
679 }
680 other => panic!("expected FinishReason::Error, got {:?}", other),
681 }
682 assert_eq!(result.iterations, 1);
683 }
684
685 #[tokio::test]
687 async fn test_tool_timeout_records_failed_tool_call() {
688 use crate::agent::backend::mock::{MockBackend, MockResponse};
689 use crate::tools::Tool;
690 use async_trait::async_trait;
691 use serde_json::Value;
692 use std::time::Duration;
693
694 struct SlowTool;
695
696 #[async_trait]
697 impl Tool for SlowTool {
698 fn name(&self) -> &str {
699 "slow_tool"
700 }
701 fn description(&self) -> &str {
702 "sleeps longer than the coordinator timeout"
703 }
704 fn parameters_schema(&self) -> Value {
705 serde_json::json!({"type": "object", "properties": {}})
706 }
707 async fn execute(&self, _args: Value) -> crate::Result<Value> {
708 tokio::time::sleep(Duration::from_secs(2)).await;
709 Ok(serde_json::json!({"ok": true}))
710 }
711 }
712
713 let backend = Arc::new(MockBackend::new(vec![
714 MockResponse::tool_call("slow_tool", serde_json::json!({})),
715 MockResponse::text("done after timeout"),
716 ]));
717
718 let mut registry = ToolRegistry::new();
719 registry.register(Arc::new(SlowTool));
720 let registry = Arc::new(registry);
721
722 let config = ToolCallingConfig {
723 tool_timeout: Duration::from_millis(50),
724 parallel_execution: false,
725 ..ToolCallingConfig::default()
726 };
727 let coordinator = ToolCoordinator::new(backend, registry, config);
728
729 let result = coordinator
730 .execute(None, "run slow tool")
731 .await
732 .expect("timeout should be absorbed into a failed tool record");
733
734 assert_eq!(result.tool_calls.len(), 1);
735 let record = &result.tool_calls[0];
736 assert!(!record.success, "timed-out tool must be marked unsuccessful");
737 assert_eq!(
738 record.result.get("error").and_then(|v| v.as_str()),
739 Some("tool execution timed out")
740 );
741 assert_eq!(result.finish_reason, FinishReason::Stop);
743 assert_eq!(result.iterations, 2);
744 }
745
746 #[tokio::test]
748 async fn test_execute_with_system_prompt_prepends_system_message() {
749 use crate::agent::backend::mock::MockBackend;
750
751 let backend = Arc::new(MockBackend::with_text("acknowledged"));
752 let registry = Arc::new(ToolRegistry::new());
753 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
754
755 let result = coordinator
756 .execute(Some("be concise"), "hello")
757 .await
758 .expect("execute should succeed");
759
760 assert_eq!(result.message_history.len(), 3);
761 assert_eq!(result.message_history[0].role, Role::System);
762 assert_eq!(result.message_history[0].content, "be concise");
763 assert_eq!(result.message_history[1].role, Role::User);
764 assert_eq!(result.message_history[1].content, "hello");
765 assert_eq!(result.message_history[2].role, Role::Assistant);
766 }
767
768 #[tokio::test]
770 async fn test_token_usage_captured_from_backend_response() {
771 use crate::agent::backend::mock::{MockBackend, MockResponse};
772 use crate::tools::Tool;
773 use async_trait::async_trait;
774 use serde_json::Value;
775
776 struct NoOpTool;
777
778 #[async_trait]
779 impl Tool for NoOpTool {
780 fn name(&self) -> &str {
781 "noop"
782 }
783 fn description(&self) -> &str {
784 "does nothing"
785 }
786 fn parameters_schema(&self) -> Value {
787 serde_json::json!({"type": "object", "properties": {}})
788 }
789 async fn execute(&self, _args: Value) -> crate::Result<Value> {
790 Ok(serde_json::json!({"ok": true}))
791 }
792 }
793
794 let backend = Arc::new(MockBackend::new(vec![
795 MockResponse::tool_call("noop", serde_json::json!({})),
796 MockResponse::TextWithUsage {
797 text: "done".into(),
798 usage: TokenUsage {
799 prompt_tokens: 20,
800 completion_tokens: 8,
801 total_tokens: 28,
802 reasoning_tokens: 3,
803 action_tokens: 5,
804 },
805 },
806 ]));
807
808 let mut registry = ToolRegistry::new();
809 registry.register(Arc::new(NoOpTool));
810 let registry = Arc::new(registry);
811 let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
812
813 let result = coordinator
814 .execute(None, "count tokens")
815 .await
816 .expect("execute should succeed");
817
818 assert_eq!(result.total_usage.prompt_tokens, 20);
819 assert_eq!(result.total_usage.completion_tokens, 8);
820 assert_eq!(result.total_usage.total_tokens, 28);
821 assert_eq!(result.total_usage.reasoning_tokens, 3);
822 assert_eq!(result.total_usage.action_tokens, 5);
823 assert_eq!(result.iterations, 2);
824 }
825
826 #[tokio::test]
828 async fn test_parallel_execution_dispatches_multiple_tools_in_one_turn() {
829 use crate::agent::backend::mock::MockBackend;
830 use crate::tools::Tool;
831 use async_trait::async_trait;
832 use serde_json::Value;
833
834 struct EchoTool {
835 suffix: &'static str,
836 }
837
838 #[async_trait]
839 impl Tool for EchoTool {
840 fn name(&self) -> &str {
841 self.suffix
842 }
843 fn description(&self) -> &str {
844 "echoes a suffix"
845 }
846 fn parameters_schema(&self) -> Value {
847 serde_json::json!({"type": "object", "properties": {}})
848 }
849 async fn execute(&self, _args: Value) -> crate::Result<Value> {
850 Ok(serde_json::json!({ "tool": self.suffix }))
851 }
852 }
853
854 let backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
855 ("call_a", "echo_a", serde_json::json!({})),
856 ("call_b", "echo_b", serde_json::json!({})),
857 ]));
858
859 let mut registry = ToolRegistry::new();
860 registry.register(Arc::new(EchoTool { suffix: "echo_a" }));
861 registry.register(Arc::new(EchoTool { suffix: "echo_b" }));
862 let registry = Arc::new(registry);
863
864 let config = ToolCallingConfig {
865 parallel_execution: true,
866 ..ToolCallingConfig::default()
867 };
868 let coordinator = ToolCoordinator::new(backend, registry, config);
869
870 let result = coordinator
871 .execute(None, "run both")
872 .await
873 .expect("parallel tool execution should succeed");
874
875 assert_eq!(result.tool_calls.len(), 2);
876 assert!(result.tool_calls.iter().all(|r| r.success));
877 let names: Vec<&str> = result.tool_calls.iter().map(|r| r.name.as_str()).collect();
878 assert!(names.contains(&"echo_a"));
879 assert!(names.contains(&"echo_b"));
880 assert_eq!(result.finish_reason, FinishReason::Stop);
881 assert_eq!(result.iterations, 2);
882 }
883
884 use async_trait::async_trait;
889 use serde_json::json;
890 use std::sync::Mutex;
891
892 struct MockTaskRunner {
893 dispatched: Mutex<Vec<String>>,
894 }
895
896 impl MockTaskRunner {
897 fn new() -> Self {
898 Self {
899 dispatched: Mutex::new(Vec::new()),
900 }
901 }
902
903 fn dispatched_ids(&self) -> Vec<String> {
904 self.dispatched.lock().unwrap().clone()
905 }
906 }
907
908 #[async_trait]
909 impl TaskRunner for MockTaskRunner {
910 async fn run(&self, task: &ScheduledTask) -> crate::Result<Value> {
911 self.dispatched.lock().unwrap().push(task.id.clone());
912 Ok(json!({
913 "id": task.id,
914 "agent": task.agent_type,
915 "assignment": task.assignment,
916 }))
917 }
918 }
919
920 #[tokio::test]
922 async fn schedule_empty_task_list_rejects_without_dispatch() {
923 let runner = Arc::new(MockTaskRunner::new());
924 let coordinator = TaskScheduleCoordinator::new(runner.clone());
925
926 let err = coordinator
927 .schedule(&[])
928 .await
929 .expect_err("empty task list should fail validation");
930
931 assert_eq!(err, ScheduleError::EmptyTaskList);
932 assert!(runner.dispatched_ids().is_empty());
933 }
934
935 #[tokio::test]
937 async fn schedule_invalid_agent_type_rejects_without_dispatch() {
938 let runner = Arc::new(MockTaskRunner::new());
939 let coordinator = TaskScheduleCoordinator::new(runner.clone());
940
941 let tasks = [ScheduledTask {
942 id: "AuthProbe".into(),
943 agent_type: "not_a_real_agent".into(),
944 assignment: "probe auth".into(),
945 }];
946
947 let err = coordinator
948 .schedule(&tasks)
949 .await
950 .expect_err("invalid agent type should fail validation");
951
952 assert_eq!(
953 err,
954 ScheduleError::InvalidAgentType("not_a_real_agent".into())
955 );
956 assert!(runner.dispatched_ids().is_empty());
957 }
958
959 #[tokio::test]
961 async fn schedule_duplicate_task_ids_rejects_without_dispatch() {
962 let runner = Arc::new(MockTaskRunner::new());
963 let coordinator = TaskScheduleCoordinator::new(runner.clone());
964
965 let tasks = [
966 ScheduledTask {
967 id: "DupId".into(),
968 agent_type: "explore".into(),
969 assignment: "first".into(),
970 },
971 ScheduledTask {
972 id: "DupId".into(),
973 agent_type: "plan".into(),
974 assignment: "second".into(),
975 },
976 ];
977
978 let err = coordinator
979 .schedule(&tasks)
980 .await
981 .expect_err("duplicate ids should fail validation");
982
983 assert_eq!(err, ScheduleError::DuplicateTaskId("DupId".into()));
984 assert!(runner.dispatched_ids().is_empty());
985 }
986
987 #[tokio::test]
989 async fn schedule_valid_tasks_dispatches_via_mock_runner() {
990 let runner = Arc::new(MockTaskRunner::new());
991 let coordinator = TaskScheduleCoordinator::new(runner.clone());
992
993 let tasks = [
994 ScheduledTask {
995 id: "Alpha".into(),
996 agent_type: "explore".into(),
997 assignment: "scan src/".into(),
998 },
999 ScheduledTask {
1000 id: "Beta".into(),
1001 agent_type: "plan".into(),
1002 assignment: "draft refactor".into(),
1003 },
1004 ];
1005
1006 let results = coordinator
1007 .schedule(&tasks)
1008 .await
1009 .expect("valid schedule should succeed");
1010
1011 assert_eq!(runner.dispatched_ids(), vec!["Alpha", "Beta"]);
1012 assert_eq!(results.len(), 2);
1013 assert_eq!(results[0].id, "Alpha");
1014 assert_eq!(results[0].output["agent"], "explore");
1015 assert_eq!(results[1].id, "Beta");
1016 assert_eq!(results[1].output["agent"], "plan");
1017 }
1018}