agent_core/controller/stateless/
executor.rs1use tokio_util::sync::CancellationToken;
2
3use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
4use crate::client::providers::anthropic::AnthropicProvider;
5use crate::client::providers::openai::OpenAIProvider;
6use crate::client::LLMClient;
7
8use crate::controller::session::LLMProvider;
9
10use super::types::{
11 RequestOptions, StatelessConfig, StatelessError, StatelessResult, StreamCallback,
12 DEFAULT_MAX_TOKENS,
13};
14
15pub struct StatelessExecutor {
18 client: LLMClient,
19 config: StatelessConfig,
20}
21
22impl StatelessExecutor {
23 pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
25 config.validate()?;
26
27 let client = match config.provider {
28 LLMProvider::Anthropic => {
29 let provider =
30 AnthropicProvider::new(config.api_key.clone(), config.model.clone());
31 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
32 op: "init_client".to_string(),
33 message: format!("failed to initialize LLM client: {}", e),
34 })?
35 }
36 LLMProvider::OpenAI => {
37 let provider = OpenAIProvider::new(config.api_key.clone(), config.model.clone());
38 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
39 op: "init_client".to_string(),
40 message: format!("failed to initialize LLM client: {}", e),
41 })?
42 }
43 };
44
45 Ok(Self { client, config })
46 }
47
48 pub async fn execute(
51 &self,
52 input: &str,
53 options: Option<RequestOptions>,
54 ) -> Result<StatelessResult, StatelessError> {
55 if input.is_empty() {
56 return Err(StatelessError::EmptyInput);
57 }
58
59 let msg_opts = self.build_message_options(options.as_ref());
60 let mut messages = Vec::new();
61
62 let system_prompt = options
64 .as_ref()
65 .and_then(|o| o.system_prompt.as_ref())
66 .or(self.config.system_prompt.as_ref());
67
68 if let Some(prompt) = system_prompt {
69 messages.push(LLMMessage::system(prompt));
70 }
71
72 messages.push(LLMMessage::user(input));
74
75 let response = self
77 .client
78 .send_message(&messages, &msg_opts)
79 .await
80 .map_err(|e| StatelessError::ExecutionFailed {
81 op: "send_message".to_string(),
82 message: e.to_string(),
83 })?;
84
85 let text = self.extract_text(&response);
87
88 Ok(StatelessResult {
89 text,
90 input_tokens: 0, output_tokens: 0, model: self.config.model.clone(),
93 stop_reason: None,
94 })
95 }
96
97 pub async fn execute_stream(
101 &self,
102 input: &str,
103 mut callback: StreamCallback,
104 options: Option<RequestOptions>,
105 cancel_token: Option<CancellationToken>,
106 ) -> Result<StatelessResult, StatelessError> {
107 use futures::StreamExt;
108
109 if input.is_empty() {
110 return Err(StatelessError::EmptyInput);
111 }
112
113 let msg_opts = self.build_message_options(options.as_ref());
114 let mut messages = Vec::new();
115
116 let system_prompt = options
118 .as_ref()
119 .and_then(|o| o.system_prompt.as_ref())
120 .or(self.config.system_prompt.as_ref());
121
122 if let Some(prompt) = system_prompt {
123 messages.push(LLMMessage::system(prompt));
124 }
125
126 messages.push(LLMMessage::user(input));
128
129 let mut stream = self
131 .client
132 .send_message_stream(&messages, &msg_opts)
133 .await
134 .map_err(|e| StatelessError::ExecutionFailed {
135 op: "create_stream".to_string(),
136 message: e.to_string(),
137 })?;
138
139 let mut result = StatelessResult {
141 model: self.config.model.clone(),
142 ..Default::default()
143 };
144 let mut text_builder = String::new();
145 let cancel = cancel_token.unwrap_or_else(CancellationToken::new);
146
147 loop {
148 tokio::select! {
149 _ = cancel.cancelled() => {
150 return Err(StatelessError::Cancelled);
151 }
152 event = stream.next() => {
153 match event {
154 Some(Ok(stream_event)) => {
155 match stream_event {
156 StreamEvent::MessageStart { model, .. } => {
157 result.model = model;
158 }
159 StreamEvent::TextDelta { text, .. } => {
160 text_builder.push_str(&text);
161 if callback(&text).is_err() {
163 return Err(StatelessError::StreamInterrupted);
164 }
165 }
166 StreamEvent::MessageDelta { stop_reason, usage } => {
167 if let Some(usage) = usage {
168 result.input_tokens = usage.input_tokens as i64;
169 result.output_tokens = usage.output_tokens as i64;
170 }
171 result.stop_reason = stop_reason;
172 }
173 StreamEvent::MessageStop => {
174 break;
175 }
176 _ => {}
178 }
179 }
180 Some(Err(e)) => {
181 return Err(StatelessError::ExecutionFailed {
182 op: "streaming".to_string(),
183 message: e.to_string(),
184 });
185 }
186 None => {
187 break;
189 }
190 }
191 }
192 }
193 }
194
195 result.text = text_builder;
196 Ok(result)
197 }
198
199 fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
201 let max_tokens = opts
202 .and_then(|o| o.max_tokens)
203 .unwrap_or(if self.config.max_tokens > 0 {
204 self.config.max_tokens
205 } else {
206 DEFAULT_MAX_TOKENS
207 });
208
209 let temperature = opts
210 .and_then(|o| o.temperature)
211 .or(self.config.temperature);
212
213 MessageOptions {
214 max_tokens: Some(max_tokens),
215 temperature,
216 ..Default::default()
217 }
218 }
219
220 fn extract_text(&self, message: &LLMMessage) -> String {
222 use crate::client::models::Content;
223
224 let mut text = String::new();
225 for block in &message.content {
226 if let Content::Text(t) = block {
227 text.push_str(&t);
228 }
229 }
230 text
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_config_validation() {
240 let config = StatelessConfig {
242 provider: LLMProvider::Anthropic,
243 api_key: "".to_string(),
244 model: "claude-3".to_string(),
245 max_tokens: 4096,
246 system_prompt: None,
247 temperature: None,
248 };
249 assert!(config.validate().is_err());
250
251 let config = StatelessConfig {
253 provider: LLMProvider::Anthropic,
254 api_key: "test-key".to_string(),
255 model: "".to_string(),
256 max_tokens: 4096,
257 system_prompt: None,
258 temperature: None,
259 };
260 assert!(config.validate().is_err());
261
262 let config = StatelessConfig::anthropic("test-key", "claude-3");
264 assert!(config.validate().is_ok());
265 }
266
267 #[test]
268 fn test_request_options_builder() {
269 let opts = RequestOptions::new()
270 .with_model("gpt-4")
271 .with_max_tokens(2048)
272 .with_system_prompt("Be helpful")
273 .with_temperature(0.7);
274
275 assert_eq!(opts.model, Some("gpt-4".to_string()));
276 assert_eq!(opts.max_tokens, Some(2048));
277 assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
278 assert_eq!(opts.temperature, Some(0.7));
279 }
280
281 #[test]
282 fn test_config_builder() {
283 let config = StatelessConfig::anthropic("key", "model")
284 .with_max_tokens(8192)
285 .with_system_prompt("You are helpful")
286 .with_temperature(0.5);
287
288 assert_eq!(config.api_key, "key");
289 assert_eq!(config.model, "model");
290 assert_eq!(config.max_tokens, 8192);
291 assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
292 assert_eq!(config.temperature, Some(0.5));
293 }
294}