swarm-engine-llm 0.1.6

LLM integration backends for SwarmEngine
Documentation
//! BatchInvoker 実装 - Core の BatchInvoker trait を LLM で実装
//!
//! `BatchProcessor`(async)を `BatchInvoker`(sync)にアダプトする。
//!
//! # 使用例
//!
//! ```ignore
//! use swarm_engine_llm::{LlmBatchInvoker, LlmBatchProcessor, OllamaDecider, OllamaConfig};
//! use swarm_engine_core::orchestrator::OrchestratorBuilder;
//!
//! let decider = OllamaDecider::new(OllamaConfig::default());
//! let processor = LlmBatchProcessor::new(decider);
//! let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());
//!
//! let orchestrator = OrchestratorBuilder::new()
//!     .add_worker(worker)
//!     .batch_invoker(invoker)
//!     .build(runtime.handle().clone());
//! ```

use std::sync::Arc;

use swarm_engine_core::actions::ActionDef;
use swarm_engine_core::agent::{
    BatchDecisionRequest, BatchInvokeError, BatchInvokeResult, BatchInvoker,
};
use swarm_engine_core::exploration::{DependencyGraph, SelectResult};
use swarm_engine_core::extensions::Extensions;
use swarm_engine_core::types::LoraConfig;

use crate::batch_processor::{BatchProcessError, BatchProcessor};

/// LLM BatchInvoker - BatchProcessor を BatchInvoker にアダプト
///
/// async な `BatchProcessor` を sync な `BatchInvoker` として提供。
/// 内部で tokio runtime を使用してブロッキング実行する。
pub struct LlmBatchInvoker<P: BatchProcessor> {
    processor: Arc<P>,
    runtime: tokio::runtime::Handle,
}

impl<P: BatchProcessor> LlmBatchInvoker<P> {
    /// 新しい LlmBatchInvoker を作成
    pub fn new(processor: P, runtime: tokio::runtime::Handle) -> Self {
        Self {
            processor: Arc::new(processor),
            runtime,
        }
    }

    /// Arc で共有された processor から作成
    pub fn from_arc(processor: Arc<P>, runtime: tokio::runtime::Handle) -> Self {
        Self { processor, runtime }
    }
}

impl<P: BatchProcessor + 'static> BatchInvoker for LlmBatchInvoker<P> {
    fn invoke(&self, request: BatchDecisionRequest, extensions: &Extensions) -> BatchInvokeResult {
        let processor = Arc::clone(&self.processor);

        // Extensions から LoRA 設定を取得
        let lora = extensions.get::<LoraConfig>().cloned();

        // LoRA 設定がある場合は各リクエストに適用
        let request = if lora.is_some() {
            let mut modified_request = request;
            for req in &mut modified_request.requests {
                if req.lora.is_none() {
                    req.lora = lora.clone();
                }
            }
            modified_request
        } else {
            request
        };

        // async を sync にブロッキング実行
        self.runtime.block_on(async move {
            let results = processor.process(request).await;

            // BatchProcessResult -> BatchInvokeResult に変換
            results
                .into_iter()
                .map(|(worker_id, result)| {
                    let mapped = result.map_err(|e: BatchProcessError| {
                        if e.is_transient() {
                            BatchInvokeError::Transient(e.message().to_string())
                        } else {
                            BatchInvokeError::Permanent(e.message().to_string())
                        }
                    });
                    (worker_id, mapped)
                })
                .collect()
        })
    }

    fn plan_dependencies(
        &self,
        task: &str,
        actions: &[ActionDef],
        hint: Option<&SelectResult>,
    ) -> Option<DependencyGraph> {
        let processor = Arc::clone(&self.processor);
        let task = task.to_string();
        let actions = actions.to_vec();
        let hint = hint.cloned();

        // async を sync にブロッキング実行
        self.runtime
            .block_on(async move { processor.plan_dependencies(&task, &actions, hint.as_ref()).await })
    }

    fn name(&self) -> &str {
        self.processor.name()
    }

    fn is_healthy(&self) -> bool {
        let processor = Arc::clone(&self.processor);
        self.runtime
            .block_on(async move { processor.is_healthy().await })
    }
}

// ============================================================================
// ショートカット関数
// ============================================================================

use crate::batch_processor::LlmBatchProcessor;
use crate::decider::LlmDecider;

