Skip to main content

pawan/coordinator/
mod.rs

1//! Multi-turn tool coordinator — data types and runtime.
2//!
3//! Provides a provider-agnostic orchestration layer for agent tool-calling
4//! loops: send a prompt with tool definitions, handle tool call requests,
5//! execute tools, feed results back, repeat until the model produces a final
6//! response or hits an iteration cap.
7//!
8//! Types reused from [`crate::agent`]:
9//! - [`ToolCallRequest`] — what the model asks for
10//! - [`ToolCallRecord`]  — what actually happened
11//! - [`TokenUsage`]      — accumulated counts
12//!
13//! Types defined here:
14//! - [`ToolCallingConfig`]   — iteration / parallelism / timeout knobs
15//! - [`FinishReason`]        — why the session ended
16//! - [`Role`]              — re-exported from `crate::agent` (system/user/assistant/tool)
17//! - [`ConversationMessage`] — a single turn in the history
18//! - [`CoordinatorResult`]   — everything the caller gets back
19//! - [`ToolCoordinator`]     — the runtime that drives the LLM+tool loop
20//! - [`TaskScheduleCoordinator`] — validates and dispatches multi-agent task batches
21//! - [`ScheduledTask`] / [`ScheduleError`] — task scheduling wire types
22//!
23//! ## Design notes
24//!
25//! - [`ToolCallRecord`] is reused from [`crate::agent`] rather than duplicated.
26//!   Failed tool calls land in `result` as a `{"error": "..."}` JSON object
27//!   with `success: false`, matching pawan's existing agent loop — there's no
28//!   separate `error` field on the record.
29//! - [`ConversationMessage::tool_call_id`] is only populated on [`Role::Tool`]
30//!   turns and links the result back to the assistant message that requested it.
31
32pub mod types;
33pub use types::*;
34
35use crate::agent::backend::LlmBackend;
36use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
37use crate::tools::ToolRegistry;
38use async_trait::async_trait;
39use futures::future::join_all;
40use serde_json::Value;
41use std::collections::HashSet;
42use std::sync::Arc;
43use std::time::Instant;
44use tokio::time::timeout;
45
46// ---------------------------------------------------------------------------
47// Type bridge: ConversationMessage → agent::Message
48// ---------------------------------------------------------------------------
49
50/// Convert a [`ConversationMessage`] to the backend's [`Message`] type.
51///
52/// The coordinator tracks history in its own `ConversationMessage` type, but
53/// `LlmBackend::generate()` expects `&[agent::Message]`. This function maps
54/// the coordinator's richer type to the backend wire format:
55///
56/// - `Tool` role messages: parse `content` back to JSON and populate
57///   `Message::tool_result` with a `ToolResultMessage`.
58/// - `Assistant` messages: copy `tool_calls` directly (same type).
59/// - `System`/`User` messages: straightforward role + content copy.
60fn to_backend_message(msg: &ConversationMessage) -> Message {
61    let tool_result = if msg.role == Role::Tool {
62        msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
63            tool_call_id: id.clone(),
64            content: serde_json::from_str(&msg.content)
65                .unwrap_or(serde_json::Value::String(msg.content.clone())),
66            success: true,
67        })
68    } else {
69        None
70    };
71
72    Message {
73        role: msg.role.clone(),
74        content: msg.content.clone(),
75        tool_calls: msg.tool_calls.clone(),
76        tool_result,
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ToolCoordinator runtime
82// ---------------------------------------------------------------------------
83
84/// Runtime that drives the LLM + tool-calling loop.
85///
86/// Wraps a backend and a tool registry, sends prompts with tool definitions,
87/// executes requested tools, feeds results back, and repeats until the model
88/// produces a final text response or a halt condition fires.
89///
90/// # Example
91///
92/// ```rust,ignore
93/// use pawan::coordinator::{ToolCoordinator, ToolCallingConfig};
94/// use pawan::tools::ToolRegistry;
95/// use std::sync::Arc;
96///
97/// let backend = Arc::new(my_backend);
98/// let registry = Arc::new(ToolRegistry::new());
99/// let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
100///
101/// let result = coordinator.execute(Some("You are helpful."), "What is 2+2?").await?;
102/// println!("{}", result.content);
103/// ```
104pub struct ToolCoordinator {
105    backend: Arc<dyn LlmBackend>,
106    registry: Arc<ToolRegistry>,
107    config: ToolCallingConfig,
108}
109
110impl ToolCoordinator {
111    /// Create a new `ToolCoordinator`.
112    pub fn new(
113        backend: Arc<dyn LlmBackend>,
114        registry: Arc<ToolRegistry>,
115        config: ToolCallingConfig,
116    ) -> Self {
117        Self {
118            backend,
119            registry,
120            config,
121        }
122    }
123
124    /// Execute a tool-calling session starting from a plain prompt.
125    ///
126    /// Builds an initial `[system?, user]` message list and drives the loop.
127    pub async fn execute(
128        &self,
129        system_prompt: Option<&str>,
130        user_prompt: &str,
131    ) -> crate::Result<CoordinatorResult> {
132        let mut messages: Vec<ConversationMessage> = Vec::new();
133        if let Some(sys) = system_prompt {
134            messages.push(ConversationMessage::system(sys));
135        }
136        messages.push(ConversationMessage::user(user_prompt));
137        self.execute_with_history(messages).await
138    }
139
140    /// Execute a tool-calling session from an existing message history.
141    ///
142    /// This is the primary loop: it calls the backend, dispatches tool calls,
143    /// appends results to history, and repeats until the model emits a final
144    /// text response or a halt condition fires.
145    pub async fn execute_with_history(
146        &self,
147        mut messages: Vec<ConversationMessage>,
148    ) -> crate::Result<CoordinatorResult> {
149        let tool_defs = self.registry.get_definitions();
150        let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
151        let mut total_usage = TokenUsage::default();
152
153        for iteration in 0..self.config.max_iterations {
154            // Convert coordinator messages to backend wire format.
155            let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
156
157            // Call backend — no streaming callback needed for coordinator.
158            let response = self
159                .backend
160                .generate(&backend_messages, &tool_defs, None)
161                .await?;
162
163            // Accumulate token usage.
164            if let Some(usage) = &response.usage {
165                total_usage.prompt_tokens += usage.prompt_tokens;
166                total_usage.completion_tokens += usage.completion_tokens;
167                total_usage.total_tokens += usage.total_tokens;
168                total_usage.reasoning_tokens += usage.reasoning_tokens;
169                total_usage.action_tokens += usage.action_tokens;
170            }
171
172            // Append the assistant turn to history.
173            messages.push(ConversationMessage::assistant(
174                &response.content,
175                response.tool_calls.clone(),
176            ));
177
178            // No tool calls → model is done.
179            if response.tool_calls.is_empty() {
180                return Ok(CoordinatorResult {
181                    content: response.content,
182                    tool_calls: all_tool_calls,
183                    iterations: iteration + 1,
184                    finish_reason: FinishReason::Stop,
185                    total_usage,
186                    message_history: messages,
187                });
188            }
189
190            // Empty response with tool calls is unusual but guard it.
191            if response.content.is_empty() && response.tool_calls.is_empty() {
192                return Ok(CoordinatorResult {
193                    content: String::new(),
194                    tool_calls: all_tool_calls,
195                    iterations: iteration + 1,
196                    finish_reason: FinishReason::Stop,
197                    total_usage,
198                    message_history: messages,
199                });
200            }
201
202            // Validate all requested tools exist before executing any.
203            for tc in &response.tool_calls {
204                if !self.registry.has_tool(&tc.name) {
205                    return Ok(CoordinatorResult {
206                        content: response.content,
207                        tool_calls: all_tool_calls,
208                        iterations: iteration + 1,
209                        finish_reason: FinishReason::UnknownTool(tc.name.clone()),
210                        total_usage,
211                        message_history: messages,
212                    });
213                }
214            }
215
216            // Execute tool calls (parallel or sequential per config).
217            let records = self.execute_tool_calls(&response.tool_calls).await?;
218
219            // If stop_on_error, check if any record failed.
220            if self.config.stop_on_error {
221                if let Some(failed) = records.iter().find(|r| !r.success) {
222                    let err_msg = failed
223                        .result
224                        .get("error")
225                        .and_then(|v| v.as_str())
226                        .unwrap_or("tool error")
227                        .to_string();
228                    return Ok(CoordinatorResult {
229                        content: response.content,
230                        tool_calls: all_tool_calls,
231                        iterations: iteration + 1,
232                        finish_reason: FinishReason::Error(err_msg),
233                        total_usage,
234                        message_history: messages,
235                    });
236                }
237            }
238
239            // Append tool result messages and accumulate records.
240            for record in records {
241                messages.push(ConversationMessage::tool_result(&record.id, &record.result));
242                all_tool_calls.push(record);
243            }
244        }
245
246        // Hit max iterations.
247        Ok(CoordinatorResult {
248            content: messages
249                .last()
250                .map(|m| m.content.clone())
251                .unwrap_or_default(),
252            tool_calls: all_tool_calls,
253            iterations: self.config.max_iterations,
254            finish_reason: FinishReason::MaxIterations,
255            total_usage,
256            message_history: messages,
257        })
258    }
259
260    // -----------------------------------------------------------------------
261    // Internal helpers
262    // -----------------------------------------------------------------------
263
264    async fn execute_tool_calls(
265        &self,
266        calls: &[ToolCallRequest],
267    ) -> crate::Result<Vec<ToolCallRecord>> {
268        if self.config.parallel_execution {
269            self.execute_parallel(calls).await
270        } else {
271            self.execute_sequential(calls).await
272        }
273    }
274
275    async fn execute_parallel(
276        &self,
277        calls: &[ToolCallRequest],
278    ) -> crate::Result<Vec<ToolCallRecord>> {
279        let futures = calls.iter().map(|c| self.execute_single_tool(c));
280        let results = join_all(futures).await;
281
282        let mut records = Vec::with_capacity(results.len());
283        for (i, res) in results.into_iter().enumerate() {
284            match res {
285                Ok(record) => records.push(record),
286                Err(e) if self.config.stop_on_error => return Err(e),
287                Err(e) => {
288                    // Recover: turn the error into a failed ToolCallRecord.
289                    let call = &calls[i];
290                    records.push(ToolCallRecord {
291                        id: call.id.clone(),
292                        name: call.name.clone(),
293                        arguments: call.arguments.clone(),
294                        result: serde_json::json!({"error": e.to_string()}),
295                        success: false,
296                        duration_ms: 0,
297                    });
298                }
299            }
300        }
301        Ok(records)
302    }
303
304    async fn execute_sequential(
305        &self,
306        calls: &[ToolCallRequest],
307    ) -> crate::Result<Vec<ToolCallRecord>> {
308        let mut records = Vec::with_capacity(calls.len());
309        for call in calls {
310            match self.execute_single_tool(call).await {
311                Ok(record) => records.push(record),
312                Err(e) if self.config.stop_on_error => return Err(e),
313                Err(e) => {
314                    records.push(ToolCallRecord {
315                        id: call.id.clone(),
316                        name: call.name.clone(),
317                        arguments: call.arguments.clone(),
318                        result: serde_json::json!({"error": e.to_string()}),
319                        success: false,
320                        duration_ms: 0,
321                    });
322                }
323            }
324        }
325        Ok(records)
326    }
327
328    async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
329        let start = Instant::now();
330
331        let result = timeout(
332            self.config.tool_timeout,
333            self.registry.execute(&call.name, call.arguments.clone()),
334        )
335        .await;
336
337        let duration_ms = start.elapsed().as_millis() as u64;
338
339        match result {
340            Ok(Ok(value)) => Ok(ToolCallRecord {
341                id: call.id.clone(),
342                name: call.name.clone(),
343                arguments: call.arguments.clone(),
344                result: value,
345                success: true,
346                duration_ms,
347            }),
348            Ok(Err(e)) => Ok(ToolCallRecord {
349                id: call.id.clone(),
350                name: call.name.clone(),
351                arguments: call.arguments.clone(),
352                result: serde_json::json!({"error": e.to_string()}),
353                success: false,
354                duration_ms,
355            }),
356            Err(_elapsed) => Ok(ToolCallRecord {
357                id: call.id.clone(),
358                name: call.name.clone(),
359                arguments: call.arguments.clone(),
360                result: serde_json::json!({"error": "tool execution timed out"}),
361                success: false,
362                duration_ms,
363            }),
364        }
365    }
366}
367
368// ---------------------------------------------------------------------------
369// Multi-agent task scheduling
370// ---------------------------------------------------------------------------
371
372/// Subagent types accepted by the task scheduler (aligned with [`crate::tools::task::TaskTool`]).
373pub const KNOWN_AGENT_TYPES: &[&str] = &[
374    "explore",
375    "plan",
376    "task",
377    "reviewer",
378    "designer",
379    "librarian",
380];
381
382/// One schedulable unit of work for a typed subagent.
383#[derive(Debug, Clone, PartialEq, Eq)]
384pub struct ScheduledTask {
385    pub id: String,
386    pub agent_type: String,
387    pub assignment: String,
388}
389
390/// Validation / scheduling failure before any task is dispatched.
391#[derive(Debug, Clone, PartialEq, Eq)]
392pub enum ScheduleError {
393    EmptyTaskList,
394    InvalidAgentType(String),
395    DuplicateTaskId(String),
396}
397
398impl std::fmt::Display for ScheduleError {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self {
401            ScheduleError::EmptyTaskList => write!(f, "task list must not be empty"),
402            ScheduleError::InvalidAgentType(agent) => write!(
403                f,
404                "unknown agent type '{agent}'. Valid types: {}",
405                KNOWN_AGENT_TYPES.join(", ")
406            ),
407            ScheduleError::DuplicateTaskId(id) => write!(f, "duplicate task id '{id}'"),
408        }
409    }
410}
411
412/// Validate a batch before dispatching. Checks non-empty list, known agent types,
413/// and unique task ids.
414pub fn validate_task_schedule(tasks: &[ScheduledTask]) -> Result<(), ScheduleError> {
415    if tasks.is_empty() {
416        return Err(ScheduleError::EmptyTaskList);
417    }
418
419    let mut seen_ids = HashSet::with_capacity(tasks.len());
420    for task in tasks {
421        if !KNOWN_AGENT_TYPES.contains(&task.agent_type.as_str()) {
422            return Err(ScheduleError::InvalidAgentType(task.agent_type.clone()));
423        }
424        if !seen_ids.insert(task.id.clone()) {
425            return Err(ScheduleError::DuplicateTaskId(task.id.clone()));
426        }
427        if task.assignment.trim().is_empty() {
428            return Err(ScheduleError::InvalidAgentType(
429                "assignment must be non-empty".into(),
430            ));
431        }
432    }
433    Ok(())
434}
435
436/// Outcome of one successfully scheduled task.
437#[derive(Debug, Clone, PartialEq, Eq)]
438pub struct ScheduledTaskResult {
439    pub id: String,
440    pub output: Value,
441}
442
443/// Executes validated [`ScheduledTask`] items (used by tests and future batch dispatch).
444#[async_trait]
445pub trait TaskRunner: Send + Sync {
446    async fn run(&self, task: &ScheduledTask) -> crate::Result<Value>;
447}
448
449/// Coordinates validation and sequential dispatch of multi-agent task batches.
450pub struct TaskScheduleCoordinator<R> {
451    runner: Arc<R>,
452}
453
454impl<R: TaskRunner> TaskScheduleCoordinator<R> {
455    pub fn new(runner: Arc<R>) -> Self {
456        Self { runner }
457    }
458
459    /// Validate `tasks`, then run each item through the configured runner.
460    pub async fn schedule(
461        &self,
462        tasks: &[ScheduledTask],
463    ) -> Result<Vec<ScheduledTaskResult>, ScheduleError> {
464        validate_task_schedule(tasks)?;
465
466        let mut results = Vec::with_capacity(tasks.len());
467        for task in tasks {
468            let output = self.runner.run(task).await.map_err(|e| {
469                ScheduleError::InvalidAgentType(format!("task '{}' failed: {e}", task.id))
470            })?;
471            results.push(ScheduledTaskResult {
472                id: task.id.clone(),
473                output,
474            });
475        }
476        Ok(results)
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use std::sync::Arc;
484
485    /// No tools available — model replies with plain text on the first turn.
486    /// Verifies that the coordinator terminates cleanly and returns the model
487    /// text as `content` with `FinishReason::Stop` and zero tool calls.
488    #[tokio::test]
489    async fn execute_with_empty_registry_returns_model_response() {
490        use crate::agent::backend::mock::MockBackend;
491
492        let backend = Arc::new(MockBackend::with_text("Hello, world!"));
493        let registry = Arc::new(ToolRegistry::new());
494        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
495
496        let result = coordinator
497            .execute(None, "Say hello")
498            .await
499            .expect("coordinator should not error");
500
501        assert_eq!(result.content, "Hello, world!");
502        assert_eq!(result.finish_reason, FinishReason::Stop);
503        assert_eq!(result.iterations, 1);
504        assert!(result.tool_calls.is_empty());
505        // History: [user, assistant]
506        assert_eq!(result.message_history.len(), 2);
507    }
508
509    /// Pin the `ToolCallingConfig` defaults so regressions are caught.
510    #[test]
511    fn tool_calling_config_defaults_are_sensible() {
512        use std::time::Duration;
513        let cfg = ToolCallingConfig::default();
514        assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
515        assert!(
516            cfg.parallel_execution,
517            "parallel_execution should default to true"
518        );
519        assert_eq!(
520            cfg.tool_timeout,
521            Duration::from_secs(30),
522            "tool_timeout default changed"
523        );
524        assert!(!cfg.stop_on_error, "stop_on_error should default to false");
525    }
526
527    /// The coordinator must fire `FinishReason::MaxIterations` when the model
528    /// keeps requesting tool calls and we exhaust the iteration budget.
529    /// Uses a mock backend that always returns a tool-call response for a
530    /// registered no-op tool, driving the loop to the configured cap.
531    #[tokio::test]
532    async fn coordinator_result_captures_finish_reason_max_iterations() {
533        use crate::agent::backend::mock::{MockBackend, MockResponse};
534        use crate::tools::Tool;
535        use async_trait::async_trait;
536        use serde_json::Value;
537
538        // A trivial no-op tool that always succeeds.
539        struct NoOpTool;
540
541        #[async_trait]
542        impl Tool for NoOpTool {
543            fn name(&self) -> &str {
544                "noop"
545            }
546            fn description(&self) -> &str {
547                "does nothing"
548            }
549            fn parameters_schema(&self) -> Value {
550                serde_json::json!({"type": "object", "properties": {}})
551            }
552            async fn execute(&self, _args: Value) -> crate::Result<Value> {
553                Ok(serde_json::json!({"ok": true}))
554            }
555        }
556
557        // Build a backend that always requests the noop tool (never gives a
558        // final text response), so the loop runs until max_iterations.
559        let responses: Vec<MockResponse> = (0..15)
560            .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
561            .collect();
562        let backend = Arc::new(MockBackend::new(responses));
563
564        let mut registry = ToolRegistry::new();
565        registry.register(std::sync::Arc::new(NoOpTool));
566        let registry = Arc::new(registry);
567
568        let config = ToolCallingConfig {
569            max_iterations: 3,
570            parallel_execution: false,
571            ..ToolCallingConfig::default()
572        };
573        let coordinator = ToolCoordinator::new(backend, registry, config);
574
575        let result = coordinator
576            .execute(None, "loop forever")
577            .await
578            .expect("coordinator should not hard-error");
579
580        assert_eq!(
581            result.finish_reason,
582            FinishReason::MaxIterations,
583            "expected MaxIterations, got {:?}",
584            result.finish_reason
585        );
586        assert_eq!(result.iterations, 3);
587        // Each iteration dispatches one noop tool call.
588        assert_eq!(result.tool_calls.len(), 3);
589        assert!(result.tool_calls.iter().all(|tc| tc.success));
590    }
591
592    /// When the model requests a tool that is not registered, the coordinator must
593    /// halt immediately with `FinishReason::UnknownTool` and must not execute anything.
594    #[tokio::test]
595    async fn test_unknown_tool_validation_returns_unknown_tool_finish_reason() {
596        use crate::agent::backend::mock::MockBackend;
597
598        let backend = Arc::new(MockBackend::with_tool_call(
599            "call_ghost",
600            "definitely_not_registered",
601            serde_json::json!({}),
602            "should not reach this",
603        ));
604        let registry = Arc::new(ToolRegistry::new());
605        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
606
607        let result = coordinator
608            .execute(None, "use a ghost tool")
609            .await
610            .expect("unknown tool should surface as a coordinator result, not a hard error");
611
612        assert_eq!(
613            result.finish_reason,
614            FinishReason::UnknownTool("definitely_not_registered".into())
615        );
616        assert!(
617            result.tool_calls.is_empty(),
618            "unknown tool must not be executed"
619        );
620        assert_eq!(result.iterations, 1);
621    }
622
623    /// With `stop_on_error = true`, a tool that returns `Err` must end the session
624    /// with `FinishReason::Error` rather than continuing the loop.
625    #[tokio::test]
626    async fn test_stop_on_error_halts_on_failed_tool_execution() {
627        use crate::agent::backend::mock::{MockBackend, MockResponse};
628        use crate::tools::Tool;
629        use async_trait::async_trait;
630        use serde_json::Value;
631
632        struct FailingTool;
633
634        #[async_trait]
635        impl Tool for FailingTool {
636            fn name(&self) -> &str {
637                "fail_me"
638            }
639            fn description(&self) -> &str {
640                "always fails"
641            }
642            fn parameters_schema(&self) -> Value {
643                serde_json::json!({"type": "object", "properties": {}})
644            }
645            async fn execute(&self, _args: Value) -> crate::Result<Value> {
646                Err(crate::PawanError::Tool("intentional failure".into()))
647            }
648        }
649
650        let backend = Arc::new(MockBackend::new(vec![
651            MockResponse::tool_call("fail_me", serde_json::json!({})),
652            MockResponse::text("unreachable"),
653        ]));
654
655        let mut registry = ToolRegistry::new();
656        registry.register(Arc::new(FailingTool));
657        let registry = Arc::new(registry);
658
659        let config = ToolCallingConfig {
660            stop_on_error: true,
661            parallel_execution: false,
662            ..ToolCallingConfig::default()
663        };
664        let coordinator = ToolCoordinator::new(backend, registry, config);
665
666        let result = coordinator
667            .execute(None, "trigger failure")
668            .await
669            .expect("stop_on_error should return Ok with Error finish reason");
670
671        match &result.finish_reason {
672            FinishReason::Error(msg) => {
673                assert!(
674                    msg.contains("intentional failure"),
675                    "error message should propagate from tool, got: {}",
676                    msg
677                );
678            }
679            other => panic!("expected FinishReason::Error, got {:?}", other),
680        }
681        assert_eq!(result.iterations, 1);
682    }
683
684    /// Per-tool timeout must produce a failed record rather than hanging the session.
685    #[tokio::test]
686    async fn test_tool_timeout_records_failed_tool_call() {
687        use crate::agent::backend::mock::{MockBackend, MockResponse};
688        use crate::tools::Tool;
689        use async_trait::async_trait;
690        use serde_json::Value;
691        use std::time::Duration;
692
693        struct SlowTool;
694
695        #[async_trait]
696        impl Tool for SlowTool {
697            fn name(&self) -> &str {
698                "slow_tool"
699            }
700            fn description(&self) -> &str {
701                "sleeps longer than the coordinator timeout"
702            }
703            fn parameters_schema(&self) -> Value {
704                serde_json::json!({"type": "object", "properties": {}})
705            }
706            async fn execute(&self, _args: Value) -> crate::Result<Value> {
707                tokio::time::sleep(Duration::from_secs(2)).await;
708                Ok(serde_json::json!({"ok": true}))
709            }
710        }
711
712        let backend = Arc::new(MockBackend::new(vec![
713            MockResponse::tool_call("slow_tool", serde_json::json!({})),
714            MockResponse::text("done after timeout"),
715        ]));
716
717        let mut registry = ToolRegistry::new();
718        registry.register(Arc::new(SlowTool));
719        let registry = Arc::new(registry);
720
721        let config = ToolCallingConfig {
722            tool_timeout: Duration::from_millis(50),
723            parallel_execution: false,
724            ..ToolCallingConfig::default()
725        };
726        let coordinator = ToolCoordinator::new(backend, registry, config);
727
728        let result = coordinator
729            .execute(None, "run slow tool")
730            .await
731            .expect("timeout should be absorbed into a failed tool record");
732
733        assert_eq!(result.tool_calls.len(), 1);
734        let record = &result.tool_calls[0];
735        assert!(
736            !record.success,
737            "timed-out tool must be marked unsuccessful"
738        );
739        assert_eq!(
740            record.result.get("error").and_then(|v| v.as_str()),
741            Some("tool execution timed out")
742        );
743        // Loop continues after a non-fatal timeout (stop_on_error defaults to false).
744        assert_eq!(result.finish_reason, FinishReason::Stop);
745        assert_eq!(result.iterations, 2);
746    }
747
748    /// `execute` with a system prompt must prepend a system message to history.
749    #[tokio::test]
750    async fn test_execute_with_system_prompt_prepends_system_message() {
751        use crate::agent::backend::mock::MockBackend;
752
753        let backend = Arc::new(MockBackend::with_text("acknowledged"));
754        let registry = Arc::new(ToolRegistry::new());
755        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
756
757        let result = coordinator
758            .execute(Some("be concise"), "hello")
759            .await
760            .expect("execute should succeed");
761
762        assert_eq!(result.message_history.len(), 3);
763        assert_eq!(result.message_history[0].role, Role::System);
764        assert_eq!(result.message_history[0].content, "be concise");
765        assert_eq!(result.message_history[1].role, Role::User);
766        assert_eq!(result.message_history[1].content, "hello");
767        assert_eq!(result.message_history[2].role, Role::Assistant);
768    }
769
770    /// Token usage reported by the backend must be captured in `total_usage`.
771    #[tokio::test]
772    async fn test_token_usage_captured_from_backend_response() {
773        use crate::agent::backend::mock::{MockBackend, MockResponse};
774        use crate::tools::Tool;
775        use async_trait::async_trait;
776        use serde_json::Value;
777
778        struct NoOpTool;
779
780        #[async_trait]
781        impl Tool for NoOpTool {
782            fn name(&self) -> &str {
783                "noop"
784            }
785            fn description(&self) -> &str {
786                "does nothing"
787            }
788            fn parameters_schema(&self) -> Value {
789                serde_json::json!({"type": "object", "properties": {}})
790            }
791            async fn execute(&self, _args: Value) -> crate::Result<Value> {
792                Ok(serde_json::json!({"ok": true}))
793            }
794        }
795
796        let backend = Arc::new(MockBackend::new(vec![
797            MockResponse::tool_call("noop", serde_json::json!({})),
798            MockResponse::TextWithUsage {
799                text: "done".into(),
800                usage: TokenUsage {
801                    prompt_tokens: 20,
802                    completion_tokens: 8,
803                    total_tokens: 28,
804                    reasoning_tokens: 3,
805                    action_tokens: 5,
806                },
807            },
808        ]));
809
810        let mut registry = ToolRegistry::new();
811        registry.register(Arc::new(NoOpTool));
812        let registry = Arc::new(registry);
813        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
814
815        let result = coordinator
816            .execute(None, "count tokens")
817            .await
818            .expect("execute should succeed");
819
820        assert_eq!(result.total_usage.prompt_tokens, 20);
821        assert_eq!(result.total_usage.completion_tokens, 8);
822        assert_eq!(result.total_usage.total_tokens, 28);
823        assert_eq!(result.total_usage.reasoning_tokens, 3);
824        assert_eq!(result.total_usage.action_tokens, 5);
825        assert_eq!(result.iterations, 2);
826    }
827
828    /// Parallel execution must dispatch every tool call in a single assistant turn.
829    #[tokio::test]
830    async fn test_parallel_execution_dispatches_multiple_tools_in_one_turn() {
831        use crate::agent::backend::mock::MockBackend;
832        use crate::tools::Tool;
833        use async_trait::async_trait;
834        use serde_json::Value;
835
836        struct EchoTool {
837            suffix: &'static str,
838        }
839
840        #[async_trait]
841        impl Tool for EchoTool {
842            fn name(&self) -> &str {
843                self.suffix
844            }
845            fn description(&self) -> &str {
846                "echoes a suffix"
847            }
848            fn parameters_schema(&self) -> Value {
849                serde_json::json!({"type": "object", "properties": {}})
850            }
851            async fn execute(&self, _args: Value) -> crate::Result<Value> {
852                Ok(serde_json::json!({ "tool": self.suffix }))
853            }
854        }
855
856        let backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
857            ("call_a", "echo_a", serde_json::json!({})),
858            ("call_b", "echo_b", serde_json::json!({})),
859        ]));
860
861        let mut registry = ToolRegistry::new();
862        registry.register(Arc::new(EchoTool { suffix: "echo_a" }));
863        registry.register(Arc::new(EchoTool { suffix: "echo_b" }));
864        let registry = Arc::new(registry);
865
866        let config = ToolCallingConfig {
867            parallel_execution: true,
868            ..ToolCallingConfig::default()
869        };
870        let coordinator = ToolCoordinator::new(backend, registry, config);
871
872        let result = coordinator
873            .execute(None, "run both")
874            .await
875            .expect("parallel tool execution should succeed");
876
877        assert_eq!(result.tool_calls.len(), 2);
878        assert!(result.tool_calls.iter().all(|r| r.success));
879        let names: Vec<&str> = result.tool_calls.iter().map(|r| r.name.as_str()).collect();
880        assert!(names.contains(&"echo_a"));
881        assert!(names.contains(&"echo_b"));
882        assert_eq!(result.finish_reason, FinishReason::Stop);
883        assert_eq!(result.iterations, 2);
884    }
885
886    // -----------------------------------------------------------------------
887    // Task scheduling edge cases (mock runner)
888    // -----------------------------------------------------------------------
889
890    use async_trait::async_trait;
891    use serde_json::json;
892    use std::sync::Mutex;
893
894    struct MockTaskRunner {
895        dispatched: Mutex<Vec<String>>,
896    }
897
898    impl MockTaskRunner {
899        fn new() -> Self {
900            Self {
901                dispatched: Mutex::new(Vec::new()),
902            }
903        }
904
905        fn dispatched_ids(&self) -> Vec<String> {
906            self.dispatched.lock().unwrap().clone()
907        }
908    }
909
910    #[async_trait]
911    impl TaskRunner for MockTaskRunner {
912        async fn run(&self, task: &ScheduledTask) -> crate::Result<Value> {
913            self.dispatched.lock().unwrap().push(task.id.clone());
914            Ok(json!({
915                "id": task.id,
916                "agent": task.agent_type,
917                "assignment": task.assignment,
918            }))
919        }
920    }
921
922    /// Scheduling must reject an empty task list before touching the runner.
923    #[tokio::test]
924    async fn schedule_empty_task_list_rejects_without_dispatch() {
925        let runner = Arc::new(MockTaskRunner::new());
926        let coordinator = TaskScheduleCoordinator::new(runner.clone());
927
928        let err = coordinator
929            .schedule(&[])
930            .await
931            .expect_err("empty task list should fail validation");
932
933        assert_eq!(err, ScheduleError::EmptyTaskList);
934        assert!(runner.dispatched_ids().is_empty());
935    }
936
937    /// Unknown agent types must fail validation and leave the runner idle.
938    #[tokio::test]
939    async fn schedule_invalid_agent_type_rejects_without_dispatch() {
940        let runner = Arc::new(MockTaskRunner::new());
941        let coordinator = TaskScheduleCoordinator::new(runner.clone());
942
943        let tasks = [ScheduledTask {
944            id: "AuthProbe".into(),
945            agent_type: "not_a_real_agent".into(),
946            assignment: "probe auth".into(),
947        }];
948
949        let err = coordinator
950            .schedule(&tasks)
951            .await
952            .expect_err("invalid agent type should fail validation");
953
954        assert_eq!(
955            err,
956            ScheduleError::InvalidAgentType("not_a_real_agent".into())
957        );
958        assert!(runner.dispatched_ids().is_empty());
959    }
960
961    /// Duplicate task ids must be rejected before any work is dispatched.
962    #[tokio::test]
963    async fn schedule_duplicate_task_ids_rejects_without_dispatch() {
964        let runner = Arc::new(MockTaskRunner::new());
965        let coordinator = TaskScheduleCoordinator::new(runner.clone());
966
967        let tasks = [
968            ScheduledTask {
969                id: "DupId".into(),
970                agent_type: "explore".into(),
971                assignment: "first".into(),
972            },
973            ScheduledTask {
974                id: "DupId".into(),
975                agent_type: "plan".into(),
976                assignment: "second".into(),
977            },
978        ];
979
980        let err = coordinator
981            .schedule(&tasks)
982            .await
983            .expect_err("duplicate ids should fail validation");
984
985        assert_eq!(err, ScheduleError::DuplicateTaskId("DupId".into()));
986        assert!(runner.dispatched_ids().is_empty());
987    }
988
989    /// Valid tasks must be dispatched through the mock runner in order.
990    #[tokio::test]
991    async fn schedule_valid_tasks_dispatches_via_mock_runner() {
992        let runner = Arc::new(MockTaskRunner::new());
993        let coordinator = TaskScheduleCoordinator::new(runner.clone());
994
995        let tasks = [
996            ScheduledTask {
997                id: "Alpha".into(),
998                agent_type: "explore".into(),
999                assignment: "scan src/".into(),
1000            },
1001            ScheduledTask {
1002                id: "Beta".into(),
1003                agent_type: "plan".into(),
1004                assignment: "draft refactor".into(),
1005            },
1006        ];
1007
1008        let results = coordinator
1009            .schedule(&tasks)
1010            .await
1011            .expect("valid schedule should succeed");
1012
1013        assert_eq!(runner.dispatched_ids(), vec!["Alpha", "Beta"]);
1014        assert_eq!(results.len(), 2);
1015        assert_eq!(results[0].id, "Alpha");
1016        assert_eq!(results[0].output["agent"], "explore");
1017        assert_eq!(results[1].id, "Beta");
1018        assert_eq!(results[1].output["agent"], "plan");
1019    }
1020}