swink-agent 0.8.0

Core scaffolding for running LLM-powered agentic loops
Documentation
//! Pre-process phase: pre-dispatch policies, approval gate, argument rewriting.

use std::collections::HashMap;
use std::sync::Arc;

use futures::FutureExt;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::error;

use crate::agent_options::ApproveToolFn;
use crate::policy::{PreDispatchVerdict, ToolDispatchContext, run_pre_dispatch_policies};
use crate::tool::{AgentTool, AgentToolResult, ApprovalMode, ToolApproval, ToolApprovalRequest};
use crate::types::{AgentMessage, ContentBlock};

use super::shared::{emit_batch_stop_results, emit_error_result, panic_payload_message};
use super::{
    AgentEvent, AgentLoopConfig, PreparedToolCall, ToolCallInfo, ToolExecOutcome, collect, emit,
    order_results_by_tool_calls,
};

// ─── Pre-process result ─────────────────────────────────────────────────────

/// Result of the pre-processing phase for a tool batch.
pub(super) struct PreprocessResult {
    /// Tool calls that passed all gates and are ready for dispatch.
    pub prepared: Vec<PreparedToolCall>,
    /// Messages injected by `PreDispatch` policies (Inject verdict).
    pub injected_messages: Vec<AgentMessage>,
}

/// Result of the batch-wide pre-dispatch policy pass for a single tool call.
enum PreDispatchPassResult {
    Ready {
        idx: usize,
        effective_arguments: serde_json::Value,
    },
    Skip {
        idx: usize,
        error_text: String,
    },
}

async fn aborted_preprocess_outcome(
    tool_calls: &[ToolCallInfo],
    results: &Arc<tokio::sync::Mutex<Vec<(usize, crate::types::ToolResultMessage)>>>,
    tool_timings: &Arc<tokio::sync::Mutex<Vec<crate::metrics::ToolExecMetrics>>>,
    injected_messages: Vec<AgentMessage>,
) -> ToolExecOutcome {
    collect::build_aborted_outcome(
        tool_calls,
        Arc::clone(results),
        Arc::clone(tool_timings),
        injected_messages,
    )
    .await
}

/// Result of checking the approval gate for a single tool call.
enum ApprovalOutcome {
    Approved,
    /// Approved with modified parameters.
    ApprovedWith(serde_json::Value),
    Rejected,
    Cancelled,
    ChannelClosed,
}

// ─── Pre-process entry point ────────────────────────────────────────────────

