llm-worker 0.2.1

A library for building autonomous LLM-powered systems
Documentation
//! 並列ツール実行のテスト
//!
//! Workerが複数のツールを並列に実行することを確認する。

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};

use async_trait::async_trait;
use llm_worker::Worker;
use llm_worker::hook::{
    Hook, HookError, PostToolCall, PostToolCallContext, PostToolCallResult, PreToolCall,
    PreToolCallResult, ToolCallContext,
};
use llm_worker::llm_client::event::{Event, ResponseStatus, StatusEvent};
use llm_worker::tool::{Tool, ToolDefinition, ToolError, ToolMeta};

mod common;
use common::MockLlmClient;

// =============================================================================
// Parallel Execution Test Tools
// =============================================================================

/// 一定時間待機してから応答するツール
#[derive(Clone)]
struct SlowTool {
    name: String,
    delay_ms: u64,
    call_count: Arc<AtomicUsize>,
}

impl SlowTool {
    fn new(name: impl Into<String>, delay_ms: u64) -> Self {
        Self {
            name: name.into(),
            delay_ms,
            call_count: Arc::new(AtomicUsize::new(0)),
        }
    }

    fn call_count(&self) -> usize {
        self.call_count.load(Ordering::SeqCst)
    }

    /// ToolDefinition を作成
    fn definition(&self) -> ToolDefinition {
        let tool = self.clone();
        Arc::new(move || {
            let meta = ToolMeta::new(&tool.name)
                .description("A tool that waits before responding")
                .input_schema(serde_json::json!({
                    "type": "object",
                    "properties": {}
                }));
            (meta, Arc::new(tool.clone()) as Arc<dyn Tool>)
        })
    }
}

#[async_trait]
impl Tool for SlowTool {
    async fn execute(&self, _input_json: &str) -> Result<String, ToolError> {
        self.call_count.fetch_add(1, Ordering::SeqCst);
        tokio::time::sleep(Duration::from_millis(self.delay_ms)).await;
        Ok(format!("Completed after {}ms", self.delay_ms))
    }
}

// =============================================================================
// Tests
// =============================================================================