/// LlmDecider から BatchInvoker を簡単に作成
///
/// # Example
///
/// ```ignore
/// use swarm_engine_llm::{create_llm_invoker, OllamaDecider, OllamaConfig};
///
/// let decider = OllamaDecider::new(OllamaConfig::default());
/// let invoker = create_llm_invoker(decider, runtime.handle().clone());
/// ```
pub fn create_llm_invoker<D: LlmDecider + 'static>(
    decider: D,
    runtime: tokio::runtime::Handle,
) -> LlmBatchInvoker<LlmBatchProcessor<D>> {
    let processor = LlmBatchProcessor::new(decider);
    LlmBatchInvoker::new(processor, runtime)
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::*;
    use std::collections::HashMap;
    use std::future::Future;
    use std::pin::Pin;

    use swarm_engine_core::agent::{
        ContextTarget, DecisionResponse, GlobalContext, ManagerId, ResolvedContext,
        WorkerDecisionRequest,
    };
    use swarm_engine_core::extensions::Extensions;
    use swarm_engine_core::types::WorkerId;

    use crate::batch_processor::BatchProcessResult;
    use swarm_engine_core::agent::ActionCandidate;

    /// テスト用の ResolvedContext を作成
    fn create_test_context(worker_id: WorkerId, action_names: Vec<&str>) -> ResolvedContext {
        let global = GlobalContext::new(1);
        let candidates: Vec<ActionCandidate> = action_names
            .into_iter()
            .map(|name| ActionCandidate {
                name: name.to_string(),
                description: format!("{} action", name),
                params: vec![],
                example: None,
            })
            .collect();
        ResolvedContext::new(global, ContextTarget::Worker(worker_id)).with_candidates(candidates)
    }

    /// テスト用 Mock BatchProcessor
    struct MockBatchProcessor {
        response_tool: String,
    }

    impl MockBatchProcessor {
        fn new(tool: impl Into<String>) -> Self {
            Self {
                response_tool: tool.into(),
            }
        }
    }

    impl BatchProcessor for MockBatchProcessor {
        fn process(
            &self,
            request: BatchDecisionRequest,
        ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
            let tool = self.response_tool.clone();
            Box::pin(async move {
                request
                    .requests
                    .iter()
                    .map(|req| {
                        let response = DecisionResponse {
                            tool: tool.clone(),
                            target: format!("target_{}", req.worker_id.0),
                            args: HashMap::new(),
                            reasoning: Some("Mock response".to_string()),
                            confidence: 0.9,
                            prompt: None,
                            raw_response: None,
                        };
                        (req.worker_id, Ok(response))
                    })
                    .collect()
            })
        }

        fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
            Box::pin(async { true })
        }

        fn name(&self) -> &str {
            "MockBatchProcessor"
        }
    }

    #[test]
    fn test_llm_batch_invoker_basic() {
        let runtime = tokio::runtime::Runtime::new().unwrap();

        let processor = MockBatchProcessor::new("TestAction");
        let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());

        let request = BatchDecisionRequest {
            manager_id: ManagerId(0),
            requests: vec![
                WorkerDecisionRequest {
                    worker_id: WorkerId(0),
                    query: "What to do?".to_string(),
                    context: create_test_context(WorkerId(0), vec!["A", "B"]),
                    lora: None,
                },
                WorkerDecisionRequest {
                    worker_id: WorkerId(1),
                    query: "What to do?".to_string(),
                    context: create_test_context(WorkerId(1), vec!["A", "B"]),
                    lora: None,
                },
            ],
        };

        let results = invoker.invoke(request, &Extensions::new());

        assert_eq!(results.len(), 2);
        for (worker_id, result) in results {
            let response = result.expect("Should succeed");
            assert_eq!(response.tool, "TestAction");
            assert_eq!(response.target, format!("target_{}", worker_id.0));
        }
    }

    #[test]
    fn test_llm_batch_invoker_name() {
        let runtime = tokio::runtime::Runtime::new().unwrap();

        let processor = MockBatchProcessor::new("Test");
        let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());

        assert_eq!(invoker.name(), "MockBatchProcessor");
    }

    #[test]
    fn test_llm_batch_invoker_is_healthy() {
        let runtime = tokio::runtime::Runtime::new().unwrap();

        let processor = MockBatchProcessor::new("Test");
        let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());

        assert!(invoker.is_healthy());
    }
}