systemprompt-agent 0.2.2

Agent-to-Agent (A2A) protocol for systemprompt.io AI governance: streaming, JSON-RPC models, task lifecycle, .well-known discovery, and governed agent orchestration.
Documentation
use anyhow::Result;
use serde_json::Value;
use systemprompt_identifiers::TaskId;
use systemprompt_models::ai::{
    ExecutionState, GenerateResponseParams, PlanValidationError, PlannedToolCall, TemplateValidator,
};
use systemprompt_models::{AiMessage, ExecutionStep, McpTool, PlannedTool, TrackedStep};

use super::super::plan_executor::{
    convert_to_call_tool_results, convert_to_tool_calls, execute_tools_with_templates,
    format_results_for_response,
};
use super::super::tool_executor::ContextToolExecutor;
use super::super::{ExecutionContext, ExecutionResult};
use crate::services::ExecutionTrackingService;
use crate::services::a2a_server::processing::message::StreamEvent;

pub struct HandleToolCallsParams<'a> {
    pub reasoning: String,
    pub calls: Vec<PlannedToolCall>,
    pub context: &'a ExecutionContext,
    pub tracking: &'a ExecutionTrackingService,
    pub planning_tracked: Result<(TrackedStep, ExecutionStep), anyhow::Error>,
    pub task_id: TaskId,
    pub messages: Vec<AiMessage>,
    pub tools: Vec<McpTool>,
}

pub async fn handle_tool_calls(params: HandleToolCallsParams<'_>) -> Result<ExecutionResult> {
    let HandleToolCallsParams {
        reasoning,
        calls,
        context,
        tracking,
        planning_tracked,
        task_id,
        messages,
        tools,
    } = params;
    tracing::info!(
        tool_count = calls.len(),
        reasoning = %reasoning,
        "Tool calls planned"
    );

    let planned_tools: Vec<PlannedTool> = calls
        .iter()
        .map(|c| PlannedTool {
            tool_name: c.tool_name.clone(),
            arguments: c.arguments.clone(),
        })
        .collect();

    if let Ok((tracked, _)) = planning_tracked {
        if let Ok(step) = tracking
            .complete_planning(tracked, Some(reasoning.clone()), Some(planned_tools))
            .await
        {
            if context
                .tx
                .send(StreamEvent::ExecutionStepUpdate { step })
                .is_err()
            {
                tracing::debug!("Stream receiver dropped");
            }
        }
    }

    let tool_output_schemas = TemplateValidator::get_tool_output_schemas(&calls, &tools);

    if let Err(validation_errors) = TemplateValidator::validate_plan(&calls, &tool_output_schemas) {
        return handle_validation_failure(validation_errors, context, messages).await;
    }

    tracing::info!("Template validation passed");

    let (tool_name, tool_arguments) = build_tool_summary(&calls);

    let (tracked, step) = tracking
        .track_tool_execution(task_id.clone(), tool_name, tool_arguments)
        .await?;

    if context
        .tx
        .send(StreamEvent::ExecutionStepUpdate { step })
        .is_err()
    {
        tracing::debug!("Stream receiver dropped");
    }

    let tool_executor = ContextToolExecutor {
        context: context.clone(),
    };

    let state =
        execute_tools_with_templates(&calls, &tools, &context.request_ctx, &tool_executor).await?;

    let execution_summary = format_results_for_response(&state);

    let has_failures = !state.failed_results().is_empty();

    record_execution_status(tracking, &tracked, &state, has_failures).await;

    tracing::info!(
        succeeded = state.successful_results().len(),
        failed = state.failed_results().len(),
        "Execution complete"
    );

    if let Ok(step) = tracking.track_completion(task_id).await {
        if context
            .tx
            .send(StreamEvent::ExecutionStepUpdate { step })
            .is_err()
        {
            tracing::debug!("Stream receiver dropped");
        }
    }

    let tool_error_message: Option<String> = if has_failures {
        Some(
            state
                .failed_results()
                .iter()
                .filter_map(|r| r.error.as_ref())
                .map(String::as_str)
                .collect::<Vec<_>>()
                .join("; "),
        )
    } else {
        None
    };

    let response = match context
        .ai_service
        .generate_response(GenerateResponseParams {
            messages,
            execution_summary: &execution_summary,
            context: &context.request_ctx,
            provider: context.agent_runtime.provider.as_deref(),
            model: context.agent_runtime.model.as_deref(),
            max_output_tokens: context.agent_runtime.max_output_tokens,
        })
        .await
    {
        Ok(response) => response,
        Err(ai_error) => {
            if let Some(tool_err) = tool_error_message {
                tracing::warn!(
                    ai_error = %ai_error,
                    tool_error = %tool_err,
                    "AI synthesis failed after tool errors - returning tool errors"
                );
                return Err(anyhow::anyhow!("Tool execution failed: {}", tool_err));
            }
            return Err(ai_error);
        },
    };

    if context
        .tx
        .send(StreamEvent::Text(response.clone()))
        .is_err()
    {
        tracing::debug!("Stream receiver dropped");
    }

    let tool_calls = convert_to_tool_calls(&calls);
    let tool_results = convert_to_call_tool_results(&state);

    Ok(ExecutionResult {
        accumulated_text: response,
        tool_calls,
        tool_results,
        tools,
        iterations: 1,
    })
}