/// Run pre-dispatch policies and the approval gate for every tool call.
///
/// Returns `Ok(PreprocessResult)` when pre-processing completes (even if some
/// calls were skipped/rejected). Returns `Err(ToolExecOutcome)` for early
/// exits (policy Stop, channel closed).
#[allow(clippy::too_many_lines)]
pub(super) async fn preprocess_tool_calls(
    config: &Arc<AgentLoopConfig>,
    tool_calls: &[ToolCallInfo],
    cancellation_token: &CancellationToken,
    tool_map: &HashMap<&str, &Arc<dyn AgentTool>>,
    results: &Arc<tokio::sync::Mutex<Vec<(usize, crate::types::ToolResultMessage)>>>,
    tool_timings: &Arc<tokio::sync::Mutex<Vec<crate::metrics::ToolExecMetrics>>>,
    tx: &mpsc::Sender<AgentEvent>,
) -> Result<PreprocessResult, ToolExecOutcome> {
    let mut prepared: Vec<PreparedToolCall> = Vec::new();
    let mut injected_messages: Vec<AgentMessage> = Vec::new();
    let mut pre_dispatch_results: Vec<PreDispatchPassResult> = Vec::with_capacity(tool_calls.len());
    let mut batch_stop_reason: Option<String> = None;

    let state_snapshot = {
        let guard = config
            .session_state
            .read()
            .unwrap_or_else(std::sync::PoisonError::into_inner);
        guard.clone()
    };

    for (idx, tc) in tool_calls.iter().enumerate() {
        if cancellation_token.is_cancelled() {
            return Err(aborted_preprocess_outcome(
                tool_calls,
                results,
                tool_timings,
                injected_messages,
            )
            .await);
        }

        // ── PreDispatch policies ──
        let mut effective_arguments = tc.arguments.clone();
        let mut dispatch_ctx = ToolDispatchContext {
            tool_name: &tc.name,
            tool_call_id: &tc.id,
            arguments: &mut effective_arguments,
            execution_root: None,
            state: &state_snapshot,
        };
        match run_pre_dispatch_policies(&config.pre_dispatch_policies, &mut dispatch_ctx) {
            PreDispatchVerdict::Continue => {
                pre_dispatch_results.push(PreDispatchPassResult::Ready {
                    idx,
                    effective_arguments,
                });
            }
            PreDispatchVerdict::Inject(msgs) => {
                injected_messages.extend(msgs);
                pre_dispatch_results.push(PreDispatchPassResult::Ready {
                    idx,
                    effective_arguments,
                });
            }
            PreDispatchVerdict::Stop(reason) => {
                if batch_stop_reason.is_none() {
                    batch_stop_reason = Some(reason);
                }
            }
            PreDispatchVerdict::Skip(error_text) => {
                pre_dispatch_results.push(PreDispatchPassResult::Skip { idx, error_text });
            }
        }
    }

    if let Some(reason) = batch_stop_reason {
        emit_batch_stop_results(tool_calls, &reason, results, tx).await;
        let all_results = std::mem::take(&mut *results.lock().await);
        let ordered = order_results_by_tool_calls(tool_calls, &all_results);
        let collected_timings = std::mem::take(&mut *tool_timings.lock().await);
        return Err(ToolExecOutcome::Completed {
            results: ordered,
            tool_metrics: collected_timings,
            transfer_signal: None,
            injected_messages,
        });
    }

    // A later `Stop` must abort the entire batch before any approval side
    // effects are emitted, so approval runs only after the whole batch clears
    // pre-dispatch.
    for pre_dispatch_result in pre_dispatch_results {
        if cancellation_token.is_cancelled() {
            return Err(aborted_preprocess_outcome(
                tool_calls,
                results,
                tool_timings,
                injected_messages,
            )
            .await);
        }

        let (idx, mut effective_arguments, skipped_error) = match pre_dispatch_result {
            PreDispatchPassResult::Ready {
                idx,
                effective_arguments,
            } => (idx, effective_arguments, None),
            PreDispatchPassResult::Skip { idx, error_text } => {
                (idx, serde_json::Value::Null, Some(error_text))
            }
        };
        let tc = &tool_calls[idx];

        if let Some(error_text) = skipped_error {
            let error_result = AgentToolResult {
                content: vec![ContentBlock::Text { text: error_text }],
                details: serde_json::Value::Null,
                is_error: true,
                transfer_signal: None,
            };
            emit_error_result(&tc.name, &tc.id, error_result, idx, results, tx).await;
            continue;
        }

        if let Some(ref approve_fn) = config.approve_tool
            && config.approval_mode != ApprovalMode::Bypassed
        {
            let requires_approval = tool_map
                .get(tc.name.as_str())
                .is_some_and(|t| t.requires_approval());

            let should_call_approval = match config.approval_mode {
                ApprovalMode::Smart => requires_approval,
                ApprovalMode::Enabled => true,
                ApprovalMode::Bypassed => unreachable!(),
            };

            if should_call_approval {
                match check_approval(
                    approve_fn,
                    tc,
                    &effective_arguments,
                    idx,
                    cancellation_token,
                    requires_approval,
                    tool_map,
                    results,
                    tx,
                )
                .await
                {
                    ApprovalOutcome::Approved => {}
                    ApprovalOutcome::ApprovedWith(new_params) => {
                        effective_arguments = new_params;
                    }
                    ApprovalOutcome::Rejected => continue,
                    ApprovalOutcome::Cancelled => {
                        return Err(aborted_preprocess_outcome(
                            tool_calls,
                            results,
                            tool_timings,
                            injected_messages,
                        )
                        .await);
                    }
                    ApprovalOutcome::ChannelClosed => return Err(ToolExecOutcome::ChannelClosed),
                }
            }
        }

        prepared.push(PreparedToolCall {
            idx,
            effective_arguments,
        });
    }

    Ok(PreprocessResult {
        prepared,
        injected_messages,
    })
}

async fn emit_approval_resolved(
    tx: &mpsc::Sender<AgentEvent>,
    tc: &ToolCallInfo,
    approved: bool,
) -> bool {
    emit(
        tx,
        AgentEvent::ToolApprovalResolved {
            id: tc.id.clone(),
            name: tc.name.clone(),
            approved,
        },
    )
    .await
}

