Skip to main content

agent_air_runtime/controller/stateless/
executor.rs

1use tokio_util::sync::CancellationToken;
2
3use crate::client::LLMClient;
4use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
5use crate::client::providers::anthropic::AnthropicProvider;
6use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
7use crate::client::providers::cohere::CohereProvider;
8use crate::client::providers::gemini::GeminiProvider;
9use crate::client::providers::openai::OpenAIProvider;
10
11use crate::controller::session::LLMProvider;
12
13use super::types::{
14    DEFAULT_MAX_TOKENS, RequestOptions, StatelessConfig, StatelessError, StatelessResult,
15    StreamCallback,
16};
17
18/// Stateless executor for single LLM requests without session state.
19/// Safe for concurrent use - multiple tasks can call execute simultaneously.
20pub struct StatelessExecutor {
21    client: LLMClient,
22    config: StatelessConfig,
23}
24
25impl StatelessExecutor {
26    /// Creates a new stateless executor with the given configuration.
27    pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
28        config.validate()?;
29
30        let client = match config.provider {
31            LLMProvider::Anthropic => {
32                let provider = AnthropicProvider::new(config.api_key.clone(), config.model.clone());
33                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
34                    op: "init_client".to_string(),
35                    message: format!("failed to initialize LLM client: {}", e),
36                })?
37            }
38            LLMProvider::OpenAI => {
39                // Check for Azure configuration first
40                let provider = if let (Some(resource), Some(deployment)) =
41                    (&config.azure_resource, &config.azure_deployment)
42                {
43                    let api_version = config
44                        .azure_api_version
45                        .clone()
46                        .unwrap_or_else(|| "2024-10-21".to_string());
47                    OpenAIProvider::azure(
48                        config.api_key.clone(),
49                        resource.clone(),
50                        deployment.clone(),
51                        api_version,
52                    )
53                } else if let Some(base_url) = &config.base_url {
54                    OpenAIProvider::with_base_url(
55                        config.api_key.clone(),
56                        config.model.clone(),
57                        base_url.clone(),
58                    )
59                } else {
60                    OpenAIProvider::new(config.api_key.clone(), config.model.clone())
61                };
62                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
63                    op: "init_client".to_string(),
64                    message: format!("failed to initialize LLM client: {}", e),
65                })?
66            }
67            LLMProvider::Google => {
68                let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
69                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
70                    op: "init_client".to_string(),
71                    message: format!("failed to initialize LLM client: {}", e),
72                })?
73            }
74            LLMProvider::Cohere => {
75                let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
76                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
77                    op: "init_client".to_string(),
78                    message: format!("failed to initialize LLM client: {}", e),
79                })?
80            }
81            LLMProvider::Bedrock => {
82                let region = config.bedrock_region.clone().ok_or_else(|| {
83                    StatelessError::ExecutionFailed {
84                        op: "init_client".to_string(),
85                        message: "Bedrock requires bedrock_region".to_string(),
86                    }
87                })?;
88                let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
89                    StatelessError::ExecutionFailed {
90                        op: "init_client".to_string(),
91                        message: "Bedrock requires bedrock_access_key_id".to_string(),
92                    }
93                })?;
94                let secret_access_key =
95                    config.bedrock_secret_access_key.clone().ok_or_else(|| {
96                        StatelessError::ExecutionFailed {
97                            op: "init_client".to_string(),
98                            message: "Bedrock requires bedrock_secret_access_key".to_string(),
99                        }
100                    })?;
101
102                let credentials = match &config.bedrock_session_token {
103                    Some(token) => BedrockCredentials::with_session_token(
104                        access_key_id,
105                        secret_access_key,
106                        token.clone(),
107                    ),
108                    None => BedrockCredentials::new(access_key_id, secret_access_key),
109                };
110
111                let provider = BedrockProvider::new(credentials, region, config.model.clone());
112                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
113                    op: "init_client".to_string(),
114                    message: format!("failed to initialize LLM client: {}", e),
115                })?
116            }
117        };
118
119        Ok(Self { client, config })
120    }
121
122    /// Sends a single request to the LLM and waits for the complete response.
123    /// This is the simplest API - use execute_stream for progress feedback.
124    pub async fn execute(
125        &self,
126        input: &str,
127        options: Option<RequestOptions>,
128    ) -> Result<StatelessResult, StatelessError> {
129        if input.is_empty() {
130            return Err(StatelessError::EmptyInput);
131        }
132
133        let msg_opts = self.build_message_options(options.as_ref());
134        let mut messages = Vec::new();
135
136        // Add system prompt if configured
137        let system_prompt = options
138            .as_ref()
139            .and_then(|o| o.system_prompt.as_ref())
140            .or(self.config.system_prompt.as_ref());
141
142        if let Some(prompt) = system_prompt {
143            messages.push(LLMMessage::system(prompt));
144        }
145
146        // Add user message
147        messages.push(LLMMessage::user(input));
148
149        // Send request
150        let response = self
151            .client
152            .send_message(&messages, &msg_opts)
153            .await
154            .map_err(|e| StatelessError::ExecutionFailed {
155                op: "send_message".to_string(),
156                message: e.to_string(),
157            })?;
158
159        // Extract text from response
160        let text = self.extract_text(&response);
161
162        Ok(StatelessResult {
163            text,
164            input_tokens: 0,  // Non-streaming doesn't provide usage
165            output_tokens: 0, // Non-streaming doesn't provide usage
166            model: self.config.model.clone(),
167            stop_reason: None,
168        })
169    }
170
171    /// Sends a request and streams the response via callback.
172    /// The callback is called for each text chunk as it arrives.
173    /// Returns the complete Result after streaming finishes.
174    pub async fn execute_stream(
175        &self,
176        input: &str,
177        mut callback: StreamCallback,
178        options: Option<RequestOptions>,
179        cancel_token: Option<CancellationToken>,
180    ) -> Result<StatelessResult, StatelessError> {
181        use futures::StreamExt;
182
183        if input.is_empty() {
184            return Err(StatelessError::EmptyInput);
185        }
186
187        let msg_opts = self.build_message_options(options.as_ref());
188        let mut messages = Vec::new();
189
190        // Add system prompt if configured
191        let system_prompt = options
192            .as_ref()
193            .and_then(|o| o.system_prompt.as_ref())
194            .or(self.config.system_prompt.as_ref());
195
196        if let Some(prompt) = system_prompt {
197            messages.push(LLMMessage::system(prompt));
198        }
199
200        // Add user message
201        messages.push(LLMMessage::user(input));
202
203        // Create streaming request
204        let mut stream = self
205            .client
206            .send_message_stream(&messages, &msg_opts)
207            .await
208            .map_err(|e| StatelessError::ExecutionFailed {
209                op: "create_stream".to_string(),
210                message: e.to_string(),
211            })?;
212
213        // Process stream events
214        let mut result = StatelessResult {
215            model: self.config.model.clone(),
216            ..Default::default()
217        };
218        let mut text_builder = String::new();
219        let cancel = cancel_token.unwrap_or_default();
220
221        loop {
222            tokio::select! {
223                _ = cancel.cancelled() => {
224                    return Err(StatelessError::Cancelled);
225                }
226                event = stream.next() => {
227                    match event {
228                        Some(Ok(stream_event)) => {
229                            match stream_event {
230                                StreamEvent::MessageStart { model, .. } => {
231                                    result.model = model;
232                                }
233                                StreamEvent::TextDelta { text, .. } => {
234                                    text_builder.push_str(&text);
235                                    // Call the callback
236                                    if callback(&text).is_err() {
237                                        return Err(StatelessError::StreamInterrupted);
238                                    }
239                                }
240                                StreamEvent::MessageDelta { stop_reason, usage } => {
241                                    if let Some(usage) = usage {
242                                        result.input_tokens = usage.input_tokens as i64;
243                                        result.output_tokens = usage.output_tokens as i64;
244                                    }
245                                    result.stop_reason = stop_reason;
246                                }
247                                StreamEvent::MessageStop => {
248                                    break;
249                                }
250                                // Ignore other events (tool use, etc.)
251                                _ => {}
252                            }
253                        }
254                        Some(Err(e)) => {
255                            return Err(StatelessError::ExecutionFailed {
256                                op: "streaming".to_string(),
257                                message: e.to_string(),
258                            });
259                        }
260                        None => {
261                            // Stream ended
262                            break;
263                        }
264                    }
265                }
266            }
267        }
268
269        result.text = text_builder;
270        Ok(result)
271    }
272
273    /// Builds MessageOptions from config and request options.
274    fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
275        let max_tokens = opts
276            .and_then(|o| o.max_tokens)
277            .unwrap_or(if self.config.max_tokens > 0 {
278                self.config.max_tokens
279            } else {
280                DEFAULT_MAX_TOKENS
281            });
282
283        let temperature = opts.and_then(|o| o.temperature).or(self.config.temperature);
284
285        MessageOptions {
286            max_tokens: Some(max_tokens),
287            temperature,
288            ..Default::default()
289        }
290    }
291
292    /// Extracts text from a LLMClient message response.
293    fn extract_text(&self, message: &LLMMessage) -> String {
294        use crate::client::models::Content;
295
296        let mut text = String::new();
297        for block in &message.content {
298            if let Content::Text(t) = block {
299                text.push_str(t);
300            }
301        }
302        text
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_config_validation() {
312        // Missing API key
313        let config = StatelessConfig {
314            provider: LLMProvider::Anthropic,
315            api_key: "".to_string(),
316            model: "claude-3".to_string(),
317            base_url: None,
318            max_tokens: 4096,
319            system_prompt: None,
320            temperature: None,
321            azure_resource: None,
322            azure_deployment: None,
323            azure_api_version: None,
324            bedrock_region: None,
325            bedrock_access_key_id: None,
326            bedrock_secret_access_key: None,
327            bedrock_session_token: None,
328        };
329        assert!(config.validate().is_err());
330
331        // Missing model
332        let config = StatelessConfig {
333            provider: LLMProvider::Anthropic,
334            api_key: "test-key".to_string(),
335            model: "".to_string(),
336            base_url: None,
337            max_tokens: 4096,
338            system_prompt: None,
339            temperature: None,
340            azure_resource: None,
341            azure_deployment: None,
342            azure_api_version: None,
343            bedrock_region: None,
344            bedrock_access_key_id: None,
345            bedrock_secret_access_key: None,
346            bedrock_session_token: None,
347        };
348        assert!(config.validate().is_err());
349
350        // Valid config
351        let config = StatelessConfig::anthropic("test-key", "claude-3");
352        assert!(config.validate().is_ok());
353    }
354
355    #[test]
356    fn test_request_options_builder() {
357        let opts = RequestOptions::new()
358            .with_model("gpt-4")
359            .with_max_tokens(2048)
360            .with_system_prompt("Be helpful")
361            .with_temperature(0.7);
362
363        assert_eq!(opts.model, Some("gpt-4".to_string()));
364        assert_eq!(opts.max_tokens, Some(2048));
365        assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
366        assert_eq!(opts.temperature, Some(0.7));
367    }
368
369    #[test]
370    fn test_config_builder() {
371        let config = StatelessConfig::anthropic("key", "model")
372            .with_max_tokens(8192)
373            .with_system_prompt("You are helpful")
374            .with_temperature(0.5);
375
376        assert_eq!(config.api_key, "key");
377        assert_eq!(config.model, "model");
378        assert_eq!(config.max_tokens, 8192);
379        assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
380        assert_eq!(config.temperature, Some(0.5));
381    }
382}