async fn handle_validation_failure(
    validation_errors: Vec<PlanValidationError>,
    context: &ExecutionContext,
    messages: Vec<AiMessage>,
) -> Result<ExecutionResult> {
    let error_messages: Vec<String> = validation_errors.iter().map(ToString::to_string).collect();

    tracing::error!(
        errors = ?error_messages,
        "Template validation failed"
    );

    let validation_summary = format!(
        "Plan validation failed:\n{}",
        error_messages
            .iter()
            .map(|e| format!("- {e}"))
            .collect::<Vec<_>>()
            .join("\n")
    );

    let response = context
        .ai_service
        .generate_response(GenerateResponseParams {
            messages,
            execution_summary: &validation_summary,
            context: &context.request_ctx,
            provider: context.agent_runtime.provider.as_deref(),
            model: context.agent_runtime.model.as_deref(),
            max_output_tokens: context.agent_runtime.max_output_tokens,
        })
        .await?;

    if context
        .tx
        .send(StreamEvent::Text(response.clone()))
        .is_err()
    {
        tracing::debug!("Stream receiver dropped");
    }

    Ok(ExecutionResult {
        accumulated_text: response,
        tool_calls: vec![],
        tool_results: vec![],
        tools: vec![],
        iterations: 1,
    })
}

fn build_tool_summary(calls: &[PlannedToolCall]) -> (String, Value) {
    if calls.len() == 1 {
        (calls[0].tool_name.clone(), calls[0].arguments.clone())
    } else {
        let tool_args_summary: Vec<Value> = calls
            .iter()
            .map(|c| {
                serde_json::json!({
                    "tool": c.tool_name,
                    "arguments": c.arguments
                })
            })
            .collect();
        (
            format!("{} tools", calls.len()),
            serde_json::json!(tool_args_summary),
        )
    }
}

async fn record_execution_status(
    tracking: &ExecutionTrackingService,
    tracked: &TrackedStep,
    state: &ExecutionState,
    has_failures: bool,
) {
    if has_failures {
        let error_message = state
            .failed_results()
            .iter()
            .filter_map(|r| r.error.as_ref())
            .map(String::as_str)
            .collect::<Vec<_>>()
            .join("; ");

        if let Err(e) = tracking.fail(tracked, error_message).await {
            tracing::warn!(error = %e, "Failed to record execution failure");
        }
    } else {
        let tool_result = if state.results.len() == 1 {
            serde_json::json!({
                "tool": state.results[0].tool_name,
                "output": state.results[0].output,
                "duration_ms": state.results[0].duration_ms
            })
        } else {
            serde_json::json!({
                "results": state.results.iter().map(|r| {
                    serde_json::json!({
                        "tool": r.tool_name,
                        "output": r.output,
                        "duration_ms": r.duration_ms
                    })
                }).collect::<Vec<_>>()
            })
        };

        if let Err(e) = tracking.complete(tracked.clone(), Some(tool_result)).await {
            tracing::warn!(error = %e, "Failed to record execution completion");
        }
    }
}