robomotion 0.1.3

Official Rust SDK for building Robomotion RPA packages
Documentation
//! Tool interceptor for automatic AI tool request handling.

use crate::message::Context;
use crate::runtime::{is_tool_request, tool_response, MessageHandler, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;

/// Tool interceptor that wraps a MessageHandler to automatically handle tool requests.
pub struct ToolInterceptor {
    original_handler: Arc<RwLock<Box<dyn MessageHandler>>>,
    has_tool: bool,
}

impl ToolInterceptor {
    /// Create a new tool interceptor for a handler.
    pub fn new(handler: Box<dyn MessageHandler>, has_tool: bool) -> Self {
        Self {
            original_handler: Arc::new(RwLock::new(handler)),
            has_tool,
        }
    }

    /// Check if the wrapped handler supports tools.
    pub fn has_tool(&self) -> bool {
        self.has_tool
    }

    /// Handle a tool request.
    async fn handle_tool_request(&self, ctx: &mut Context) -> Result<()> {
        // Call the original handler
        let result = {
            let mut handler = self.original_handler.write().await;
            handler.on_message(ctx).await
        };

        // If response not already sent, send default response
        if !has_tool_response_been_sent(ctx) {
            if let Err(ref e) = result {
                tool_response(ctx, "error", None, &e.to_string()).await?;
            } else {
                // Collect output variables
                let output_data = self.collect_output_variables(ctx);
                tool_response(ctx, "success", Some(output_data), "").await?;
            }
        }

        result
    }

    /// Collect output variables from the context.
    fn collect_output_variables(
        &self,
        _ctx: &Context,
    ) -> std::collections::HashMap<String, serde_json::Value> {
        let mut output_data = std::collections::HashMap::new();

        // Try to get common output fields
        // In practice, this would use reflection or stored field info
        // For now, just return a status
        output_data.insert("status".to_string(), serde_json::json!("completed"));

        output_data
    }
}

#[async_trait]
impl MessageHandler for ToolInterceptor {
    async fn on_create(&mut self) -> Result<()> {
        let mut handler = self.original_handler.write().await;
        handler.on_create().await
    }

    async fn on_message(&mut self, ctx: &mut Context) -> Result<()> {
        // Check if this is a tool request and this handler supports tools
        if self.has_tool && is_tool_request(ctx) {
            return self.handle_tool_request(ctx).await;
        }

        // Pass through to original handler
        let mut handler = self.original_handler.write().await;
        handler.on_message(ctx).await
    }

    async fn on_close(&mut self) -> Result<()> {
        let mut handler = self.original_handler.write().await;
        handler.on_close().await
    }
}

/// Check if a tool response has already been sent.
fn has_tool_response_been_sent(ctx: &Context) -> bool {
    // Check if context has been cleared (indicating response sent)
    ctx.is_empty() || ctx.get_raw().is_err()
}

/// Wrap a handler with tool interceptor if needed.
pub fn wrap_with_tool_interceptor(
    handler: Box<dyn MessageHandler>,
    has_tool: bool,
) -> Box<dyn MessageHandler> {
    if has_tool {
        Box::new(ToolInterceptor::new(handler, true))
    } else {
        handler
    }
}