agent_core_runtime/controller/stateless/
executor.rs1use tokio_util::sync::CancellationToken;
2
3use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
4use crate::client::providers::anthropic::AnthropicProvider;
5use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
6use crate::client::providers::cohere::CohereProvider;
7use crate::client::providers::gemini::GeminiProvider;
8use crate::client::providers::openai::OpenAIProvider;
9use crate::client::LLMClient;
10
11use crate::controller::session::LLMProvider;
12
13use super::types::{
14 RequestOptions, StatelessConfig, StatelessError, StatelessResult, StreamCallback,
15 DEFAULT_MAX_TOKENS,
16};
17
18pub struct StatelessExecutor {
21 client: LLMClient,
22 config: StatelessConfig,
23}
24
25impl StatelessExecutor {
26 pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
28 config.validate()?;
29
30 let client = match config.provider {
31 LLMProvider::Anthropic => {
32 let provider =
33 AnthropicProvider::new(config.api_key.clone(), config.model.clone());
34 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
35 op: "init_client".to_string(),
36 message: format!("failed to initialize LLM client: {}", e),
37 })?
38 }
39 LLMProvider::OpenAI => {
40 let provider = if let (Some(resource), Some(deployment)) =
42 (&config.azure_resource, &config.azure_deployment)
43 {
44 let api_version = config
45 .azure_api_version
46 .clone()
47 .unwrap_or_else(|| "2024-10-21".to_string());
48 OpenAIProvider::azure(
49 config.api_key.clone(),
50 resource.clone(),
51 deployment.clone(),
52 api_version,
53 )
54 } else if let Some(base_url) = &config.base_url {
55 OpenAIProvider::with_base_url(
56 config.api_key.clone(),
57 config.model.clone(),
58 base_url.clone(),
59 )
60 } else {
61 OpenAIProvider::new(config.api_key.clone(), config.model.clone())
62 };
63 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
64 op: "init_client".to_string(),
65 message: format!("failed to initialize LLM client: {}", e),
66 })?
67 }
68 LLMProvider::Google => {
69 let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
70 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
71 op: "init_client".to_string(),
72 message: format!("failed to initialize LLM client: {}", e),
73 })?
74 }
75 LLMProvider::Cohere => {
76 let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
77 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
78 op: "init_client".to_string(),
79 message: format!("failed to initialize LLM client: {}", e),
80 })?
81 }
82 LLMProvider::Bedrock => {
83 let region = config.bedrock_region.clone().ok_or_else(|| {
84 StatelessError::ExecutionFailed {
85 op: "init_client".to_string(),
86 message: "Bedrock requires bedrock_region".to_string(),
87 }
88 })?;
89 let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
90 StatelessError::ExecutionFailed {
91 op: "init_client".to_string(),
92 message: "Bedrock requires bedrock_access_key_id".to_string(),
93 }
94 })?;
95 let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
96 StatelessError::ExecutionFailed {
97 op: "init_client".to_string(),
98 message: "Bedrock requires bedrock_secret_access_key".to_string(),
99 }
100 })?;
101
102 let credentials = match &config.bedrock_session_token {
103 Some(token) => {
104 BedrockCredentials::with_session_token(access_key_id, secret_access_key, token.clone())
105 }
106 None => BedrockCredentials::new(access_key_id, secret_access_key),
107 };
108
109 let provider = BedrockProvider::new(credentials, region, config.model.clone());
110 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
111 op: "init_client".to_string(),
112 message: format!("failed to initialize LLM client: {}", e),
113 })?
114 }
115 };
116
117 Ok(Self { client, config })
118 }
119
120 pub async fn execute(
123 &self,
124 input: &str,
125 options: Option<RequestOptions>,
126 ) -> Result<StatelessResult, StatelessError> {
127 if input.is_empty() {
128 return Err(StatelessError::EmptyInput);
129 }
130
131 let msg_opts = self.build_message_options(options.as_ref());
132 let mut messages = Vec::new();
133
134 let system_prompt = options
136 .as_ref()
137 .and_then(|o| o.system_prompt.as_ref())
138 .or(self.config.system_prompt.as_ref());
139
140 if let Some(prompt) = system_prompt {
141 messages.push(LLMMessage::system(prompt));
142 }
143
144 messages.push(LLMMessage::user(input));
146
147 let response = self
149 .client
150 .send_message(&messages, &msg_opts)
151 .await
152 .map_err(|e| StatelessError::ExecutionFailed {
153 op: "send_message".to_string(),
154 message: e.to_string(),
155 })?;
156
157 let text = self.extract_text(&response);
159
160 Ok(StatelessResult {
161 text,
162 input_tokens: 0, output_tokens: 0, model: self.config.model.clone(),
165 stop_reason: None,
166 })
167 }
168
169 pub async fn execute_stream(
173 &self,
174 input: &str,
175 mut callback: StreamCallback,
176 options: Option<RequestOptions>,
177 cancel_token: Option<CancellationToken>,
178 ) -> Result<StatelessResult, StatelessError> {
179 use futures::StreamExt;
180
181 if input.is_empty() {
182 return Err(StatelessError::EmptyInput);
183 }
184
185 let msg_opts = self.build_message_options(options.as_ref());
186 let mut messages = Vec::new();
187
188 let system_prompt = options
190 .as_ref()
191 .and_then(|o| o.system_prompt.as_ref())
192 .or(self.config.system_prompt.as_ref());
193
194 if let Some(prompt) = system_prompt {
195 messages.push(LLMMessage::system(prompt));
196 }
197
198 messages.push(LLMMessage::user(input));
200
201 let mut stream = self
203 .client
204 .send_message_stream(&messages, &msg_opts)
205 .await
206 .map_err(|e| StatelessError::ExecutionFailed {
207 op: "create_stream".to_string(),
208 message: e.to_string(),
209 })?;
210
211 let mut result = StatelessResult {
213 model: self.config.model.clone(),
214 ..Default::default()
215 };
216 let mut text_builder = String::new();
217 let cancel = cancel_token.unwrap_or_else(CancellationToken::new);
218
219 loop {
220 tokio::select! {
221 _ = cancel.cancelled() => {
222 return Err(StatelessError::Cancelled);
223 }
224 event = stream.next() => {
225 match event {
226 Some(Ok(stream_event)) => {
227 match stream_event {
228 StreamEvent::MessageStart { model, .. } => {
229 result.model = model;
230 }
231 StreamEvent::TextDelta { text, .. } => {
232 text_builder.push_str(&text);
233 if callback(&text).is_err() {
235 return Err(StatelessError::StreamInterrupted);
236 }
237 }
238 StreamEvent::MessageDelta { stop_reason, usage } => {
239 if let Some(usage) = usage {
240 result.input_tokens = usage.input_tokens as i64;
241 result.output_tokens = usage.output_tokens as i64;
242 }
243 result.stop_reason = stop_reason;
244 }
245 StreamEvent::MessageStop => {
246 break;
247 }
248 _ => {}
250 }
251 }
252 Some(Err(e)) => {
253 return Err(StatelessError::ExecutionFailed {
254 op: "streaming".to_string(),
255 message: e.to_string(),
256 });
257 }
258 None => {
259 break;
261 }
262 }
263 }
264 }
265 }
266
267 result.text = text_builder;
268 Ok(result)
269 }
270
271 fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
273 let max_tokens = opts
274 .and_then(|o| o.max_tokens)
275 .unwrap_or(if self.config.max_tokens > 0 {
276 self.config.max_tokens
277 } else {
278 DEFAULT_MAX_TOKENS
279 });
280
281 let temperature = opts
282 .and_then(|o| o.temperature)
283 .or(self.config.temperature);
284
285 MessageOptions {
286 max_tokens: Some(max_tokens),
287 temperature,
288 ..Default::default()
289 }
290 }
291
292 fn extract_text(&self, message: &LLMMessage) -> String {
294 use crate::client::models::Content;
295
296 let mut text = String::new();
297 for block in &message.content {
298 if let Content::Text(t) = block {
299 text.push_str(&t);
300 }
301 }
302 text
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_config_validation() {
312 let config = StatelessConfig {
314 provider: LLMProvider::Anthropic,
315 api_key: "".to_string(),
316 model: "claude-3".to_string(),
317 base_url: None,
318 max_tokens: 4096,
319 system_prompt: None,
320 temperature: None,
321 azure_resource: None,
322 azure_deployment: None,
323 azure_api_version: None,
324 bedrock_region: None,
325 bedrock_access_key_id: None,
326 bedrock_secret_access_key: None,
327 bedrock_session_token: None,
328 };
329 assert!(config.validate().is_err());
330
331 let config = StatelessConfig {
333 provider: LLMProvider::Anthropic,
334 api_key: "test-key".to_string(),
335 model: "".to_string(),
336 base_url: None,
337 max_tokens: 4096,
338 system_prompt: None,
339 temperature: None,
340 azure_resource: None,
341 azure_deployment: None,
342 azure_api_version: None,
343 bedrock_region: None,
344 bedrock_access_key_id: None,
345 bedrock_secret_access_key: None,
346 bedrock_session_token: None,
347 };
348 assert!(config.validate().is_err());
349
350 let config = StatelessConfig::anthropic("test-key", "claude-3");
352 assert!(config.validate().is_ok());
353 }
354
355 #[test]
356 fn test_request_options_builder() {
357 let opts = RequestOptions::new()
358 .with_model("gpt-4")
359 .with_max_tokens(2048)
360 .with_system_prompt("Be helpful")
361 .with_temperature(0.7);
362
363 assert_eq!(opts.model, Some("gpt-4".to_string()));
364 assert_eq!(opts.max_tokens, Some(2048));
365 assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
366 assert_eq!(opts.temperature, Some(0.7));
367 }
368
369 #[test]
370 fn test_config_builder() {
371 let config = StatelessConfig::anthropic("key", "model")
372 .with_max_tokens(8192)
373 .with_system_prompt("You are helpful")
374 .with_temperature(0.5);
375
376 assert_eq!(config.api_key, "key");
377 assert_eq!(config.model, "model");
378 assert_eq!(config.max_tokens, 8192);
379 assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
380 assert_eq!(config.temperature, Some(0.5));
381 }
382}