Skip to main content

koda_core/
tool_dispatch.rs

1//! Tool execution dispatch — sequential, parallel, and split-batch.
2//!
3//! Routes tool calls from the inference loop to execution, handling
4//! approval flow, parallelization, and result recording.
5//!
6//! ## Dispatch flow
7//!
8//! ```text
9//! Model emits tool calls
10//!   → Classify each call's effect (ReadOnly / LocalMutation / Destructive)
11//!   → Split into read-only batch + mutation batch
12//!   → Read-only tools: execute in parallel (tokio::join)
13//!   → Mutation tools: execute sequentially with approval
14//!   → Record results in DB + inject into conversation
15//! ```
16//!
17//! ## Related modules
18//!
19//! - [`crate::tools`] — tool definitions and `ToolRegistry::execute()`
20//! - [`crate::approval`] — approval mode and effect classification
21//! - `sub_agent_dispatch.rs` — `InvokeAgent` handling (needs provider access)
22//! - `approval_flow.rs` — interactive approval UI flow
23//!
24//! ## Design (DESIGN.md)
25//!
26//! - **Tool Dispatch: Match Statement (P2)**: Tools are dispatched via a
27//!   `match` in `ToolRegistry::execute()`, not a `HashMap<String, Box<dyn Tool>>`.
28//!   Rust's exhaustive matching catches missing handlers at compile time.
29
30use crate::approval::{self, ApprovalMode, ToolApproval};
31use crate::approval_flow::{handle_ask_user, request_approval};
32use crate::config::KodaConfig;
33use crate::db::{Database, Role};
34use crate::engine::{ApprovalDecision, EngineCommand, EngineEvent};
35use crate::file_tracker::FileTracker;
36use crate::persistence::Persistence;
37use crate::preview;
38use crate::providers::ToolCall;
39use crate::sub_agent_cache::SubAgentCache;
40use crate::sub_agent_dispatch;
41use crate::tools;
42
43use anyhow::Result;
44use std::path::{Path, PathBuf};
45use tokio::sync::mpsc;
46use tokio_util::sync::CancellationToken;
47
48/// Post-execution recording: emit result event, persist to DB, track progress
49/// and file lifecycle. Called after every successful tool execution regardless
50/// of execution strategy (parallel, split-batch, or sequential).
51#[allow(clippy::too_many_arguments)]
52pub(crate) async fn record_tool_result(
53    tc: &ToolCall,
54    result: &str,
55    success: bool,
56    full_output: Option<&str>,
57    db: &Database,
58    session_id: &str,
59    max_result_chars: usize,
60    project_root: &Path,
61    file_tracker: &mut FileTracker,
62    sink: &dyn crate::engine::EngineSink,
63) -> Result<()> {
64    sink.emit(EngineEvent::ToolCallResult {
65        id: tc.id.clone(),
66        name: tc.function_name.clone(),
67        output: result.to_string(),
68    });
69
70    // If we have separate full output (Bash smart summary), use the dedicated
71    // two-column insert so the model sees the summary while RecallContext can
72    // search the full output.
73    if let Some(full) = full_output {
74        db.insert_tool_message_with_full(session_id, result, &tc.id, full)
75            .await?;
76    } else {
77        let stored = truncate_for_history(result, max_result_chars);
78        db.insert_message(
79            session_id,
80            &Role::Tool,
81            Some(&stored),
82            None,
83            Some(&tc.id),
84            None,
85        )
86        .await?;
87    }
88    crate::progress::track_progress(db, session_id, &tc.function_name, &tc.arguments, result).await;
89    let parsed_args: serde_json::Value = serde_json::from_str(&tc.arguments).unwrap_or_default();
90    track_file_lifecycle(
91        &tc.function_name,
92        &parsed_args,
93        project_root,
94        file_tracker,
95        success,
96    )
97    .await;
98    Ok(())
99}
100
101/// Truncate a tool result for storage in conversation history.
102/// The `max_chars` limit is set by `OutputCaps::tool_result_chars`.
103fn truncate_for_history(output: &str, max_chars: usize) -> String {
104    if output.len() <= max_chars {
105        return output.to_string();
106    }
107    // Find a safe char boundary
108    let mut end = max_chars;
109    while end > 0 && !output.is_char_boundary(end) {
110        end -= 1;
111    }
112    format!(
113        "{}\n\n[...truncated {} chars. Re-read the file if you need the full content.]",
114        &output[..end],
115        output.len() - end
116    )
117}
118
119/// Resolve the file path from a tool call's arguments.
120///
121/// Used by the file lifecycle tracker to record which paths
122/// Koda creates or deletes (#465). Only relevant for Write and Delete.
123fn resolve_tool_path(
124    tool_name: &str,
125    args: &serde_json::Value,
126    project_root: &Path,
127) -> Option<PathBuf> {
128    if !matches!(tool_name, "Write" | "Delete") {
129        return None;
130    }
131    crate::file_tracker::resolve_file_path_from_args(args, project_root)
132}
133
134/// Update file lifecycle tracker after a tool execution (#465).
135///
136/// - Write → track as owned (Koda created it)
137/// - Delete → untrack (file no longer exists)
138///
139/// Only tracks when `success` is true, using the structured boolean
140/// from `ToolResult` rather than fragile string-prefix matching (#476).
141async fn track_file_lifecycle(
142    tool_name: &str,
143    args: &serde_json::Value,
144    project_root: &Path,
145    file_tracker: &mut FileTracker,
146    success: bool,
147) {
148    if !success {
149        return;
150    }
151    if let Some(path) = resolve_tool_path(tool_name, args, project_root) {
152        match tool_name {
153            "Write" => file_tracker.track_created(path).await,
154            "Delete" => file_tracker.untrack(&path).await,
155            _ => {}
156        }
157    }
158}
159
160pub(crate) fn can_parallelize(
161    tool_calls: &[ToolCall],
162    mode: ApprovalMode,
163    project_root: &Path,
164) -> bool {
165    let all_approved = !tool_calls.iter().any(|tc| {
166        let args: serde_json::Value = serde_json::from_str(&tc.arguments).unwrap_or_default();
167        matches!(
168            approval::check_tool(&tc.function_name, &args, mode, Some(project_root)),
169            ToolApproval::NeedsConfirmation | ToolApproval::Blocked
170        )
171    });
172
173    if !all_approved {
174        return false;
175    }
176
177    let mut seen = std::collections::HashSet::new();
178    let has_conflict = tool_calls.iter().any(|tc| {
179        if !crate::tools::is_mutating_tool(&tc.function_name) {
180            return false;
181        }
182        let args: serde_json::Value = serde_json::from_str(&tc.arguments).unwrap_or_default();
183        if let Some(path) = crate::undo::extract_file_path(&tc.function_name, &args) {
184            // If the path is already in the set, we have a conflict
185            !seen.insert(path)
186        } else {
187            false
188        }
189    });
190
191    !has_conflict
192}
193
194/// Execute a single tool call, returning (tool_call_id, result_output, success).
195#[allow(clippy::too_many_arguments)]
196pub(crate) async fn execute_one_tool(
197    tc: &ToolCall,
198    project_root: &Path,
199    config: &KodaConfig,
200    db: &Database,
201    _session_id: &str,
202    tools: &crate::tools::ToolRegistry,
203    mode: ApprovalMode,
204    sink: &dyn crate::engine::EngineSink,
205    cancel: CancellationToken,
206    sub_agent_cache: &SubAgentCache,
207    bg_agents: &std::sync::Arc<crate::bg_agent::BgAgentRegistry>,
208) -> (String, String, bool, Option<String>) {
209    let (result, success, full_output) = if tc.function_name == "InvokeAgent" {
210        // Sub-agents inherit the parent's approval mode.
211        match sub_agent_dispatch::execute_sub_agent(
212            project_root,
213            config,
214            db,
215            &tc.arguments,
216            mode,
217            sink,
218            cancel.clone(),
219            // Sub-agents get a fresh command channel (they auto-approve in all modes)
220            &mut mpsc::channel(1).1,
221            Some(tools.file_read_cache()),
222            sub_agent_cache,
223            _session_id,
224            bg_agents,
225        )
226        .await
227        {
228            Ok(output) => (output, true, None),
229            Err(e) => (format!("Error invoking sub-agent: {e}"), false, None),
230        }
231    } else {
232        // Invalidate sub-agent cache on file mutations
233        if crate::tools::is_mutating_tool(&tc.function_name) {
234            sub_agent_cache.invalidate();
235        }
236        let streaming = if tc.function_name == "Bash" {
237            Some((sink, tc.id.as_str()))
238        } else {
239            None
240        };
241        let r = tools
242            .execute(&tc.function_name, &tc.arguments, streaming)
243            .await;
244        (r.output, r.success, r.full_output)
245    };
246
247    (tc.id.clone(), result, success, full_output)
248}
249
250/// Run multiple tool calls concurrently and store results.
251#[allow(clippy::too_many_arguments)]
252pub(crate) async fn execute_tools_parallel(
253    tool_calls: &[ToolCall],
254    project_root: &Path,
255    config: &KodaConfig,
256    db: &Database,
257    session_id: &str,
258    tools: &crate::tools::ToolRegistry,
259    mode: ApprovalMode,
260    sink: &dyn crate::engine::EngineSink,
261    cancel: CancellationToken,
262    sub_agent_cache: &SubAgentCache,
263    file_tracker: &mut FileTracker,
264    bg_agents: &std::sync::Arc<crate::bg_agent::BgAgentRegistry>,
265) -> Result<()> {
266    let count = tool_calls.len();
267    sink.emit(EngineEvent::Info {
268        message: format!("Running {count} tools in parallel..."),
269    });
270
271    // Launch all tool calls concurrently
272    let futures: Vec<_> = tool_calls
273        .iter()
274        .map(|tc| {
275            execute_one_tool(
276                tc,
277                project_root,
278                config,
279                db,
280                session_id,
281                tools,
282                mode,
283                sink,
284                cancel.clone(),
285                sub_agent_cache,
286                bg_agents,
287            )
288        })
289        .collect();
290    let results = futures_util::future::join_all(futures).await;
291
292    // Emit banner + result together so each tool's output is visually grouped
293    for (i, (tc_id, result, success, full_output)) in results.into_iter().enumerate() {
294        sink.emit(EngineEvent::ToolCallStart {
295            id: tc_id.clone(),
296            name: tool_calls[i].function_name.clone(),
297            args: serde_json::from_str(&tool_calls[i].arguments).unwrap_or_default(),
298            is_sub_agent: false,
299        });
300        record_tool_result(
301            &tool_calls[i],
302            &result,
303            success,
304            full_output.as_deref(),
305            db,
306            session_id,
307            tools.caps.tool_result_chars,
308            project_root,
309            file_tracker,
310            sink,
311        )
312        .await?;
313    }
314    Ok(())
315}
316
317/// Split a mixed batch: run parallelizable tools concurrently, then
318/// execute remaining tools sequentially.
319///
320/// This is the key optimization for mixed batches like
321/// `[InvokeAgent, InvokeAgent, Write]` — the two sub-agents run in
322/// parallel while the Write waits for confirmation.
323#[allow(clippy::too_many_arguments)]
324pub(crate) async fn execute_tools_split_batch(
325    tool_calls: &[ToolCall],
326    project_root: &Path,
327    config: &KodaConfig,
328    db: &Database,
329    session_id: &str,
330    tools: &crate::tools::ToolRegistry,
331    mode: ApprovalMode,
332    sink: &dyn crate::engine::EngineSink,
333    cancel: CancellationToken,
334    cmd_rx: &mut mpsc::Receiver<EngineCommand>,
335    sub_agent_cache: &SubAgentCache,
336    file_tracker: &mut FileTracker,
337    bg_agents: &std::sync::Arc<crate::bg_agent::BgAgentRegistry>,
338) -> Result<()> {
339    // Partition into parallelizable vs sequential
340    let (parallel, sequential): (Vec<_>, Vec<_>) = tool_calls.iter().partition(|tc| {
341        let args: serde_json::Value = serde_json::from_str(&tc.arguments).unwrap_or_default();
342        matches!(
343            approval::check_tool(&tc.function_name, &args, mode, Some(project_root),),
344            ToolApproval::AutoApprove
345        )
346    });
347
348    // Run parallelizable tools concurrently (if more than one)
349    if parallel.len() > 1 {
350        sink.emit(EngineEvent::Info {
351            message: format!("Running {} tools in parallel...", parallel.len()),
352        });
353
354        let futures: Vec<_> = parallel
355            .iter()
356            .map(|tc| {
357                execute_one_tool(
358                    tc,
359                    project_root,
360                    config,
361                    db,
362                    session_id,
363                    tools,
364                    mode,
365                    sink,
366                    cancel.clone(),
367                    sub_agent_cache,
368                    bg_agents,
369                )
370            })
371            .collect();
372        let results = futures_util::future::join_all(futures).await;
373
374        for (j, (tc_id, result, success, full_output)) in results.into_iter().enumerate() {
375            sink.emit(EngineEvent::ToolCallStart {
376                id: tc_id.clone(),
377                name: parallel[j].function_name.clone(),
378                args: serde_json::from_str(&parallel[j].arguments).unwrap_or_default(),
379                is_sub_agent: false,
380            });
381            record_tool_result(
382                parallel[j],
383                &result,
384                success,
385                full_output.as_deref(),
386                db,
387                session_id,
388                tools.caps.tool_result_chars,
389                project_root,
390                file_tracker,
391                sink,
392            )
393            .await?;
394        }
395    } else {
396        // 0–1 parallelizable tools — just run sequentially
397        for tc in &parallel {
398            let calls = std::slice::from_ref(*tc);
399            execute_tools_sequential(
400                calls,
401                project_root,
402                config,
403                db,
404                session_id,
405                tools,
406                mode,
407                sink,
408                cancel.clone(),
409                cmd_rx,
410                sub_agent_cache,
411                file_tracker,
412                bg_agents,
413            )
414            .await?;
415        }
416    }
417
418    // Run non-parallelizable tools sequentially
419    if !sequential.is_empty() {
420        let seq_calls: Vec<ToolCall> = sequential.into_iter().cloned().collect();
421        execute_tools_sequential(
422            &seq_calls,
423            project_root,
424            config,
425            db,
426            session_id,
427            tools,
428            mode,
429            sink,
430            cancel.clone(),
431            cmd_rx,
432            sub_agent_cache,
433            file_tracker,
434            bg_agents,
435        )
436        .await?;
437    }
438
439    Ok(())
440}
441
442/// Run tool calls one at a time (when confirmation is needed, or single call).
443#[allow(clippy::too_many_arguments)]
444pub(crate) async fn execute_tools_sequential(
445    tool_calls: &[ToolCall],
446    project_root: &Path,
447    config: &KodaConfig,
448    db: &Database,
449    session_id: &str,
450    tools: &crate::tools::ToolRegistry,
451    mode: ApprovalMode,
452    sink: &dyn crate::engine::EngineSink,
453    cancel: CancellationToken,
454    cmd_rx: &mut mpsc::Receiver<EngineCommand>,
455    sub_agent_cache: &SubAgentCache,
456    file_tracker: &mut FileTracker,
457    bg_agents: &std::sync::Arc<crate::bg_agent::BgAgentRegistry>,
458) -> Result<()> {
459    for tc in tool_calls {
460        // Check for interrupt before each tool
461        if cancel.is_cancelled() {
462            sink.emit(EngineEvent::Warn {
463                message: "Interrupted".into(),
464            });
465            return Ok(());
466        }
467
468        let parsed_args: serde_json::Value =
469            serde_json::from_str(&tc.arguments).unwrap_or_default();
470
471        sink.emit(EngineEvent::ToolCallStart {
472            id: tc.id.clone(),
473            name: tc.function_name.clone(),
474            args: parsed_args.clone(),
475            is_sub_agent: false,
476        });
477
478        // AskUser: pause inference, show question in TUI, wait for typed answer.
479        // Handled here (not in execute_one_tool) because it needs sink + cmd_rx.
480        if tc.function_name == "AskUser" {
481            let answer = handle_ask_user(sink, cmd_rx, &cancel, &parsed_args).await;
482            let result = match answer {
483                Some(text) if !text.trim().is_empty() => text,
484                Some(_) => "User did not provide an answer.".into(),
485                None => return Ok(()), // cancelled
486            };
487            record_tool_result(
488                tc,
489                &result,
490                true,
491                None, // AskUser has no full_output
492                db,
493                session_id,
494                tools.caps.tool_result_chars,
495                project_root,
496                file_tracker,
497                sink,
498            )
499            .await?;
500            continue;
501        }
502
503        // Pre-flight validation: catch errors before bothering the user
504        // with an approval prompt that will inevitably fail.
505        if let Some(error) = {
506            let cache = tools.file_read_cache();
507            tools::validate::validate_tool_call(
508                &tc.function_name,
509                &parsed_args,
510                project_root,
511                Some(&cache),
512            )
513            .await
514        } {
515            record_tool_result(
516                tc,
517                &format!("Validation error: {error}"),
518                false,
519                None,
520                db,
521                session_id,
522                tools.caps.tool_result_chars,
523                project_root,
524                file_tracker,
525                sink,
526            )
527            .await?;
528            continue;
529        }
530
531        // Check approval for this tool call (with file ownership awareness, #465)
532        let approval = approval::check_tool_with_tracker(
533            &tc.function_name,
534            &parsed_args,
535            mode,
536            Some(project_root),
537            Some(file_tracker),
538        );
539
540        match approval {
541            ToolApproval::AutoApprove => {
542                // Execute without asking
543            }
544            ToolApproval::Blocked => {
545                // Plan mode: emit ActionBlocked event, let the client render it
546                let detail = tools::describe_action(&tc.function_name, &parsed_args);
547                let diff_preview =
548                    preview::compute(&tc.function_name, &parsed_args, project_root).await;
549                sink.emit(EngineEvent::ActionBlocked {
550                    tool_name: tc.function_name.clone(),
551                    detail: detail.clone(),
552                    preview: diff_preview,
553                });
554                db.insert_message(
555                    session_id,
556                    &Role::Tool,
557                    Some("[safe mode] Action blocked. You are in read-only mode. DO NOT retry this command. Describe what you would do instead. The user must press Shift+Tab to switch to auto or strict mode."),
558                    None,
559                    Some(&tc.id),
560                    None,
561                )
562                .await?;
563                continue;
564            }
565            ToolApproval::NeedsConfirmation => {
566                let detail = tools::describe_action(&tc.function_name, &parsed_args);
567                let diff_preview =
568                    preview::compute(&tc.function_name, &parsed_args, project_root).await;
569                let effect = crate::approval::resolve_tool_effect(&tc.function_name, &parsed_args);
570
571                match request_approval(
572                    sink,
573                    cmd_rx,
574                    &cancel,
575                    &tc.function_name,
576                    &detail,
577                    diff_preview,
578                    effect,
579                )
580                .await
581                {
582                    Some(ApprovalDecision::Approve) => {}
583                    Some(ApprovalDecision::Reject) => {
584                        db.insert_message(
585                            session_id,
586                            &Role::Tool,
587                            Some("User rejected this action."),
588                            None,
589                            Some(&tc.id),
590                            None,
591                        )
592                        .await?;
593                        continue;
594                    }
595                    Some(ApprovalDecision::RejectWithFeedback { feedback }) => {
596                        let result = format!("User rejected this action with feedback: {feedback}");
597                        db.insert_message(
598                            session_id,
599                            &Role::Tool,
600                            Some(&result),
601                            None,
602                            Some(&tc.id),
603                            None,
604                        )
605                        .await?;
606                        continue;
607                    }
608                    None => {
609                        // Cancelled
610                        return Ok(());
611                    }
612                }
613            }
614        }
615
616        let (_, result, success, full_output) = execute_one_tool(
617            tc,
618            project_root,
619            config,
620            db,
621            session_id,
622            tools,
623            mode,
624            sink,
625            cancel.clone(),
626            sub_agent_cache,
627            bg_agents,
628        )
629        .await;
630        record_tool_result(
631            tc,
632            &result,
633            success,
634            full_output.as_deref(),
635            db,
636            session_id,
637            tools.caps.tool_result_chars,
638            project_root,
639            file_tracker,
640            sink,
641        )
642        .await?;
643    }
644    Ok(())
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use crate::providers::ToolCall;
651
652    fn make_tool_call(name: &str) -> ToolCall {
653        ToolCall {
654            id: "t1".to_string(),
655            function_name: name.to_string(),
656            arguments: "{}".to_string(),
657            thought_signature: None,
658        }
659    }
660
661    #[test]
662    fn test_can_parallelize_read_only() {
663        let calls = vec![make_tool_call("Read"), make_tool_call("Grep")];
664        assert!(can_parallelize(
665            &calls,
666            ApprovalMode::Confirm,
667            Path::new("/test/project")
668        ));
669    }
670
671    #[test]
672    fn test_cannot_parallelize_writes() {
673        let calls = vec![make_tool_call("Read"), make_tool_call("Write")];
674        assert!(!can_parallelize(
675            &calls,
676            ApprovalMode::Confirm,
677            Path::new("/test/project")
678        ));
679    }
680
681    #[test]
682    fn test_cannot_parallelize_bash() {
683        // Dangerous bash command should prevent parallelization
684        let calls = vec![
685            make_tool_call("Read"),
686            ToolCall {
687                id: "t2".to_string(),
688                function_name: "Bash".to_string(),
689                arguments: r#"{"command": "rm -rf /tmp/test"}"#.to_string(),
690                thought_signature: None,
691            },
692        ];
693        assert!(!can_parallelize(
694            &calls,
695            ApprovalMode::Confirm,
696            Path::new("/test/project")
697        ));
698    }
699
700    #[test]
701    fn test_can_parallelize_agents() {
702        let calls = vec![make_tool_call("InvokeAgent"), make_tool_call("InvokeAgent")];
703        assert!(can_parallelize(
704            &calls,
705            ApprovalMode::Confirm,
706            Path::new("/test/project")
707        ));
708    }
709
710    #[test]
711    fn test_cannot_parallelize_same_file_edits() {
712        let calls = vec![
713            ToolCall {
714                id: "t1".to_string(),
715                function_name: "Edit".to_string(),
716                arguments: r#"{"file_path": "src/main.rs"}"#.to_string(),
717                thought_signature: None,
718            },
719            ToolCall {
720                id: "t2".to_string(),
721                function_name: "Edit".to_string(),
722                arguments: r#"{"file_path": "src/main.rs"}"#.to_string(),
723                thought_signature: None,
724            },
725        ];
726        assert!(!can_parallelize(
727            &calls,
728            ApprovalMode::Auto, // Auto mode would normally allow parallelization
729            Path::new("/test/project")
730        ));
731    }
732
733    #[test]
734    fn test_can_parallelize_different_file_edits() {
735        let calls = vec![
736            ToolCall {
737                id: "t1".to_string(),
738                function_name: "Edit".to_string(),
739                arguments: r#"{"file_path": "src/main.rs"}"#.to_string(),
740                thought_signature: None,
741            },
742            ToolCall {
743                id: "t2".to_string(),
744                function_name: "Edit".to_string(),
745                arguments: r#"{"file_path": "src/lib.rs"}"#.to_string(),
746                thought_signature: None,
747            },
748        ];
749        assert!(can_parallelize(
750            &calls,
751            ApprovalMode::Auto,
752            Path::new("/test/project")
753        ));
754    }
755
756    #[test]
757    fn test_is_mutating_tool() {
758        assert!(crate::tools::is_mutating_tool("Write"));
759        assert!(crate::tools::is_mutating_tool("Edit"));
760        assert!(crate::tools::is_mutating_tool("Delete"));
761        assert!(crate::tools::is_mutating_tool("Bash"));
762        assert!(crate::tools::is_mutating_tool("MemoryWrite"));
763        assert!(!crate::tools::is_mutating_tool("Read"));
764        assert!(!crate::tools::is_mutating_tool("List"));
765        // InvokeAgent is ReadOnly (sub-agents inherit parent's approval mode)
766        assert!(!crate::tools::is_mutating_tool("InvokeAgent"));
767    }
768
769    #[test]
770    fn test_mixed_batch_not_fully_parallelizable() {
771        let calls = vec![make_tool_call("InvokeAgent"), make_tool_call("Write")];
772        assert!(!can_parallelize(
773            &calls,
774            ApprovalMode::Confirm,
775            Path::new("/test/project")
776        ));
777    }
778
779    #[test]
780    fn test_mixed_batch_fully_parallelizable_in_auto() {
781        let calls = vec![make_tool_call("InvokeAgent"), make_tool_call("Write")];
782        assert!(can_parallelize(
783            &calls,
784            ApprovalMode::Auto,
785            Path::new("/test/project")
786        ));
787    }
788}