use std::future::Future;
use std::pin::Pin;
pub type BatchDecisionFuture<'a> =
Pin<Box<dyn Future<Output = Vec<Result<DecisionResponse, LlmError>>> + Send + 'a>>;
pub use swarm_engine_core::agent::{
ActionCandidate, ActionParam, DecisionResponse, ResolvedContext, WorkerDecisionRequest,
};
pub use swarm_engine_core::types::LoraConfig;
#[derive(Debug, Clone, thiserror::Error)]
pub enum LlmError {
#[error("LLM error (transient): {0}")]
Transient(String),
#[error("LLM error: {0}")]
Permanent(String),
}
impl LlmError {
pub fn transient(message: impl Into<String>) -> Self {
Self::Transient(message.into())
}
pub fn permanent(message: impl Into<String>) -> Self {
Self::Permanent(message.into())
}
pub fn is_transient(&self) -> bool {
matches!(self, Self::Transient(_))
}
pub fn message(&self) -> &str {
match self {
Self::Transient(msg) => msg,
Self::Permanent(msg) => msg,
}
}
}
impl From<swarm_engine_core::error::SwarmError> for LlmError {
fn from(err: swarm_engine_core::error::SwarmError) -> Self {
if err.is_transient() {
Self::Transient(err.message())
} else {
Self::Permanent(err.message())
}
}
}
impl From<LlmError> for swarm_engine_core::error::SwarmError {
fn from(err: LlmError) -> Self {
match err {
LlmError::Transient(message) => {
swarm_engine_core::error::SwarmError::LlmTransient { message }
}
LlmError::Permanent(message) => {
swarm_engine_core::error::SwarmError::LlmPermanent { message }
}
}
}
}
pub trait LlmDecider: Send + Sync {
fn decide(
&self,
request: WorkerDecisionRequest,
) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>>;
fn call_raw(
&self,
_prompt: &str,
_lora: Option<&LoraConfig>,
) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
Box::pin(async { Err(LlmError::permanent("call_raw not implemented")) })
}
fn decide_batch(&self, requests: Vec<WorkerDecisionRequest>) -> BatchDecisionFuture<'_> {
Box::pin(async move {
let mut results = Vec::with_capacity(requests.len());
for req in requests {
results.push(self.decide(req).await);
}
results
})
}
fn model_name(&self) -> &str;
fn endpoint(&self) -> &str {
"unknown"
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
Box::pin(async { None })
}
}
#[derive(Debug, Clone)]
pub struct LlmDeciderConfig {
pub model: String,
pub endpoint: String,
pub timeout_ms: u64,
pub max_batch_size: usize,
pub temperature: f32,
pub system_prompt: Option<String>,
}
impl Default for LlmDeciderConfig {
fn default() -> Self {
Self {
model: "qwen2.5-coder:1.5b".to_string(),
endpoint: "http://localhost:11434".to_string(),
timeout_ms: 5000,
max_batch_size: 100,
temperature: 0.1,
system_prompt: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_error_transient() {
let err = LlmError::transient("connection timeout");
assert!(err.is_transient());
assert_eq!(err.message(), "connection timeout");
assert_eq!(
format!("{}", err),
"LLM error (transient): connection timeout"
);
}
#[test]
fn test_llm_error_permanent() {
let err = LlmError::permanent("invalid model");
assert!(!err.is_transient());
assert_eq!(err.message(), "invalid model");
}
#[test]
fn test_llm_decider_config_default() {
let config = LlmDeciderConfig::default();
assert_eq!(config.model, "qwen2.5-coder:1.5b");
assert_eq!(config.endpoint, "http://localhost:11434");
assert_eq!(config.timeout_ms, 5000);
assert_eq!(config.max_batch_size, 100);
assert!((config.temperature - 0.1).abs() < 0.001);
}
}