/// 複数のツールが並列に実行されることを確認
///
/// 各ツールが100msかかる場合、逐次実行なら300ms以上かかるが、
/// 並列実行なら100ms程度で完了するはず。
#[tokio::test]
async fn test_parallel_tool_execution() {
    // 3つのツール呼び出しを含むイベントシーケンス
    let events = vec![
        Event::tool_use_start(0, "call_1", "slow_tool_1"),
        Event::tool_input_delta(0, r#"{}"#),
        Event::tool_use_stop(0),
        Event::tool_use_start(1, "call_2", "slow_tool_2"),
        Event::tool_input_delta(1, r#"{}"#),
        Event::tool_use_stop(1),
        Event::tool_use_start(2, "call_3", "slow_tool_3"),
        Event::tool_input_delta(2, r#"{}"#),
        Event::tool_use_stop(2),
        Event::Status(StatusEvent {
            status: ResponseStatus::Completed,
        }),
    ];

    let client = MockLlmClient::new(events);
    let mut worker = Worker::new(client);

    // 各ツールは100ms待機
    let tool1 = SlowTool::new("slow_tool_1", 100);
    let tool2 = SlowTool::new("slow_tool_2", 100);
    let tool3 = SlowTool::new("slow_tool_3", 100);

    let tool1_clone = tool1.clone();
    let tool2_clone = tool2.clone();
    let tool3_clone = tool3.clone();

    worker.register_tool(tool1.definition()).unwrap();
    worker.register_tool(tool2.definition()).unwrap();
    worker.register_tool(tool3.definition()).unwrap();

    let start = Instant::now();
    let _result = worker.run("Run all tools").await;
    let elapsed = start.elapsed();

    // 全ツールが呼び出されたことを確認
    assert_eq!(tool1_clone.call_count(), 1, "Tool 1 should be called once");
    assert_eq!(tool2_clone.call_count(), 1, "Tool 2 should be called once");
    assert_eq!(tool3_clone.call_count(), 1, "Tool 3 should be called once");

    // 並列実行なら200ms以下で完了するはず(逐次なら300ms以上)
    // マージン込みで250msをしきい値とする
    assert!(
        elapsed < Duration::from_millis(250),
        "Parallel execution should complete in ~100ms, but took {:?}",
        elapsed
    );

    println!("Parallel execution completed in {:?}", elapsed);
}

/// Hook: pre_tool_call でスキップされたツールは実行されないことを確認
#[tokio::test]
async fn test_before_tool_call_skip() {
    let events = vec![
        Event::tool_use_start(0, "call_1", "allowed_tool"),
        Event::tool_input_delta(0, r#"{}"#),
        Event::tool_use_stop(0),
        Event::tool_use_start(1, "call_2", "blocked_tool"),
        Event::tool_input_delta(1, r#"{}"#),
        Event::tool_use_stop(1),
        Event::Status(StatusEvent {
            status: ResponseStatus::Completed,
        }),
    ];

    let client = MockLlmClient::new(events);
    let mut worker = Worker::new(client);

    let allowed_tool = SlowTool::new("allowed_tool", 10);
    let blocked_tool = SlowTool::new("blocked_tool", 10);

    let allowed_clone = allowed_tool.clone();
    let blocked_clone = blocked_tool.clone();

    worker.register_tool(allowed_tool.definition()).unwrap();
    worker.register_tool(blocked_tool.definition()).unwrap();

    // "blocked_tool" をスキップするHook
    struct BlockingHook;

    #[async_trait]
    impl Hook<PreToolCall> for BlockingHook {
        async fn call(&self, ctx: &mut ToolCallContext) -> Result<PreToolCallResult, HookError> {
            if ctx.call.name == "blocked_tool" {
                Ok(PreToolCallResult::Skip)
            } else {
                Ok(PreToolCallResult::Continue)
            }
        }
    }

    worker.add_pre_tool_call_hook(BlockingHook);

    let _result = worker.run("Test hook").await;

    // allowed_tool は呼び出されるが、blocked_tool は呼び出されない
    assert_eq!(
        allowed_clone.call_count(),
        1,
        "Allowed tool should be called"
    );
    assert_eq!(
        blocked_clone.call_count(),
        0,
        "Blocked tool should not be called"
    );
}

/// Hook: post_tool_call で結果が改変されることを確認
#[tokio::test]
async fn test_post_tool_call_modification() {
    // 複数リクエストに対応するレスポンスを準備
    let client = MockLlmClient::with_responses(vec![
        // 1回目のリクエスト: ツール呼び出し
        vec![
            Event::tool_use_start(0, "call_1", "test_tool"),
            Event::tool_input_delta(0, r#"{}"#),
            Event::tool_use_stop(0),
            Event::Status(StatusEvent {
                status: ResponseStatus::Completed,
            }),
        ],
        // 2回目のリクエスト: ツール結果を受けてテキストレスポンス
        vec![
            Event::text_block_start(0),
            Event::text_delta(0, "Done!"),
            Event::text_block_stop(0, None),
            Event::Status(StatusEvent {
                status: ResponseStatus::Completed,
            }),
        ],
    ]);

    let mut worker = Worker::new(client);

    #[derive(Clone)]
    struct SimpleTool;

    #[async_trait]
    impl Tool for SimpleTool {
        async fn execute(&self, _: &str) -> Result<String, ToolError> {
            Ok("Original Result".to_string())
        }
    }

    fn simple_tool_definition() -> ToolDefinition {
        Arc::new(|| {
            let meta = ToolMeta::new("test_tool")
                .description("Test")
                .input_schema(serde_json::json!({}));
            (meta, Arc::new(SimpleTool) as Arc<dyn Tool>)
        })
    }

    worker.register_tool(simple_tool_definition()).unwrap();

    // 結果を改変するHook
    struct ModifyingHook {
        modified_content: Arc<std::sync::Mutex<Option<String>>>,
    }

    #[async_trait]
    impl Hook<PostToolCall> for ModifyingHook {
        async fn call(
            &self,
            ctx: &mut PostToolCallContext,
        ) -> Result<PostToolCallResult, HookError> {
            ctx.result.content = format!("[Modified] {}", ctx.result.content);
            *self.modified_content.lock().unwrap() = Some(ctx.result.content.clone());
            Ok(PostToolCallResult::Continue)
        }
    }

    let modified_content = Arc::new(std::sync::Mutex::new(None));
    worker.add_post_tool_call_hook(ModifyingHook {
        modified_content: modified_content.clone(),
    });

    let result = worker.run("Test modification").await;

    assert!(result.is_ok(), "Worker should complete: {:?}", result);

    // Hookが呼ばれて内容が改変されたことを確認
    let content = modified_content.lock().unwrap().clone();
    assert!(content.is_some(), "Hook should have been called");
    assert!(
        content.unwrap().contains("[Modified]"),
        "Result should be modified"
    );
}