Skip to main content

pawan/coordinator/
mod.rs

1//! Multi-turn tool coordinator — data types and runtime.
2//!
3//! Provides a provider-agnostic orchestration layer for agent tool-calling
4//! loops: send a prompt with tool definitions, handle tool call requests,
5//! execute tools, feed results back, repeat until the model produces a final
6//! response or hits an iteration cap.
7//!
8//! Types reused from [`crate::agent`]:
9//! - [`ToolCallRequest`] — what the model asks for
10//! - [`ToolCallRecord`]  — what actually happened
11//! - [`TokenUsage`]      — accumulated counts
12//!
13//! Types defined here:
14//! - [`ToolCallingConfig`]   — iteration / parallelism / timeout knobs
15//! - [`FinishReason`]        — why the session ended
16//! - [`Role`]              — re-exported from `crate::agent` (system/user/assistant/tool)
17//! - [`ConversationMessage`] — a single turn in the history
18//! - [`CoordinatorResult`]   — everything the caller gets back
19//! - [`ToolCoordinator`]     — the runtime that drives the LLM+tool loop
20//!
21//! ## Design notes
22//!
23//! - [`ToolCallRecord`] is reused from [`crate::agent`] rather than duplicated.
24//!   Failed tool calls land in `result` as a `{"error": "..."}` JSON object
25//!   with `success: false`, matching pawan's existing agent loop — there's no
26//!   separate `error` field on the record.
27//! - [`ConversationMessage::tool_call_id`] is only populated on [`Role::Tool`]
28//!   turns and links the result back to the assistant message that requested it.
29
30pub mod types;
31pub use types::*;
32
33use crate::agent::backend::LlmBackend;
34use crate::agent::{Message, Role, ToolCallRecord, ToolCallRequest, ToolResultMessage, TokenUsage};
35use crate::tools::ToolRegistry;
36use futures::future::join_all;
37use std::sync::Arc;
38use std::time::Instant;
39use tokio::time::timeout;
40
41// ---------------------------------------------------------------------------
42// Type bridge: ConversationMessage → agent::Message
43// ---------------------------------------------------------------------------
44
45/// Convert a [`ConversationMessage`] to the backend's [`Message`] type.
46///
47/// The coordinator tracks history in its own `ConversationMessage` type, but
48/// `LlmBackend::generate()` expects `&[agent::Message]`. This function maps
49/// the coordinator's richer type to the backend wire format:
50///
51/// - `Tool` role messages: parse `content` back to JSON and populate
52///   `Message::tool_result` with a `ToolResultMessage`.
53/// - `Assistant` messages: copy `tool_calls` directly (same type).
54/// - `System`/`User` messages: straightforward role + content copy.
55fn to_backend_message(msg: &ConversationMessage) -> Message {
56    let tool_result = if msg.role == Role::Tool {
57        msg.tool_call_id.as_ref().map(|id| ToolResultMessage {
58            tool_call_id: id.clone(),
59            content: serde_json::from_str(&msg.content).unwrap_or(serde_json::Value::String(msg.content.clone())),
60            success: true,
61        })
62    } else {
63        None
64    };
65
66    Message {
67        role: msg.role.clone(),
68        content: msg.content.clone(),
69        tool_calls: msg.tool_calls.clone(),
70        tool_result,
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ToolCoordinator runtime
76// ---------------------------------------------------------------------------
77
78/// Runtime that drives the LLM + tool-calling loop.
79///
80/// Wraps a backend and a tool registry, sends prompts with tool definitions,
81/// executes requested tools, feeds results back, and repeats until the model
82/// produces a final text response or a halt condition fires.
83///
84/// # Example
85///
86/// ```rust,ignore
87/// use pawan::coordinator::{ToolCoordinator, ToolCallingConfig};
88/// use pawan::tools::ToolRegistry;
89/// use std::sync::Arc;
90///
91/// let backend = Arc::new(my_backend);
92/// let registry = Arc::new(ToolRegistry::new());
93/// let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
94///
95/// let result = coordinator.execute(Some("You are helpful."), "What is 2+2?").await?;
96/// println!("{}", result.content);
97/// ```
98pub struct ToolCoordinator {
99    backend: Arc<dyn LlmBackend>,
100    registry: Arc<ToolRegistry>,
101    config: ToolCallingConfig,
102}
103
104impl ToolCoordinator {
105    /// Create a new `ToolCoordinator`.
106    pub fn new(
107        backend: Arc<dyn LlmBackend>,
108        registry: Arc<ToolRegistry>,
109        config: ToolCallingConfig,
110    ) -> Self {
111        Self { backend, registry, config }
112    }
113
114    /// Execute a tool-calling session starting from a plain prompt.
115    ///
116    /// Builds an initial `[system?, user]` message list and drives the loop.
117    pub async fn execute(
118        &self,
119        system_prompt: Option<&str>,
120        user_prompt: &str,
121    ) -> crate::Result<CoordinatorResult> {
122        let mut messages: Vec<ConversationMessage> = Vec::new();
123        if let Some(sys) = system_prompt {
124            messages.push(ConversationMessage::system(sys));
125        }
126        messages.push(ConversationMessage::user(user_prompt));
127        self.execute_with_history(messages).await
128    }
129
130    /// Execute a tool-calling session from an existing message history.
131    ///
132    /// This is the primary loop: it calls the backend, dispatches tool calls,
133    /// appends results to history, and repeats until the model emits a final
134    /// text response or a halt condition fires.
135    pub async fn execute_with_history(
136        &self,
137        mut messages: Vec<ConversationMessage>,
138    ) -> crate::Result<CoordinatorResult> {
139        let tool_defs = self.registry.get_definitions();
140        let mut all_tool_calls: Vec<ToolCallRecord> = Vec::new();
141        let mut total_usage = TokenUsage::default();
142
143        for iteration in 0..self.config.max_iterations {
144            // Convert coordinator messages to backend wire format.
145            let backend_messages: Vec<Message> =
146                messages.iter().map(to_backend_message).collect();
147
148            // Call backend — no streaming callback needed for coordinator.
149            let response = self
150                .backend
151                .generate(&backend_messages, &tool_defs, None)
152                .await?;
153
154            // Accumulate token usage.
155            if let Some(usage) = &response.usage {
156                total_usage.prompt_tokens += usage.prompt_tokens;
157                total_usage.completion_tokens += usage.completion_tokens;
158                total_usage.total_tokens += usage.total_tokens;
159                total_usage.reasoning_tokens += usage.reasoning_tokens;
160                total_usage.action_tokens += usage.action_tokens;
161            }
162
163            // Append the assistant turn to history.
164            messages.push(ConversationMessage::assistant(
165                &response.content,
166                response.tool_calls.clone(),
167            ));
168
169            // No tool calls → model is done.
170            if response.tool_calls.is_empty() {
171                return Ok(CoordinatorResult {
172                    content: response.content,
173                    tool_calls: all_tool_calls,
174                    iterations: iteration + 1,
175                    finish_reason: FinishReason::Stop,
176                    total_usage,
177                    message_history: messages,
178                });
179            }
180
181            // Empty response with tool calls is unusual but guard it.
182            if response.content.is_empty() && response.tool_calls.is_empty() {
183                return Ok(CoordinatorResult {
184                    content: String::new(),
185                    tool_calls: all_tool_calls,
186                    iterations: iteration + 1,
187                    finish_reason: FinishReason::Stop,
188                    total_usage,
189                    message_history: messages,
190                });
191            }
192
193            // Validate all requested tools exist before executing any.
194            for tc in &response.tool_calls {
195                if !self.registry.has_tool(&tc.name) {
196                    return Ok(CoordinatorResult {
197                        content: response.content,
198                        tool_calls: all_tool_calls,
199                        iterations: iteration + 1,
200                        finish_reason: FinishReason::UnknownTool(tc.name.clone()),
201                        total_usage,
202                        message_history: messages,
203                    });
204                }
205            }
206
207            // Execute tool calls (parallel or sequential per config).
208            let records = self.execute_tool_calls(&response.tool_calls).await?;
209
210            // If stop_on_error, check if any record failed.
211            if self.config.stop_on_error {
212                if let Some(failed) = records.iter().find(|r| !r.success) {
213                    let err_msg = failed
214                        .result
215                        .get("error")
216                        .and_then(|v| v.as_str())
217                        .unwrap_or("tool error")
218                        .to_string();
219                    return Ok(CoordinatorResult {
220                        content: response.content,
221                        tool_calls: all_tool_calls,
222                        iterations: iteration + 1,
223                        finish_reason: FinishReason::Error(err_msg),
224                        total_usage,
225                        message_history: messages,
226                    });
227                }
228            }
229
230            // Append tool result messages and accumulate records.
231            for record in records {
232                messages.push(ConversationMessage::tool_result(&record.id, &record.result));
233                all_tool_calls.push(record);
234            }
235        }
236
237        // Hit max iterations.
238        Ok(CoordinatorResult {
239            content: messages
240                .last()
241                .map(|m| m.content.clone())
242                .unwrap_or_default(),
243            tool_calls: all_tool_calls,
244            iterations: self.config.max_iterations,
245            finish_reason: FinishReason::MaxIterations,
246            total_usage,
247            message_history: messages,
248        })
249    }
250
251    // -----------------------------------------------------------------------
252    // Internal helpers
253    // -----------------------------------------------------------------------
254
255    async fn execute_tool_calls(
256        &self,
257        calls: &[ToolCallRequest],
258    ) -> crate::Result<Vec<ToolCallRecord>> {
259        if self.config.parallel_execution {
260            self.execute_parallel(calls).await
261        } else {
262            self.execute_sequential(calls).await
263        }
264    }
265
266    async fn execute_parallel(
267        &self,
268        calls: &[ToolCallRequest],
269    ) -> crate::Result<Vec<ToolCallRecord>> {
270        let futures = calls.iter().map(|c| self.execute_single_tool(c));
271        let results = join_all(futures).await;
272
273        let mut records = Vec::with_capacity(results.len());
274        for (i, res) in results.into_iter().enumerate() {
275            match res {
276                Ok(record) => records.push(record),
277                Err(e) if self.config.stop_on_error => return Err(e),
278                Err(e) => {
279                    // Recover: turn the error into a failed ToolCallRecord.
280                    let call = &calls[i];
281                    records.push(ToolCallRecord {
282                        id: call.id.clone(),
283                        name: call.name.clone(),
284                        arguments: call.arguments.clone(),
285                        result: serde_json::json!({"error": e.to_string()}),
286                        success: false,
287                        duration_ms: 0,
288                    });
289                }
290            }
291        }
292        Ok(records)
293    }
294
295    async fn execute_sequential(
296        &self,
297        calls: &[ToolCallRequest],
298    ) -> crate::Result<Vec<ToolCallRecord>> {
299        let mut records = Vec::with_capacity(calls.len());
300        for call in calls {
301            match self.execute_single_tool(call).await {
302                Ok(record) => records.push(record),
303                Err(e) if self.config.stop_on_error => return Err(e),
304                Err(e) => {
305                    records.push(ToolCallRecord {
306                        id: call.id.clone(),
307                        name: call.name.clone(),
308                        arguments: call.arguments.clone(),
309                        result: serde_json::json!({"error": e.to_string()}),
310                        success: false,
311                        duration_ms: 0,
312                    });
313                }
314            }
315        }
316        Ok(records)
317    }
318
319    async fn execute_single_tool(&self, call: &ToolCallRequest) -> crate::Result<ToolCallRecord> {
320        let start = Instant::now();
321
322        let result = timeout(
323            self.config.tool_timeout,
324            self.registry.execute(&call.name, call.arguments.clone()),
325        )
326        .await;
327
328        let duration_ms = start.elapsed().as_millis() as u64;
329
330        match result {
331            Ok(Ok(value)) => Ok(ToolCallRecord {
332                id: call.id.clone(),
333                name: call.name.clone(),
334                arguments: call.arguments.clone(),
335                result: value,
336                success: true,
337                duration_ms,
338            }),
339            Ok(Err(e)) => Ok(ToolCallRecord {
340                id: call.id.clone(),
341                name: call.name.clone(),
342                arguments: call.arguments.clone(),
343                result: serde_json::json!({"error": e.to_string()}),
344                success: false,
345                duration_ms,
346            }),
347            Err(_elapsed) => Ok(ToolCallRecord {
348                id: call.id.clone(),
349                name: call.name.clone(),
350                arguments: call.arguments.clone(),
351                result: serde_json::json!({"error": "tool execution timed out"}),
352                success: false,
353                duration_ms,
354            }),
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use std::sync::Arc;
363
364    /// No tools available — model replies with plain text on the first turn.
365    /// Verifies that the coordinator terminates cleanly and returns the model
366    /// text as `content` with `FinishReason::Stop` and zero tool calls.
367    #[tokio::test]
368    async fn execute_with_empty_registry_returns_model_response() {
369        use crate::agent::backend::mock::MockBackend;
370
371        let backend = Arc::new(MockBackend::with_text("Hello, world!"));
372        let registry = Arc::new(ToolRegistry::new());
373        let coordinator = ToolCoordinator::new(backend, registry, ToolCallingConfig::default());
374
375        let result = coordinator
376            .execute(None, "Say hello")
377            .await
378            .expect("coordinator should not error");
379
380        assert_eq!(result.content, "Hello, world!");
381        assert_eq!(result.finish_reason, FinishReason::Stop);
382        assert_eq!(result.iterations, 1);
383        assert!(result.tool_calls.is_empty());
384        // History: [user, assistant]
385        assert_eq!(result.message_history.len(), 2);
386    }
387
388    /// Pin the `ToolCallingConfig` defaults so regressions are caught.
389    #[test]
390    fn tool_calling_config_defaults_are_sensible() {
391        use std::time::Duration;
392        let cfg = ToolCallingConfig::default();
393        assert_eq!(cfg.max_iterations, 10, "max_iterations default changed");
394        assert!(cfg.parallel_execution, "parallel_execution should default to true");
395        assert_eq!(cfg.tool_timeout, Duration::from_secs(30), "tool_timeout default changed");
396        assert!(!cfg.stop_on_error, "stop_on_error should default to false");
397    }
398
399    /// The coordinator must fire `FinishReason::MaxIterations` when the model
400    /// keeps requesting tool calls and we exhaust the iteration budget.
401    /// Uses a mock backend that always returns a tool-call response for a
402    /// registered no-op tool, driving the loop to the configured cap.
403    #[tokio::test]
404    async fn coordinator_result_captures_finish_reason_max_iterations() {
405        use crate::agent::backend::mock::{MockBackend, MockResponse};
406        use async_trait::async_trait;
407        use crate::tools::Tool;
408        use serde_json::Value;
409
410        // A trivial no-op tool that always succeeds.
411        struct NoOpTool;
412
413        #[async_trait]
414        impl Tool for NoOpTool {
415            fn name(&self) -> &str { "noop" }
416            fn description(&self) -> &str { "does nothing" }
417            fn parameters_schema(&self) -> Value {
418                serde_json::json!({"type": "object", "properties": {}})
419            }
420            async fn execute(&self, _args: Value) -> crate::Result<Value> {
421                Ok(serde_json::json!({"ok": true}))
422            }
423        }
424
425        // Build a backend that always requests the noop tool (never gives a
426        // final text response), so the loop runs until max_iterations.
427        let responses: Vec<MockResponse> = (0..15)
428            .map(|_| MockResponse::tool_call("noop", serde_json::json!({})))
429            .collect();
430        let backend = Arc::new(MockBackend::new(responses));
431
432        let mut registry = ToolRegistry::new();
433        registry.register(std::sync::Arc::new(NoOpTool));
434        let registry = Arc::new(registry);
435
436        let config = ToolCallingConfig {
437            max_iterations: 3,
438            parallel_execution: false,
439            ..ToolCallingConfig::default()
440        };
441        let coordinator = ToolCoordinator::new(backend, registry, config);
442
443        let result = coordinator
444            .execute(None, "loop forever")
445            .await
446            .expect("coordinator should not hard-error");
447
448        assert_eq!(
449            result.finish_reason,
450            FinishReason::MaxIterations,
451            "expected MaxIterations, got {:?}",
452            result.finish_reason
453        );
454        assert_eq!(result.iterations, 3);
455        // Each iteration dispatches one noop tool call.
456        assert_eq!(result.tool_calls.len(), 3);
457        assert!(result.tool_calls.iter().all(|tc| tc.success));
458    }
459}