use std::sync::Arc;
use swarm_engine_core::actions::ActionDef;
use swarm_engine_core::agent::{
BatchDecisionRequest, BatchInvokeError, BatchInvokeResult, BatchInvoker,
};
use swarm_engine_core::exploration::{DependencyGraph, SelectResult};
use swarm_engine_core::extensions::Extensions;
use swarm_engine_core::types::LoraConfig;
use crate::batch_processor::{BatchProcessError, BatchProcessor};
pub struct LlmBatchInvoker<P: BatchProcessor> {
processor: Arc<P>,
runtime: tokio::runtime::Handle,
}
impl<P: BatchProcessor> LlmBatchInvoker<P> {
pub fn new(processor: P, runtime: tokio::runtime::Handle) -> Self {
Self {
processor: Arc::new(processor),
runtime,
}
}
pub fn from_arc(processor: Arc<P>, runtime: tokio::runtime::Handle) -> Self {
Self { processor, runtime }
}
}
impl<P: BatchProcessor + 'static> BatchInvoker for LlmBatchInvoker<P> {
fn invoke(&self, request: BatchDecisionRequest, extensions: &Extensions) -> BatchInvokeResult {
let processor = Arc::clone(&self.processor);
let lora = extensions.get::<LoraConfig>().cloned();
let request = if lora.is_some() {
let mut modified_request = request;
for req in &mut modified_request.requests {
if req.lora.is_none() {
req.lora = lora.clone();
}
}
modified_request
} else {
request
};
self.runtime.block_on(async move {
let results = processor.process(request).await;
results
.into_iter()
.map(|(worker_id, result)| {
let mapped = result.map_err(|e: BatchProcessError| {
if e.is_transient() {
BatchInvokeError::Transient(e.message().to_string())
} else {
BatchInvokeError::Permanent(e.message().to_string())
}
});
(worker_id, mapped)
})
.collect()
})
}
fn plan_dependencies(
&self,
task: &str,
actions: &[ActionDef],
hint: Option<&SelectResult>,
) -> Option<DependencyGraph> {
let processor = Arc::clone(&self.processor);
let task = task.to_string();
let actions = actions.to_vec();
let hint = hint.cloned();
self.runtime
.block_on(async move { processor.plan_dependencies(&task, &actions, hint.as_ref()).await })
}
fn name(&self) -> &str {
self.processor.name()
}
fn is_healthy(&self) -> bool {
let processor = Arc::clone(&self.processor);
self.runtime
.block_on(async move { processor.is_healthy().await })
}
}
use crate::batch_processor::LlmBatchProcessor;
use crate::decider::LlmDecider;
pub fn create_llm_invoker<D: LlmDecider + 'static>(
decider: D,
runtime: tokio::runtime::Handle,
) -> LlmBatchInvoker<LlmBatchProcessor<D>> {
let processor = LlmBatchProcessor::new(decider);
LlmBatchInvoker::new(processor, runtime)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use swarm_engine_core::agent::{
ContextTarget, DecisionResponse, GlobalContext, ManagerId, ResolvedContext,
WorkerDecisionRequest,
};
use swarm_engine_core::extensions::Extensions;
use swarm_engine_core::types::WorkerId;
use crate::batch_processor::BatchProcessResult;
use swarm_engine_core::agent::ActionCandidate;
fn create_test_context(worker_id: WorkerId, action_names: Vec<&str>) -> ResolvedContext {
let global = GlobalContext::new(1);
let candidates: Vec<ActionCandidate> = action_names
.into_iter()
.map(|name| ActionCandidate {
name: name.to_string(),
description: format!("{} action", name),
params: vec![],
example: None,
})
.collect();
ResolvedContext::new(global, ContextTarget::Worker(worker_id)).with_candidates(candidates)
}
struct MockBatchProcessor {
response_tool: String,
}
impl MockBatchProcessor {
fn new(tool: impl Into<String>) -> Self {
Self {
response_tool: tool.into(),
}
}
}
impl BatchProcessor for MockBatchProcessor {
fn process(
&self,
request: BatchDecisionRequest,
) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
let tool = self.response_tool.clone();
Box::pin(async move {
request
.requests
.iter()
.map(|req| {
let response = DecisionResponse {
tool: tool.clone(),
target: format!("target_{}", req.worker_id.0),
args: HashMap::new(),
reasoning: Some("Mock response".to_string()),
confidence: 0.9,
prompt: None,
raw_response: None,
};
(req.worker_id, Ok(response))
})
.collect()
})
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
Box::pin(async { true })
}
fn name(&self) -> &str {
"MockBatchProcessor"
}
}
#[test]
fn test_llm_batch_invoker_basic() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let processor = MockBatchProcessor::new("TestAction");
let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());
let request = BatchDecisionRequest {
manager_id: ManagerId(0),
requests: vec![
WorkerDecisionRequest {
worker_id: WorkerId(0),
query: "What to do?".to_string(),
context: create_test_context(WorkerId(0), vec!["A", "B"]),
lora: None,
},
WorkerDecisionRequest {
worker_id: WorkerId(1),
query: "What to do?".to_string(),
context: create_test_context(WorkerId(1), vec!["A", "B"]),
lora: None,
},
],
};
let results = invoker.invoke(request, &Extensions::new());
assert_eq!(results.len(), 2);
for (worker_id, result) in results {
let response = result.expect("Should succeed");
assert_eq!(response.tool, "TestAction");
assert_eq!(response.target, format!("target_{}", worker_id.0));
}
}
#[test]
fn test_llm_batch_invoker_name() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let processor = MockBatchProcessor::new("Test");
let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());
assert_eq!(invoker.name(), "MockBatchProcessor");
}
#[test]
fn test_llm_batch_invoker_is_healthy() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let processor = MockBatchProcessor::new("Test");
let invoker = LlmBatchInvoker::new(processor, runtime.handle().clone());
assert!(invoker.is_healthy());
}
}