Skip to main content

pawan/coordinator/
mod.rs

1//! Multi-turn tool coordinator — data types and runtime.
2//!
3//! Provides a provider-agnostic orchestration layer for agent tool-calling
4//! loops: send a prompt with tool definitions, handle tool call requests,
5//! execute tools, feed results back, repeat until the model produces a final
6//! response or hits an iteration cap.
7//!
8//! Types reused from [`crate::agent`]:
9//! - [`ToolCallRequest`] — what the model asks for
10//! - [`ToolCallRecord`]  — what actually happened
11//! - [`TokenUsage`]      — accumulated counts
12//!
13//! Types defined here:
14//! - [`ToolCallingConfig`]   — iteration / parallelism / timeout knobs
15//! - [`FinishReason`]        — why the session ended
16//! - [`Role`]              — re-exported from `crate::agent` (system/user/assistant/tool)
17//! - [`ConversationMessage`] — a single turn in the history
18//! - [`CoordinatorResult`]   — everything the caller gets back
19//! - [`ToolCoordinator`]     — the runtime that drives the LLM+tool loop
20//! - [`TaskScheduleCoordinator`] — validates and dispatches multi-agent task batches
21//! - [`ScheduledTask`] / [`ScheduleError`] — task scheduling wire types
22//!
23//! ## Design notes
24//!
25//! - [`ToolCallRecord`] is reused from [`crate::agent`] rather than duplicated.
26//!   Failed tool calls land in `result` as a `{"error": "..."}` JSON object
27//!   with `success: false`, matching pawan's existing agent loop — there's no
28//!   separate `error` field on the record.
29//! - [`ConversationMessage::tool_call_id`] is only populated on [`Role::Tool`]
30//!   turns and links the result back to the assistant message that requested it.
31
32pub mod types;
33pub use types::*;
34
35use crate::agent::backend::LlmBackend;
36use crate::agent::{Message, Role, TokenUsage, ToolCallRecord, ToolCallRequest, ToolResultMessage};
37use crate::tools::ToolRegistry;
38use async_trait::async_trait;
39use futures::future::join_all;
40use serde_json::Value;
41use std::collections::HashSet;
42use std::sync::Arc;
43use std::time::Instant;
44use tokio::time::timeout;
45
46// ---------------------------------------------------------------------------
47// Type bridge: ConversationMessage → agent::Message
48// ---------------------------------------------------------------------------
49
50/// Convert a [`ConversationMessage`] to the backend's [`Message`] type.
51///
52/// The coordinator tracks history in its own `ConversationMessage` type, but
53/// `LlmBackend::generate()` expects `&[agent::Message]`. This function maps
54/// the coordinator's richer type to the backend wire format:
55///
56/// - `Tool` role messages: parse `content` back to JSON and populate
57///   `Message::tool_result` with a `ToolResultMessage`.
58/// - `Assistant` messages: copy `tool_calls` directly (same type).
59/// - `System`/`User` messages: straightforward role + content copy.
60fn to_backend_message(msg: &ConversationMessage) -> Message {
61    let tool_result = if msg.role == Role::Tool {
62        msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
63            tool_call_id: id.clone(),
64            content: serde_json::from_str(&msg.content)
65                .unwrap_or(serde_json::Value::String(msg.content.clone())),
66            success: true,
67        })
68    } else {
69        None
70    };
71
72    Message {
73        role: msg.role.clone(),
74        content: msg.content.clone(),
75        tool_calls: msg.tool_calls.clone(),
76        tool_result,
77    }
78}
79
80// ---------------------------------------------------------------------------
81// ToolCoordinator runtime
82// ---------------------------------------------------------------------------
83
84/// Runtime that drives the LLM + tool-calling loop.
85///
86/// Wraps a backend and a tool registry, sends prompts with tool definitions,
87/// executes requested tools, feeds results back, and repeats until the model
88/// produces a final text response or a halt condition fires.
89///
90/// # Example
91///
92/// ```rust,ignore
93/// use pawan::coordinator::{ToolCoordinator, ToolCallingConfig};
94/// use pawan::tools::ToolRegistry;
95/// use std::sync::Arc;
96///
97/// let backend = Arc::new(my_backend);
98/// let registry = Arc::new(ToolRegistry::new());
99/// let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
100///
101/// let result = coordinator.execute(Some("You are helpful."), "What is 2+2?").await?;
102/// println!("{}", result.content);
103/// ```
104pub struct ToolCoordinator {
105    backend: Arc<dyn LlmBackend>,
106    registry: Arc<ToolRegistry>,
107    config: ToolCallingConfig,
108}
109
110impl ToolCoordinator {
111    /// Create a new `ToolCoordinator`.
112    pub fn new(
113        backend: Arc<dyn LlmBackend>,
114        registry: Arc<ToolRegistry>,
115        config: ToolCallingConfig,
116    ) -> Self {
117        Self {
118            backend,
119            registry,
120            config,
121        }
122    }
123
124    /// Execute a tool-calling session starting from a plain prompt.
125    ///
126    /// Builds an initial `[system?, user]` message list and drives the loop.
127    pub async fn execute(
128        &self,
129        system_prompt: Option<&str>,
130        user_prompt: &str,
131    ) -> crate::Result<CoordinatorResult> {
132        let mut messages: Vec<ConversationMessage> = Vec::new();
133        if let Some(sys) = system_prompt {
134            messages.push(ConversationMessage::system(sys));
135        }
136        messages.push(ConversationMessage::user(user_prompt));
137        self.execute_with_history(messages).await
138    }
139
140    /// Execute a tool-calling session from an existing message history.
141    ///
142    /// This is the primary loop: it calls the backend, dispatches tool calls,
143    /// appends results to history, and repeats until the model emits a final
144    /// text response or a halt condition fires.
145    pub async fn execute_with_history(
146        &self,
147        mut messages: Vec<ConversationMessage>,
148    ) -> crate::Result<CoordinatorResult> {
149        let tool_defs = self.registry.get_definitions();
150        let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
151        let mut total_usage = TokenUsage::default();
152
153        for iteration in 0..self.config.max_iterations {
154            // Convert coordinator messages to backend wire format.
155            let backend_messages: Vec<Message> = messages.iter().map(to_backend_message).collect();
156
157            // Call backend — no streaming callback needed for coordinator.
158            let response = self
159                .backend
160                .generate(&backend_messages, &tool_defs, None)
161                .await?;
162
163            // Accumulate token usage.
164            if let Some(usage) = &response.usage {
165                total_usage.prompt_tokens += usage.prompt_tokens;
166                total_usage.completion_tokens += usage.completion_tokens;
167                total_usage.total_tokens += usage.total_tokens;
168                total_usage.reasoning_tokens += usage.reasoning_tokens;
169                total_usage.action_tokens += usage.action_tokens;
170            }
171
172            // Append the assistant turn to history.
173            messages.push(ConversationMessage::assistant(
174                &response.content,
175                response.tool_calls.clone(),
176            ));
177
178            // No tool calls → model is done.
179            if response.tool_calls.is_empty() {
180                return Ok(CoordinatorResult {
181                    content: response.content,
182                    tool_calls: all_tool_calls,
183                    iterations: iteration + 1,
184                    finish_reason: FinishReason::Stop,
185                    total_usage,
186                    message_history: messages,
187                });
188            }
189
190            // Empty response with tool calls is unusual but guard it.
191            if response.content.is_empty() && response.tool_calls.is_empty() {
192                return Ok(CoordinatorResult {
193                    content: String::new(),
194                    tool_calls: all_tool_calls,
195                    iterations: iteration + 1,
196                    finish_reason: FinishReason::Stop,
197                    total_usage,
198                    message_history: messages,
199                });
200            }
201
202            // Validate all requested tools exist before executing any.
203            for tc in &response.tool_calls {
204                if !self.registry.has_tool(&tc.name) {
205                    return Ok(CoordinatorResult {
206                        content: response.content,
207                        tool_calls: all_tool_calls,
208                        iterations: iteration + 1,
209                        finish_reason: FinishReason::UnknownTool(tc.name.clone()),
210                        total_usage,
211                        message_history: messages,
212                    });
213                }
214            }
215
216            // Execute tool calls (parallel or sequential per config).
217            let records = self.execute_tool_calls(&response.tool_calls).await?;
218
219            // If stop_on_error, check if any record failed.
220            if self.config.stop_on_error {
221                if let Some(failed) = records.iter().find(|r| !r.success) {
222                    let err_msg = failed
223                        .result
224                        .get("error")
225                        .and_then(|v| v.as_str())
226                        .unwrap_or("tool error")
227                        .to_string();
228                    return Ok(CoordinatorResult {
229                        content: response.content,
230                        tool_calls: all_tool_calls,
231                        iterations: iteration + 1,
232                        finish_reason: FinishReason::Error(err_msg),
233                        total_usage,
234                        message_history: messages,
235                    });
236                }
237            }
238
239            // Append tool result messages and accumulate records.
240            for record in records {
241                messages.push(ConversationMessage::tool_result(&record.id, &record.result));
242                all_tool_calls.push(record);
243            }
244        }
245
246        // Hit max iterations.
247        Ok(CoordinatorResult {
248            content: messages
249                .last()
250                .map(|m| m.content.clone())
251                .unwrap_or_default(),
252            tool_calls: all_tool_calls,
253            iterations: self.config.max_iterations,
254            finish_reason: FinishReason::MaxIterations,
255            total_usage,
256            message_history: messages,
257        })
258    }
259
260    // -----------------------------------------------------------------------
261    // Internal helpers
262    // -----------------------------------------------------------------------
263
264    async fn execute_tool_calls(
265        &self,
266        calls: &[ToolCallRequest],
267    ) -> crate::Result<Vec<ToolCallRecord>> {
268        if self.config.parallel_execution {
269            self.execute_parallel(calls).await
270        } else {
271            self.execute_sequential(calls).await
272        }
273    }
274
275    async fn execute_parallel(
276        &self,
277        calls: &[ToolCallRequest],
278    ) -> crate::Result<Vec<ToolCallRecord>> {
279        let futures = calls.iter().map(|c| self.execute_single_tool(c));
280        let results = join_all(futures).await;
281
282        let mut records = Vec::with_capacity(results.len());
283        for (i, res) in results.into_iter().enumerate() {
284            match res {
285                Ok(record) => records.push(record),
286                Err(e) if self.config.stop_on_error => return Err(e),
287                Err(e) => {
288                    // Recover: turn the error into a failed ToolCallRecord.
289                    let call = &calls[i];
290                    records.push(ToolCallRecord {
291                        id: call.id.clone(),
292                        name: call.name.clone(),
293                        arguments: call.arguments.clone(),
294                        result: serde_json::json!({"error": e.to_string()}),
295                        success: false,
296                        duration_ms: 0,
297                    });
298                }
299            }
300        }
301        Ok(records)
302    }
303
304    async fn execute_sequential(
305        &self,
306        calls: &[ToolCallRequest],
307    ) -> crate::Result<Vec<ToolCallRecord>> {
308        let mut records = Vec::with_capacity(calls.len());
309        for call in calls {
310            match self.execute_single_tool(call).await {
311                Ok(record) => records.push(record),
312                Err(e) if self.config.stop_on_error => return Err(e),
313                Err(e) => {
314                    records.push(ToolCallRecord {
315                        id: call.id.clone(),
316                        name: call.name.clone(),
317                        arguments: call.arguments.clone(),
318                        result: serde_json::json!({"error": e.to_string()}),
319                        success: false,
320                        duration_ms: 0,
321                    });
322                }
323            }
324        }
325        Ok(records)
326    }
327
328    async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
329        let start = Instant::now();
330
331        let result = timeout(
332            self.config.tool_timeout,
333            self.registry.execute(&call.name, call.arguments.clone()),
334        )
335        .await;
336
337        let duration_ms = start.elapsed().as_millis() as u64;
338
339        match result {
340            Ok(Ok(value)) => Ok(ToolCallRecord {
341                id: call.id.clone(),
342                name: call.name.clone(),
343                arguments: call.arguments.clone(),
344                result: value,
345                success: true,
346                duration_ms,
347            }),
348            Ok(Err(e)) => Ok(ToolCallRecord {
349                id: call.id.clone(),
350                name: call.name.clone(),
351                arguments: call.arguments.clone(),
352                result: serde_json::json!({"error": e.to_string()}),
353                success: false,
354                duration_ms,
355            }),
356            Err(_elapsed) => Ok(ToolCallRecord {
357                id: call.id.clone(),
358                name: call.name.clone(),
359                arguments: call.arguments.clone(),
360                result: serde_json::json!({"error": "tool execution timed out"}),
361                success: false,
362                duration_ms,
363            }),
364        }
365    }
366}
367
368// ---------------------------------------------------------------------------
369// Multi-agent task scheduling
370// ---------------------------------------------------------------------------
371
372/// Subagent types accepted by the task scheduler (aligned with [`crate::tools::task::TaskTool`]).
373pub const KNOWN_AGENT_TYPES: &[&str] = &[
374    "explore",
375    "plan",
376    "task",
377    "reviewer",
378    "designer",
379    "librarian",
380];
381
382/// One schedulable unit of work for a typed subagent.
383#[derive(Debug, Clone, PartialEq, Eq)]
384pub struct ScheduledTask {
385    pub id: String,
386    pub agent_type: String,
387    pub assignment: String,
388}
389
390/// Validation / scheduling failure before any task is dispatched.
391#[derive(Debug, Clone, PartialEq, Eq)]
392pub enum ScheduleError {
393    EmptyTaskList,
394    InvalidAgentType(String),
395    DuplicateTaskId(String),
396}
397
398impl std::fmt::Display for ScheduleError {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self {
401            ScheduleError::EmptyTaskList => write!(f, "task list must not be empty"),
402            ScheduleError::InvalidAgentType(agent) => write!(
403                f,
404                "unknown agent type '{agent}'. Valid types: {}",
405                KNOWN_AGENT_TYPES.join(", ")
406            ),
407            ScheduleError::DuplicateTaskId(id) => write!(f, "duplicate task id '{id}'"),
408        }
409    }
410}
411
412/// Validate a batch before dispatching. Checks non-empty list, known agent types,
413/// and unique task ids.
414pub fn validate_task_schedule(tasks: &[ScheduledTask]) -> Result<(), ScheduleError> {
415    if tasks.is_empty() {
416        return Err(ScheduleError::EmptyTaskList);
417    }
418
419    let mut seen_ids = HashSet::with_capacity(tasks.len());
420    for task in tasks {
421        if !KNOWN_AGENT_TYPES.contains(&task.agent_type.as_str()) {
422            return Err(ScheduleError::InvalidAgentType(task.agent_type.clone()));
423        }
424        if !seen_ids.insert(task.id.clone()) {
425            return Err(ScheduleError::DuplicateTaskId(task.id.clone()));
426        }
427        if task.assignment.trim().is_empty() {
428            return Err(ScheduleError::InvalidAgentType(
429                "assignment must be non-empty".into(),
430            ));
431        }
432    }
433    Ok(())
434}
435
436/// Outcome of one successfully scheduled task.
437#[derive(Debug, Clone, PartialEq, Eq)]
438pub struct ScheduledTaskResult {
439    pub id: String,
440    pub output: Value,
441}
442
443/// Executes validated [`ScheduledTask`] items (used by tests and future batch dispatch).
444#[async_trait]
445pub trait TaskRunner: Send + Sync {
446    async fn run(&self, task: &ScheduledTask) -> crate::Result<Value>;
447}
448
449/// Coordinates validation and sequential dispatch of multi-agent task batches.
450pub struct TaskScheduleCoordinator<R> {
451    runner: Arc<R>,
452}
453
454impl<R: TaskRunner> TaskScheduleCoordinator<R> {
455    pub fn new(runner: Arc<R>) -> Self {
456        Self { runner }
457    }
458
459    /// Validate `tasks`, then run each item through the configured runner.
460    pub async fn schedule(
461        &self,
462        tasks: &[ScheduledTask],
463    ) -> Result<Vec<ScheduledTaskResult>, ScheduleError> {
464        validate_task_schedule(tasks)?;
465
466        let mut results = Vec::with_capacity(tasks.len());
467        for task in tasks {
468            let output = self.runner.run(task).await.map_err(|e| {
469                ScheduleError::InvalidAgentType(format!("task '{}' failed: {e}", task.id))
470            })?;
471            results.push(ScheduledTaskResult {
472                id: task.id.clone(),
473                output,
474            });
475        }
476        Ok(results)
477    }
478}
479
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use std::sync::Arc;
485
486    /// No tools available — model replies with plain text on the first turn.
487    /// Verifies that the coordinator terminates cleanly and returns the model
488    /// text as `content` with `FinishReason::Stop` and zero tool calls.
489    #[tokio::test]
490    async fn execute_with_empty_registry_returns_model_response() {
491        use crate::agent::backend::mock::MockBackend;
492
493        let backend = Arc::new(MockBackend::with_text("Hello, world!"));
494        let registry = Arc::new(ToolRegistry::new());
495        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
496
497        let result = coordinator
498            .execute(None, "Say hello")
499            .await
500            .expect("coordinator should not error");
501
502        assert_eq!(result.content, "Hello, world!");
503        assert_eq!(result.finish_reason, FinishReason::Stop);
504        assert_eq!(result.iterations, 1);
505        assert!(result.tool_calls.is_empty());
506        // History: [user, assistant]
507        assert_eq!(result.message_history.len(), 2);
508    }
509
510    /// Pin the `ToolCallingConfig` defaults so regressions are caught.
511    #[test]
512    fn tool_calling_config_defaults_are_sensible() {
513        use std::time::Duration;
514        let cfg = ToolCallingConfig::default();
515        assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
516        assert!(
517            cfg.parallel_execution,
518            "parallel_execution should default to true"
519        );
520        assert_eq!(
521            cfg.tool_timeout,
522            Duration::from_secs(30),
523            "tool_timeout default changed"
524        );
525        assert!(!cfg.stop_on_error, "stop_on_error should default to false");
526    }
527
528    /// The coordinator must fire `FinishReason::MaxIterations` when the model
529    /// keeps requesting tool calls and we exhaust the iteration budget.
530    /// Uses a mock backend that always returns a tool-call response for a
531    /// registered no-op tool, driving the loop to the configured cap.
532    #[tokio::test]
533    async fn coordinator_result_captures_finish_reason_max_iterations() {
534        use crate::agent::backend::mock::{MockBackend, MockResponse};
535        use crate::tools::Tool;
536        use async_trait::async_trait;
537        use serde_json::Value;
538
539        // A trivial no-op tool that always succeeds.
540        struct NoOpTool;
541
542        #[async_trait]
543        impl Tool for NoOpTool {
544            fn name(&self) -> &str {
545                "noop"
546            }
547            fn description(&self) -> &str {
548                "does nothing"
549            }
550            fn parameters_schema(&self) -> Value {
551                serde_json::json!({"type": "object", "properties": {}})
552            }
553            async fn execute(&self, _args: Value) -> crate::Result<Value> {
554                Ok(serde_json::json!({"ok": true}))
555            }
556        }
557
558        // Build a backend that always requests the noop tool (never gives a
559        // final text response), so the loop runs until max_iterations.
560        let responses: Vec<MockResponse> = (0..15)
561            .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
562            .collect();
563        let backend = Arc::new(MockBackend::new(responses));
564
565        let mut registry = ToolRegistry::new();
566        registry.register(std::sync::Arc::new(NoOpTool));
567        let registry = Arc::new(registry);
568
569        let config = ToolCallingConfig {
570            max_iterations: 3,
571            parallel_execution: false,
572            ..ToolCallingConfig::default()
573        };
574        let coordinator = ToolCoordinator::new(backend, registry, config);
575
576        let result = coordinator
577            .execute(None, "loop forever")
578            .await
579            .expect("coordinator should not hard-error");
580
581        assert_eq!(
582            result.finish_reason,
583            FinishReason::MaxIterations,
584            "expected MaxIterations, got {:?}",
585            result.finish_reason
586        );
587        assert_eq!(result.iterations, 3);
588        // Each iteration dispatches one noop tool call.
589        assert_eq!(result.tool_calls.len(), 3);
590        assert!(result.tool_calls.iter().all(|tc| tc.success));
591    }
592
593    /// When the model requests a tool that is not registered, the coordinator must
594    /// halt immediately with `FinishReason::UnknownTool` and must not execute anything.
595    #[tokio::test]
596    async fn test_unknown_tool_validation_returns_unknown_tool_finish_reason() {
597        use crate::agent::backend::mock::MockBackend;
598
599        let backend = Arc::new(MockBackend::with_tool_call(
600            "call_ghost",
601            "definitely_not_registered",
602            serde_json::json!({}),
603            "should not reach this",
604        ));
605        let registry = Arc::new(ToolRegistry::new());
606        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
607
608        let result = coordinator
609            .execute(None, "use a ghost tool")
610            .await
611            .expect("unknown tool should surface as a coordinator result, not a hard error");
612
613        assert_eq!(
614            result.finish_reason,
615            FinishReason::UnknownTool("definitely_not_registered".into())
616        );
617        assert!(
618            result.tool_calls.is_empty(),
619            "unknown tool must not be executed"
620        );
621        assert_eq!(result.iterations, 1);
622    }
623
624    /// With `stop_on_error = true`, a tool that returns `Err` must end the session
625    /// with `FinishReason::Error` rather than continuing the loop.
626    #[tokio::test]
627    async fn test_stop_on_error_halts_on_failed_tool_execution() {
628        use crate::agent::backend::mock::{MockBackend, MockResponse};
629        use crate::tools::Tool;
630        use async_trait::async_trait;
631        use serde_json::Value;
632
633        struct FailingTool;
634
635        #[async_trait]
636        impl Tool for FailingTool {
637            fn name(&self) -> &str {
638                "fail_me"
639            }
640            fn description(&self) -> &str {
641                "always fails"
642            }
643            fn parameters_schema(&self) -> Value {
644                serde_json::json!({"type": "object", "properties": {}})
645            }
646            async fn execute(&self, _args: Value) -> crate::Result<Value> {
647                Err(crate::PawanError::Tool("intentional failure".into()))
648            }
649        }
650
651        let backend = Arc::new(MockBackend::new(vec![
652            MockResponse::tool_call("fail_me", serde_json::json!({})),
653            MockResponse::text("unreachable"),
654        ]));
655
656        let mut registry = ToolRegistry::new();
657        registry.register(Arc::new(FailingTool));
658        let registry = Arc::new(registry);
659
660        let config = ToolCallingConfig {
661            stop_on_error: true,
662            parallel_execution: false,
663            ..ToolCallingConfig::default()
664        };
665        let coordinator = ToolCoordinator::new(backend, registry, config);
666
667        let result = coordinator
668            .execute(None, "trigger failure")
669            .await
670            .expect("stop_on_error should return Ok with Error finish reason");
671
672        match &result.finish_reason {
673            FinishReason::Error(msg) => {
674                assert!(
675                    msg.contains("intentional failure"),
676                    "error message should propagate from tool, got: {}",
677                    msg
678                );
679            }
680            other => panic!("expected FinishReason::Error, got {:?}", other),
681        }
682        assert_eq!(result.iterations, 1);
683    }
684
685    /// Per-tool timeout must produce a failed record rather than hanging the session.
686    #[tokio::test]
687    async fn test_tool_timeout_records_failed_tool_call() {
688        use crate::agent::backend::mock::{MockBackend, MockResponse};
689        use crate::tools::Tool;
690        use async_trait::async_trait;
691        use serde_json::Value;
692        use std::time::Duration;
693
694        struct SlowTool;
695
696        #[async_trait]
697        impl Tool for SlowTool {
698            fn name(&self) -> &str {
699                "slow_tool"
700            }
701            fn description(&self) -> &str {
702                "sleeps longer than the coordinator timeout"
703            }
704            fn parameters_schema(&self) -> Value {
705                serde_json::json!({"type": "object", "properties": {}})
706            }
707            async fn execute(&self, _args: Value) -> crate::Result<Value> {
708                tokio::time::sleep(Duration::from_secs(2)).await;
709                Ok(serde_json::json!({"ok": true}))
710            }
711        }
712
713        let backend = Arc::new(MockBackend::new(vec![
714            MockResponse::tool_call("slow_tool", serde_json::json!({})),
715            MockResponse::text("done after timeout"),
716        ]));
717
718        let mut registry = ToolRegistry::new();
719        registry.register(Arc::new(SlowTool));
720        let registry = Arc::new(registry);
721
722        let config = ToolCallingConfig {
723            tool_timeout: Duration::from_millis(50),
724            parallel_execution: false,
725            ..ToolCallingConfig::default()
726        };
727        let coordinator = ToolCoordinator::new(backend, registry, config);
728
729        let result = coordinator
730            .execute(None, "run slow tool")
731            .await
732            .expect("timeout should be absorbed into a failed tool record");
733
734        assert_eq!(result.tool_calls.len(), 1);
735        let record = &result.tool_calls[0];
736        assert!(!record.success, "timed-out tool must be marked unsuccessful");
737        assert_eq!(
738            record.result.get("error").and_then(|v| v.as_str()),
739            Some("tool execution timed out")
740        );
741        // Loop continues after a non-fatal timeout (stop_on_error defaults to false).
742        assert_eq!(result.finish_reason, FinishReason::Stop);
743        assert_eq!(result.iterations, 2);
744    }
745
746    /// `execute` with a system prompt must prepend a system message to history.
747    #[tokio::test]
748    async fn test_execute_with_system_prompt_prepends_system_message() {
749        use crate::agent::backend::mock::MockBackend;
750
751        let backend = Arc::new(MockBackend::with_text("acknowledged"));
752        let registry = Arc::new(ToolRegistry::new());
753        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
754
755        let result = coordinator
756            .execute(Some("be concise"), "hello")
757            .await
758            .expect("execute should succeed");
759
760        assert_eq!(result.message_history.len(), 3);
761        assert_eq!(result.message_history[0].role, Role::System);
762        assert_eq!(result.message_history[0].content, "be concise");
763        assert_eq!(result.message_history[1].role, Role::User);
764        assert_eq!(result.message_history[1].content, "hello");
765        assert_eq!(result.message_history[2].role, Role::Assistant);
766    }
767
768    /// Token usage reported by the backend must be captured in `total_usage`.
769    #[tokio::test]
770    async fn test_token_usage_captured_from_backend_response() {
771        use crate::agent::backend::mock::{MockBackend, MockResponse};
772        use crate::tools::Tool;
773        use async_trait::async_trait;
774        use serde_json::Value;
775
776        struct NoOpTool;
777
778        #[async_trait]
779        impl Tool for NoOpTool {
780            fn name(&self) -> &str {
781                "noop"
782            }
783            fn description(&self) -> &str {
784                "does nothing"
785            }
786            fn parameters_schema(&self) -> Value {
787                serde_json::json!({"type": "object", "properties": {}})
788            }
789            async fn execute(&self, _args: Value) -> crate::Result<Value> {
790                Ok(serde_json::json!({"ok": true}))
791            }
792        }
793
794        let backend = Arc::new(MockBackend::new(vec![
795            MockResponse::tool_call("noop", serde_json::json!({})),
796            MockResponse::TextWithUsage {
797                text: "done".into(),
798                usage: TokenUsage {
799                    prompt_tokens: 20,
800                    completion_tokens: 8,
801                    total_tokens: 28,
802                    reasoning_tokens: 3,
803                    action_tokens: 5,
804                },
805            },
806        ]));
807
808        let mut registry = ToolRegistry::new();
809        registry.register(Arc::new(NoOpTool));
810        let registry = Arc::new(registry);
811        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
812
813        let result = coordinator
814            .execute(None, "count tokens")
815            .await
816            .expect("execute should succeed");
817
818        assert_eq!(result.total_usage.prompt_tokens, 20);
819        assert_eq!(result.total_usage.completion_tokens, 8);
820        assert_eq!(result.total_usage.total_tokens, 28);
821        assert_eq!(result.total_usage.reasoning_tokens, 3);
822        assert_eq!(result.total_usage.action_tokens, 5);
823        assert_eq!(result.iterations, 2);
824    }
825
826    /// Parallel execution must dispatch every tool call in a single assistant turn.
827    #[tokio::test]
828    async fn test_parallel_execution_dispatches_multiple_tools_in_one_turn() {
829        use crate::agent::backend::mock::MockBackend;
830        use crate::tools::Tool;
831        use async_trait::async_trait;
832        use serde_json::Value;
833
834        struct EchoTool {
835            suffix: &'static str,
836        }
837
838        #[async_trait]
839        impl Tool for EchoTool {
840            fn name(&self) -> &str {
841                self.suffix
842            }
843            fn description(&self) -> &str {
844                "echoes a suffix"
845            }
846            fn parameters_schema(&self) -> Value {
847                serde_json::json!({"type": "object", "properties": {}})
848            }
849            async fn execute(&self, _args: Value) -> crate::Result<Value> {
850                Ok(serde_json::json!({ "tool": self.suffix }))
851            }
852        }
853
854        let backend = Arc::new(MockBackend::with_multiple_tool_calls(vec![
855            ("call_a", "echo_a", serde_json::json!({})),
856            ("call_b", "echo_b", serde_json::json!({})),
857        ]));
858
859        let mut registry = ToolRegistry::new();
860        registry.register(Arc::new(EchoTool { suffix: "echo_a" }));
861        registry.register(Arc::new(EchoTool { suffix: "echo_b" }));
862        let registry = Arc::new(registry);
863
864        let config = ToolCallingConfig {
865            parallel_execution: true,
866            ..ToolCallingConfig::default()
867        };
868        let coordinator = ToolCoordinator::new(backend, registry, config);
869
870        let result = coordinator
871            .execute(None, "run both")
872            .await
873            .expect("parallel tool execution should succeed");
874
875        assert_eq!(result.tool_calls.len(), 2);
876        assert!(result.tool_calls.iter().all(|r| r.success));
877        let names: Vec<&str> = result.tool_calls.iter().map(|r| r.name.as_str()).collect();
878        assert!(names.contains(&"echo_a"));
879        assert!(names.contains(&"echo_b"));
880        assert_eq!(result.finish_reason, FinishReason::Stop);
881        assert_eq!(result.iterations, 2);
882    }
883
884    // -----------------------------------------------------------------------
885    // Task scheduling edge cases (mock runner)
886    // -----------------------------------------------------------------------
887
888    use async_trait::async_trait;
889    use serde_json::json;
890    use std::sync::Mutex;
891
892    struct MockTaskRunner {
893        dispatched: Mutex<Vec<String>>,
894    }
895
896    impl MockTaskRunner {
897        fn new() -> Self {
898            Self {
899                dispatched: Mutex::new(Vec::new()),
900            }
901        }
902
903        fn dispatched_ids(&self) -> Vec<String> {
904            self.dispatched.lock().unwrap().clone()
905        }
906    }
907
908    #[async_trait]
909    impl TaskRunner for MockTaskRunner {
910        async fn run(&self, task: &ScheduledTask) -> crate::Result<Value> {
911            self.dispatched.lock().unwrap().push(task.id.clone());
912            Ok(json!({
913                "id": task.id,
914                "agent": task.agent_type,
915                "assignment": task.assignment,
916            }))
917        }
918    }
919
920    /// Scheduling must reject an empty task list before touching the runner.
921    #[tokio::test]
922    async fn schedule_empty_task_list_rejects_without_dispatch() {
923        let runner = Arc::new(MockTaskRunner::new());
924        let coordinator = TaskScheduleCoordinator::new(runner.clone());
925
926        let err = coordinator
927            .schedule(&[])
928            .await
929            .expect_err("empty task list should fail validation");
930
931        assert_eq!(err, ScheduleError::EmptyTaskList);
932        assert!(runner.dispatched_ids().is_empty());
933    }
934
935    /// Unknown agent types must fail validation and leave the runner idle.
936    #[tokio::test]
937    async fn schedule_invalid_agent_type_rejects_without_dispatch() {
938        let runner = Arc::new(MockTaskRunner::new());
939        let coordinator = TaskScheduleCoordinator::new(runner.clone());
940
941        let tasks = [ScheduledTask {
942            id: "AuthProbe".into(),
943            agent_type: "not_a_real_agent".into(),
944            assignment: "probe auth".into(),
945        }];
946
947        let err = coordinator
948            .schedule(&tasks)
949            .await
950            .expect_err("invalid agent type should fail validation");
951
952        assert_eq!(
953            err,
954            ScheduleError::InvalidAgentType("not_a_real_agent".into())
955        );
956        assert!(runner.dispatched_ids().is_empty());
957    }
958
959    /// Duplicate task ids must be rejected before any work is dispatched.
960    #[tokio::test]
961    async fn schedule_duplicate_task_ids_rejects_without_dispatch() {
962        let runner = Arc::new(MockTaskRunner::new());
963        let coordinator = TaskScheduleCoordinator::new(runner.clone());
964
965        let tasks = [
966            ScheduledTask {
967                id: "DupId".into(),
968                agent_type: "explore".into(),
969                assignment: "first".into(),
970            },
971            ScheduledTask {
972                id: "DupId".into(),
973                agent_type: "plan".into(),
974                assignment: "second".into(),
975            },
976        ];
977
978        let err = coordinator
979            .schedule(&tasks)
980            .await
981            .expect_err("duplicate ids should fail validation");
982
983        assert_eq!(err, ScheduleError::DuplicateTaskId("DupId".into()));
984        assert!(runner.dispatched_ids().is_empty());
985    }
986
987    /// Valid tasks must be dispatched through the mock runner in order.
988    #[tokio::test]
989    async fn schedule_valid_tasks_dispatches_via_mock_runner() {
990        let runner = Arc::new(MockTaskRunner::new());
991        let coordinator = TaskScheduleCoordinator::new(runner.clone());
992
993        let tasks = [
994            ScheduledTask {
995                id: "Alpha".into(),
996                agent_type: "explore".into(),
997                assignment: "scan src/".into(),
998            },
999            ScheduledTask {
1000                id: "Beta".into(),
1001                agent_type: "plan".into(),
1002                assignment: "draft refactor".into(),
1003            },
1004        ];
1005
1006        let results = coordinator
1007            .schedule(&tasks)
1008            .await
1009            .expect("valid schedule should succeed");
1010
1011        assert_eq!(runner.dispatched_ids(), vec!["Alpha", "Beta"]);
1012        assert_eq!(results.len(), 2);
1013        assert_eq!(results[0].id, "Alpha");
1014        assert_eq!(results[0].output["agent"], "explore");
1015        assert_eq!(results[1].id, "Beta");
1016        assert_eq!(results[1].output["agent"], "plan");
1017    }
1018}