agent_air_runtime/controller/stateless/
executor.rs1use tokio_util::sync::CancellationToken;
2
3use crate::client::LLMClient;
4use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
5use crate::client::providers::anthropic::AnthropicProvider;
6use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
7use crate::client::providers::cohere::CohereProvider;
8use crate::client::providers::gemini::GeminiProvider;
9use crate::client::providers::openai::OpenAIProvider;
10
11use crate::controller::session::LLMProvider;
12
13use super::types::{
14 DEFAULT_MAX_TOKENS, RequestOptions, StatelessConfig, StatelessError, StatelessResult,
15 StreamCallback,
16};
17
18pub struct StatelessExecutor {
21 client: LLMClient,
22 config: StatelessConfig,
23}
24
25impl StatelessExecutor {
26 pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
28 config.validate()?;
29
30 let client = match config.provider {
31 LLMProvider::Anthropic => {
32 let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
33 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
34 op: "init_client".to_string(),
35 message: format!("failed to initialize LLM client: {}", e),
36 })?
37 }
38 LLMProvider::OpenAI => {
39 let provider = if let (Some(resource), Some(deployment)) =
41 (&config.azure_resource, &config.azure_deployment)
42 {
43 let api_version = config
44 .azure_api_version
45 .clone()
46 .unwrap_or_else(|| "2024-10-21".to_string());
47 OpenAIProvider::azure(
48 config.api_key.clone(),
49 resource.clone(),
50 deployment.clone(),
51 api_version,
52 )
53 } else if let Some(base_url) = &config.base_url {
54 OpenAIProvider::with_base_url(
55 config.api_key.clone(),
56 config.model.clone(),
57 base_url.clone(),
58 )
59 } else {
60 OpenAIProvider::new(config.api_key.clone(), config.model.clone())
61 };
62 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
63 op: "init_client".to_string(),
64 message: format!("failed to initialize LLM client: {}", e),
65 })?
66 }
67 LLMProvider::Google => {
68 let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
69 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
70 op: "init_client".to_string(),
71 message: format!("failed to initialize LLM client: {}", e),
72 })?
73 }
74 LLMProvider::Cohere => {
75 let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
76 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
77 op: "init_client".to_string(),
78 message: format!("failed to initialize LLM client: {}", e),
79 })?
80 }
81 LLMProvider::Bedrock => {
82 let region = config.bedrock_region.clone().ok_or_else(|| {
83 StatelessError::ExecutionFailed {
84 op: "init_client".to_string(),
85 message: "Bedrock requires bedrock_region".to_string(),
86 }
87 })?;
88 let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
89 StatelessError::ExecutionFailed {
90 op: "init_client".to_string(),
91 message: "Bedrock requires bedrock_access_key_id".to_string(),
92 }
93 })?;
94 let secret_access_key =
95 config.bedrock_secret_access_key.clone().ok_or_else(|| {
96 StatelessError::ExecutionFailed {
97 op: "init_client".to_string(),
98 message: "Bedrock requires bedrock_secret_access_key".to_string(),
99 }
100 })?;
101
102 let credentials = match &config.bedrock_session_token {
103 Some(token) => BedrockCredentials::with_session_token(
104 access_key_id,
105 secret_access_key,
106 token.clone(),
107 ),
108 None => BedrockCredentials::new(access_key_id, secret_access_key),
109 };
110
111 let provider = BedrockProvider::new(credentials, region, config.model.clone());
112 LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
113 op: "init_client".to_string(),
114 message: format!("failed to initialize LLM client: {}", e),
115 })?
116 }
117 };
118
119 Ok(Self { client, config })
120 }
121
122 pub async fn execute(
125 &self,
126 input: &str,
127 options: Option<RequestOptions>,
128 ) -> Result<StatelessResult, StatelessError> {
129 if input.is_empty() {
130 return Err(StatelessError::EmptyInput);
131 }
132
133 let msg_opts = self.build_message_options(options.as_ref());
134 let mut messages = Vec::new();
135
136 let system_prompt = options
138 .as_ref()
139 .and_then(|o| o.system_prompt.as_ref())
140 .or(self.config.system_prompt.as_ref());
141
142 if let Some(prompt) = system_prompt {
143 messages.push(LLMMessage::system(prompt));
144 }
145
146 messages.push(LLMMessage::user(input));
148
149 let response = self
151 .client
152 .send_message(&messages, &msg_opts)
153 .await
154 .map_err(|e| StatelessError::ExecutionFailed {
155 op: "send_message".to_string(),
156 message: e.to_string(),
157 })?;
158
159 let text = self.extract_text(&response);
161
162 Ok(StatelessResult {
163 text,
164 input_tokens: 0, output_tokens: 0, model: self.config.model.clone(),
167 stop_reason: None,
168 })
169 }
170
171 pub async fn execute_stream(
175 &self,
176 input: &str,
177 mut callback: StreamCallback,
178 options: Option<RequestOptions>,
179 cancel_token: Option<CancellationToken>,
180 ) -> Result<StatelessResult, StatelessError> {
181 use futures::StreamExt;
182
183 if input.is_empty() {
184 return Err(StatelessError::EmptyInput);
185 }
186
187 let msg_opts = self.build_message_options(options.as_ref());
188 let mut messages = Vec::new();
189
190 let system_prompt = options
192 .as_ref()
193 .and_then(|o| o.system_prompt.as_ref())
194 .or(self.config.system_prompt.as_ref());
195
196 if let Some(prompt) = system_prompt {
197 messages.push(LLMMessage::system(prompt));
198 }
199
200 messages.push(LLMMessage::user(input));
202
203 let mut stream = self
205 .client
206 .send_message_stream(&messages, &msg_opts)
207 .await
208 .map_err(|e| StatelessError::ExecutionFailed {
209 op: "create_stream".to_string(),
210 message: e.to_string(),
211 })?;
212
213 let mut result = StatelessResult {
215 model: self.config.model.clone(),
216 ..Default::default()
217 };
218 let mut text_builder = String::new();
219 let cancel = cancel_token.unwrap_or_default();
220
221 loop {
222 tokio::select! {
223 _ = cancel.cancelled() => {
224 return Err(StatelessError::Cancelled);
225 }
226 event = stream.next() => {
227 match event {
228 Some(Ok(stream_event)) => {
229 match stream_event {
230 StreamEvent::MessageStart { model, .. } => {
231 result.model = model;
232 }
233 StreamEvent::TextDelta { text, .. } => {
234 text_builder.push_str(&text);
235 if callback(&text).is_err() {
237 return Err(StatelessError::StreamInterrupted);
238 }
239 }
240 StreamEvent::MessageDelta { stop_reason, usage } => {
241 if let Some(usage) = usage {
242 result.input_tokens = usage.input_tokens as i64;
243 result.output_tokens = usage.output_tokens as i64;
244 }
245 result.stop_reason = stop_reason;
246 }
247 StreamEvent::MessageStop => {
248 break;
249 }
250 _ => {}
252 }
253 }
254 Some(Err(e)) => {
255 return Err(StatelessError::ExecutionFailed {
256 op: "streaming".to_string(),
257 message: e.to_string(),
258 });
259 }
260 None => {
261 break;
263 }
264 }
265 }
266 }
267 }
268
269 result.text = text_builder;
270 Ok(result)
271 }
272
273 fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
275 let max_tokens = opts
276 .and_then(|o| o.max_tokens)
277 .unwrap_or(if self.config.max_tokens > 0 {
278 self.config.max_tokens
279 } else {
280 DEFAULT_MAX_TOKENS
281 });
282
283 let temperature = opts.and_then(|o| o.temperature).or(self.config.temperature);
284
285 MessageOptions {
286 max_tokens: Some(max_tokens),
287 temperature,
288 ..Default::default()
289 }
290 }
291
292 fn extract_text(&self, message: &LLMMessage) -> String {
294 use crate::client::models::Content;
295
296 let mut text = String::new();
297 for block in &message.content {
298 if let Content::Text(t) = block {
299 text.push_str(t);
300 }
301 }
302 text
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_config_validation() {
312 let config = StatelessConfig {
314 provider: LLMProvider::Anthropic,
315 api_key: "".to_string(),
316 model: "claude-3".to_string(),
317 base_url: None,
318 max_tokens: 4096,
319 system_prompt: None,
320 temperature: None,
321 azure_resource: None,
322 azure_deployment: None,
323 azure_api_version: None,
324 bedrock_region: None,
325 bedrock_access_key_id: None,
326 bedrock_secret_access_key: None,
327 bedrock_session_token: None,
328 };
329 assert!(config.validate().is_err());
330
331 let config = StatelessConfig {
333 provider: LLMProvider::Anthropic,
334 api_key: "test-key".to_string(),
335 model: "".to_string(),
336 base_url: None,
337 max_tokens: 4096,
338 system_prompt: None,
339 temperature: None,
340 azure_resource: None,
341 azure_deployment: None,
342 azure_api_version: None,
343 bedrock_region: None,
344 bedrock_access_key_id: None,
345 bedrock_secret_access_key: None,
346 bedrock_session_token: None,
347 };
348 assert!(config.validate().is_err());
349
350 let config = StatelessConfig::anthropic("test-key", "claude-3");
352 assert!(config.validate().is_ok());
353 }
354
355 #[test]
356 fn test_request_options_builder() {
357 let opts = RequestOptions::new()
358 .with_model("gpt-4")
359 .with_max_tokens(2048)
360 .with_system_prompt("Be helpful")
361 .with_temperature(0.7);
362
363 assert_eq!(opts.model, Some("gpt-4".to_string()));
364 assert_eq!(opts.max_tokens, Some(2048));
365 assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
366 assert_eq!(opts.temperature, Some(0.7));
367 }
368
369 #[test]
370 fn test_config_builder() {
371 let config = StatelessConfig::anthropic("key", "model")
372 .with_max_tokens(8192)
373 .with_system_prompt("You are helpful")
374 .with_temperature(0.5);
375
376 assert_eq!(config.api_key, "key");
377 assert_eq!(config.model, "model");
378 assert_eq!(config.max_tokens, 8192);
379 assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
380 assert_eq!(config.temperature, Some(0.5));
381 }
382}