// ─── Approval helper ────────────────────────────────────────────────────────

/// Run the approval gate for a single tool call.
///
/// # Canonical event order
///
/// The full per-tool-call event sequence is:
///
/// 1. [`AgentEvent::ToolApprovalRequested`] — emitted here, before the callback fires.
/// 2. [`AgentEvent::ToolApprovalResolved`] — emitted here, after the callback resolves.
/// 3. [`AgentEvent::ToolExecutionStart`] — emitted later by `dispatch_single_tool`.
/// 4. [`AgentEvent::ToolExecutionEnd`] — emitted after the tool's `execute()` returns.
///
/// Approval always precedes execution: a tool must be approved before it is
/// dispatched, so `ToolExecutionStart` cannot be observed until after
/// `ToolApprovalResolved`. Consumers (TUI, eval, tests) may rely on this order.
#[allow(clippy::too_many_arguments)]
async fn check_approval(
    approve_fn: &ApproveToolFn,
    tc: &ToolCallInfo,
    effective_arguments: &serde_json::Value,
    idx: usize,
    cancellation_token: &CancellationToken,
    requires_approval: bool,
    tool_map: &HashMap<&str, &Arc<dyn AgentTool>>,
    results: &Arc<tokio::sync::Mutex<Vec<(usize, crate::types::ToolResultMessage)>>>,
    tx: &mpsc::Sender<AgentEvent>,
) -> ApprovalOutcome {
    if cancellation_token.is_cancelled() {
        return ApprovalOutcome::Cancelled;
    }

    if !emit(
        tx,
        AgentEvent::ToolApprovalRequested {
            id: tc.id.clone(),
            name: tc.name.clone(),
            arguments: effective_arguments.clone(),
        },
    )
    .await
    {
        return ApprovalOutcome::ChannelClosed;
    }

    // Resolve approval context with panic safety.
    let approval_context = tool_map.get(tc.name.as_str()).and_then(|tool| {
        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            tool.approval_context(effective_arguments)
        }))
        .unwrap_or_else(|_| {
            tracing::warn!(tool_name = %tc.name, "approval_context() panicked — using None");
            None
        })
    });

    let request = ToolApprovalRequest {
        tool_call_id: tc.id.clone(),
        tool_name: tc.name.clone(),
        arguments: effective_arguments.clone(),
        requires_approval,
        context: approval_context,
    };
    let decision = match tokio::select! {
        biased;
        () = cancellation_token.cancelled() => {
            if !emit_approval_resolved(tx, tc, false).await {
                return ApprovalOutcome::ChannelClosed;
            }

            return ApprovalOutcome::Cancelled;
        }
        decision = std::panic::AssertUnwindSafe(approve_fn(request)).catch_unwind() => decision
    } {
        Ok(decision) => decision,
        Err(panic_value) => {
            let panic_message = panic_payload_message(panic_value.as_ref());
            error!(
                tool_call_id = %tc.id,
                tool_name = %tc.name,
                "approval callback panicked: {panic_message}"
            );

            if !emit_approval_resolved(tx, tc, false).await {
                return ApprovalOutcome::ChannelClosed;
            }

            emit_error_result(
                &tc.name,
                &tc.id,
                AgentToolResult::error(format!(
                    "Tool call '{}' was rejected because the approval callback panicked: \
                     {panic_message}",
                    tc.name
                )),
                idx,
                results,
                tx,
            )
            .await;
            return ApprovalOutcome::Rejected;
        }
    };
    let approved = !matches!(decision, ToolApproval::Rejected);

    if !emit_approval_resolved(tx, tc, approved).await {
        return ApprovalOutcome::ChannelClosed;
    }

    match decision {
        ToolApproval::Approved => ApprovalOutcome::Approved,
        ToolApproval::ApprovedWith(new_params) => ApprovalOutcome::ApprovedWith(new_params),
        ToolApproval::Rejected => {
            let rejection_result = AgentToolResult::error(format!(
                "Tool call '{}' was rejected by the approval gate.",
                tc.name
            ));
            emit_error_result(&tc.name, &tc.id, rejection_result, idx, results, tx).await;
            ApprovalOutcome::Rejected
        }
    }
}