swarm_engine_llm/
decider.rs1use std::future::Future;
16use std::pin::Pin;
17
18pub type BatchDecisionFuture<'a> =
20 Pin<Box<dyn Future<Output = Vec<Result<DecisionResponse, LlmError>>> + Send + 'a>>;
21
22pub use swarm_engine_core::agent::{
24 ActionCandidate, ActionParam, DecisionResponse, ResolvedContext, WorkerDecisionRequest,
25};
26pub use swarm_engine_core::types::LoraConfig;
27
28#[derive(Debug, Clone, thiserror::Error)]
30pub enum LlmError {
31 #[error("LLM error (transient): {0}")]
33 Transient(String),
34
35 #[error("LLM error: {0}")]
37 Permanent(String),
38}
39
40impl LlmError {
41 pub fn transient(message: impl Into<String>) -> Self {
42 Self::Transient(message.into())
43 }
44
45 pub fn permanent(message: impl Into<String>) -> Self {
46 Self::Permanent(message.into())
47 }
48
49 pub fn is_transient(&self) -> bool {
50 matches!(self, Self::Transient(_))
51 }
52
53 pub fn message(&self) -> &str {
54 match self {
55 Self::Transient(msg) => msg,
56 Self::Permanent(msg) => msg,
57 }
58 }
59}
60
61impl From<swarm_engine_core::error::SwarmError> for LlmError {
62 fn from(err: swarm_engine_core::error::SwarmError) -> Self {
63 if err.is_transient() {
64 Self::Transient(err.message())
65 } else {
66 Self::Permanent(err.message())
67 }
68 }
69}
70
71impl From<LlmError> for swarm_engine_core::error::SwarmError {
72 fn from(err: LlmError) -> Self {
73 match err {
74 LlmError::Transient(message) => {
75 swarm_engine_core::error::SwarmError::LlmTransient { message }
76 }
77 LlmError::Permanent(message) => {
78 swarm_engine_core::error::SwarmError::LlmPermanent { message }
79 }
80 }
81 }
82}
83
84pub trait LlmDecider: Send + Sync {
88 fn decide(
90 &self,
91 request: WorkerDecisionRequest,
92 ) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>>;
93
94 fn call_raw(
103 &self,
104 _prompt: &str,
105 _lora: Option<&LoraConfig>,
106 ) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
107 Box::pin(async { Err(LlmError::permanent("call_raw not implemented")) })
108 }
109
110 fn decide_batch(&self, requests: Vec<WorkerDecisionRequest>) -> BatchDecisionFuture<'_> {
112 Box::pin(async move {
114 let mut results = Vec::with_capacity(requests.len());
115 for req in requests {
116 results.push(self.decide(req).await);
117 }
118 results
119 })
120 }
121
122 fn model_name(&self) -> &str;
124
125 fn endpoint(&self) -> &str {
127 "unknown"
128 }
129
130 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
132
133 fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
138 Box::pin(async { None })
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct LlmDeciderConfig {
145 pub model: String,
147 pub endpoint: String,
149 pub timeout_ms: u64,
151 pub max_batch_size: usize,
153 pub temperature: f32,
155 pub system_prompt: Option<String>,
157}
158
159impl Default for LlmDeciderConfig {
160 fn default() -> Self {
161 Self {
162 model: "qwen2.5-coder:1.5b".to_string(),
163 endpoint: "http://localhost:11434".to_string(),
164 timeout_ms: 5000,
165 max_batch_size: 100,
166 temperature: 0.1,
167 system_prompt: None,
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_llm_error_transient() {
178 let err = LlmError::transient("connection timeout");
179 assert!(err.is_transient());
180 assert_eq!(err.message(), "connection timeout");
181 assert_eq!(
182 format!("{}", err),
183 "LLM error (transient): connection timeout"
184 );
185 }
186
187 #[test]
188 fn test_llm_error_permanent() {
189 let err = LlmError::permanent("invalid model");
190 assert!(!err.is_transient());
191 assert_eq!(err.message(), "invalid model");
192 }
193
194 #[test]
195 fn test_llm_decider_config_default() {
196 let config = LlmDeciderConfig::default();
197 assert_eq!(config.model, "qwen2.5-coder:1.5b");
198 assert_eq!(config.endpoint, "http://localhost:11434");
199 assert_eq!(config.timeout_ms, 5000);
200 assert_eq!(config.max_batch_size, 100);
201 assert!((config.temperature - 0.1).abs() < 0.001);
202 }
203}