llm-worker 0.2.0

A library for building autonomous LLM-powered systems
Documentation
//! ツールマクロのテスト
//!
//! `#[tool_registry]` と `#[tool]` マクロの動作を確認する。

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

// マクロ展開に必要なインポート
use schemars;
use serde;

use llm_worker_macros::tool_registry;

// =============================================================================
// Test: Basic Tool Generation
// =============================================================================

/// シンプルなコンテキスト構造体
#[derive(Clone)]
struct SimpleContext {
    prefix: String,
}

#[tool_registry]
impl SimpleContext {
    /// メッセージに挨拶を追加する
    ///
    /// 指定されたメッセージにプレフィックスを付けて返します。
    #[tool]
    async fn greet(&self, message: String) -> String {
        format!("{}: {}", self.prefix, message)
    }

    /// 二つの数を足す
    #[tool]
    async fn add(&self, a: i32, b: i32) -> i32 {
        a + b
    }

    /// 引数なしのツール
    #[tool]
    async fn get_prefix(&self) -> String {
        self.prefix.clone()
    }
}

#[tokio::test]
async fn test_basic_tool_generation() {
    let ctx = SimpleContext {
        prefix: "Hello".to_string(),
    };

    // ファクトリメソッドでToolDefinitionを取得
    let greet_definition = ctx.greet_definition();

    // ファクトリを呼び出してMetaとToolを取得
    let (meta, tool) = greet_definition();

    // メタ情報の確認
    assert_eq!(meta.name, "greet");
    assert!(
        meta.description.contains("メッセージに挨拶を追加する"),
        "Description should contain doc comment: {}",
        meta.description
    );
    assert!(
        meta.input_schema.get("properties").is_some(),
        "Schema should have properties"
    );

    println!(
        "Schema: {}",
        serde_json::to_string_pretty(&meta.input_schema).unwrap()
    );

    // 実行テスト
    let result = tool.execute(r#"{"message": "World"}"#).await;
    assert!(result.is_ok(), "Should execute successfully");
    let output = result.unwrap();
    assert!(output.contains("Hello"), "Output should contain prefix");
    assert!(output.contains("World"), "Output should contain message");
}

#[tokio::test]
async fn test_multiple_arguments() {
    let ctx = SimpleContext {
        prefix: "".to_string(),
    };

    let (meta, tool) = ctx.add_definition()();

    assert_eq!(meta.name, "add");

    let result = tool.execute(r#"{"a": 10, "b": 20}"#).await;
    assert!(result.is_ok());
    let output = result.unwrap();
    assert!(output.contains("30"), "Should contain sum: {}", output);
}

#[tokio::test]
async fn test_no_arguments() {
    let ctx = SimpleContext {
        prefix: "TestPrefix".to_string(),
    };

    let (meta, tool) = ctx.get_prefix_definition()();

    assert_eq!(meta.name, "get_prefix");

    // 空のJSONオブジェクトで呼び出し
    let result = tool.execute(r#"{}"#).await;
    assert!(result.is_ok());
    let output = result.unwrap();
    assert!(
        output.contains("TestPrefix"),
        "Should contain prefix: {}",
        output
    );
}

#[tokio::test]
async fn test_invalid_arguments() {
    let ctx = SimpleContext {
        prefix: "".to_string(),
    };

    let (_, tool) = ctx.greet_definition()();

    // 不正なJSON
    let result = tool.execute(r#"{"wrong_field": "value"}"#).await;
    assert!(result.is_err(), "Should fail with invalid arguments");
}

// =============================================================================
// Test: Result Return Type
// =============================================================================

#[derive(Clone)]
struct FallibleContext;

#[derive(Debug)]
struct MyError(String);

impl std::fmt::Display for MyError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

#[tool_registry]
impl FallibleContext {
    /// 与えられた値を検証する
    #[tool]
    async fn validate(&self, value: i32) -> Result<String, MyError> {
        if value > 0 {
            Ok(format!("Valid: {}", value))
        } else {
            Err(MyError("Value must be positive".to_string()))
        }
    }
}

#[tokio::test]
async fn test_result_return_type_success() {
    let ctx = FallibleContext;
    let (_, tool) = ctx.validate_definition()();

    let result = tool.execute(r#"{"value": 42}"#).await;
    assert!(result.is_ok(), "Should succeed for positive value");
    let output = result.unwrap();
    assert!(output.contains("Valid"), "Should contain Valid: {}", output);
}

#[tokio::test]
async fn test_result_return_type_error() {
    let ctx = FallibleContext;
    let (_, tool) = ctx.validate_definition()();

    let result = tool.execute(r#"{"value": -1}"#).await;
    assert!(result.is_err(), "Should fail for negative value");

    let err = result.unwrap_err();
    assert!(
        err.to_string().contains("positive"),
        "Error should mention positive: {}",
        err
    );
}

// =============================================================================
// Test: Synchronous Methods
// =============================================================================

#[derive(Clone)]
struct SyncContext {
    counter: Arc<AtomicUsize>,
}

#[tool_registry]
impl SyncContext {
    /// カウンターをインクリメントして返す (非async)
    #[tool]
    fn increment(&self) -> usize {
        self.counter.fetch_add(1, Ordering::SeqCst) + 1
    }
}

#[tokio::test]
async fn test_sync_method() {
    let ctx = SyncContext {
        counter: Arc::new(AtomicUsize::new(0)),
    };

    let (_, tool) = ctx.increment_definition()();

    // 3回実行
    let result1 = tool.execute(r#"{}"#).await;
    let result2 = tool.execute(r#"{}"#).await;
    let result3 = tool.execute(r#"{}"#).await;

    assert!(result1.is_ok());
    assert!(result2.is_ok());
    assert!(result3.is_ok());

    // カウンターは3になっているはず
    assert_eq!(ctx.counter.load(Ordering::SeqCst), 3);
}

// =============================================================================
// Test: ToolMeta Immutability
// =============================================================================

#[tokio::test]
async fn test_tool_meta_immutability() {
    let ctx = SimpleContext {
        prefix: "Test".to_string(),
    };

    // 2回取得しても同じメタ情報が得られることを確認
    let (meta1, _) = ctx.greet_definition()();
    let (meta2, _) = ctx.greet_definition()();

    assert_eq!(meta1.name, meta2.name);
    assert_eq!(meta1.description, meta2.description);
    assert_eq!(meta1.input_schema, meta2.input_schema);
}