Skip to main content

agent_core/controller/stateless/
executor.rs

1use tokio_util::sync::CancellationToken;
2
3use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
4use crate::client::providers::anthropic::AnthropicProvider;
5use crate::client::providers::bedrock::{BedrockCredentials, BedrockProvider};
6use crate::client::providers::cohere::CohereProvider;
7use crate::client::providers::gemini::GeminiProvider;
8use crate::client::providers::openai::OpenAIProvider;
9use crate::client::LLMClient;
10
11use crate::controller::session::LLMProvider;
12
13use super::types::{
14    RequestOptions, StatelessConfig, StatelessError, StatelessResult, StreamCallback,
15    DEFAULT_MAX_TOKENS,
16};
17
18/// Stateless executor for single LLM requests without session state.
19/// Safe for concurrent use - multiple tasks can call execute simultaneously.
20pub struct StatelessExecutor {
21    client: LLMClient,
22    config: StatelessConfig,
23}
24
25impl StatelessExecutor {
26    /// Creates a new stateless executor with the given configuration.
27    pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
28        config.validate()?;
29
30        let client = match config.provider {
31            LLMProvider::Anthropic => {
32                let provider =
33                    AnthropicProvider::new(config.api_key.clone(), config.model.clone());
34                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
35                    op: "init_client".to_string(),
36                    message: format!("failed to initialize LLM client: {}", e),
37                })?
38            }
39            LLMProvider::OpenAI => {
40                // Check for Azure configuration first
41                let provider = if let (Some(resource), Some(deployment)) =
42                    (&config.azure_resource, &config.azure_deployment)
43                {
44                    let api_version = config
45                        .azure_api_version
46                        .clone()
47                        .unwrap_or_else(|| "2024-10-21".to_string());
48                    OpenAIProvider::azure(
49                        config.api_key.clone(),
50                        resource.clone(),
51                        deployment.clone(),
52                        api_version,
53                    )
54                } else if let Some(base_url) = &config.base_url {
55                    OpenAIProvider::with_base_url(
56                        config.api_key.clone(),
57                        config.model.clone(),
58                        base_url.clone(),
59                    )
60                } else {
61                    OpenAIProvider::new(config.api_key.clone(), config.model.clone())
62                };
63                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
64                    op: "init_client".to_string(),
65                    message: format!("failed to initialize LLM client: {}", e),
66                })?
67            }
68            LLMProvider::Google => {
69                let provider = GeminiProvider::new(config.api_key.clone(), config.model.clone());
70                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
71                    op: "init_client".to_string(),
72                    message: format!("failed to initialize LLM client: {}", e),
73                })?
74            }
75            LLMProvider::Cohere => {
76                let provider = CohereProvider::new(config.api_key.clone(), config.model.clone());
77                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
78                    op: "init_client".to_string(),
79                    message: format!("failed to initialize LLM client: {}", e),
80                })?
81            }
82            LLMProvider::Bedrock => {
83                let region = config.bedrock_region.clone().ok_or_else(|| {
84                    StatelessError::ExecutionFailed {
85                        op: "init_client".to_string(),
86                        message: "Bedrock requires bedrock_region".to_string(),
87                    }
88                })?;
89                let access_key_id = config.bedrock_access_key_id.clone().ok_or_else(|| {
90                    StatelessError::ExecutionFailed {
91                        op: "init_client".to_string(),
92                        message: "Bedrock requires bedrock_access_key_id".to_string(),
93                    }
94                })?;
95                let secret_access_key = config.bedrock_secret_access_key.clone().ok_or_else(|| {
96                    StatelessError::ExecutionFailed {
97                        op: "init_client".to_string(),
98                        message: "Bedrock requires bedrock_secret_access_key".to_string(),
99                    }
100                })?;
101
102                let credentials = match &config.bedrock_session_token {
103                    Some(token) => {
104                        BedrockCredentials::with_session_token(access_key_id, secret_access_key, token.clone())
105                    }
106                    None => BedrockCredentials::new(access_key_id, secret_access_key),
107                };
108
109                let provider = BedrockProvider::new(credentials, region, config.model.clone());
110                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
111                    op: "init_client".to_string(),
112                    message: format!("failed to initialize LLM client: {}", e),
113                })?
114            }
115        };
116
117        Ok(Self { client, config })
118    }
119
120    /// Sends a single request to the LLM and waits for the complete response.
121    /// This is the simplest API - use execute_stream for progress feedback.
122    pub async fn execute(
123        &self,
124        input: &str,
125        options: Option<RequestOptions>,
126    ) -> Result<StatelessResult, StatelessError> {
127        if input.is_empty() {
128            return Err(StatelessError::EmptyInput);
129        }
130
131        let msg_opts = self.build_message_options(options.as_ref());
132        let mut messages = Vec::new();
133
134        // Add system prompt if configured
135        let system_prompt = options
136            .as_ref()
137            .and_then(|o| o.system_prompt.as_ref())
138            .or(self.config.system_prompt.as_ref());
139
140        if let Some(prompt) = system_prompt {
141            messages.push(LLMMessage::system(prompt));
142        }
143
144        // Add user message
145        messages.push(LLMMessage::user(input));
146
147        // Send request
148        let response = self
149            .client
150            .send_message(&messages, &msg_opts)
151            .await
152            .map_err(|e| StatelessError::ExecutionFailed {
153                op: "send_message".to_string(),
154                message: e.to_string(),
155            })?;
156
157        // Extract text from response
158        let text = self.extract_text(&response);
159
160        Ok(StatelessResult {
161            text,
162            input_tokens: 0,  // Non-streaming doesn't provide usage
163            output_tokens: 0, // Non-streaming doesn't provide usage
164            model: self.config.model.clone(),
165            stop_reason: None,
166        })
167    }
168
169    /// Sends a request and streams the response via callback.
170    /// The callback is called for each text chunk as it arrives.
171    /// Returns the complete Result after streaming finishes.
172    pub async fn execute_stream(
173        &self,
174        input: &str,
175        mut callback: StreamCallback,
176        options: Option<RequestOptions>,
177        cancel_token: Option<CancellationToken>,
178    ) -> Result<StatelessResult, StatelessError> {
179        use futures::StreamExt;
180
181        if input.is_empty() {
182            return Err(StatelessError::EmptyInput);
183        }
184
185        let msg_opts = self.build_message_options(options.as_ref());
186        let mut messages = Vec::new();
187
188        // Add system prompt if configured
189        let system_prompt = options
190            .as_ref()
191            .and_then(|o| o.system_prompt.as_ref())
192            .or(self.config.system_prompt.as_ref());
193
194        if let Some(prompt) = system_prompt {
195            messages.push(LLMMessage::system(prompt));
196        }
197
198        // Add user message
199        messages.push(LLMMessage::user(input));
200
201        // Create streaming request
202        let mut stream = self
203            .client
204            .send_message_stream(&messages, &msg_opts)
205            .await
206            .map_err(|e| StatelessError::ExecutionFailed {
207                op: "create_stream".to_string(),
208                message: e.to_string(),
209            })?;
210
211        // Process stream events
212        let mut result = StatelessResult {
213            model: self.config.model.clone(),
214            ..Default::default()
215        };
216        let mut text_builder = String::new();
217        let cancel = cancel_token.unwrap_or_else(CancellationToken::new);
218
219        loop {
220            tokio::select! {
221                _ = cancel.cancelled() => {
222                    return Err(StatelessError::Cancelled);
223                }
224                event = stream.next() => {
225                    match event {
226                        Some(Ok(stream_event)) => {
227                            match stream_event {
228                                StreamEvent::MessageStart { model, .. } => {
229                                    result.model = model;
230                                }
231                                StreamEvent::TextDelta { text, .. } => {
232                                    text_builder.push_str(&text);
233                                    // Call the callback
234                                    if callback(&text).is_err() {
235                                        return Err(StatelessError::StreamInterrupted);
236                                    }
237                                }
238                                StreamEvent::MessageDelta { stop_reason, usage } => {
239                                    if let Some(usage) = usage {
240                                        result.input_tokens = usage.input_tokens as i64;
241                                        result.output_tokens = usage.output_tokens as i64;
242                                    }
243                                    result.stop_reason = stop_reason;
244                                }
245                                StreamEvent::MessageStop => {
246                                    break;
247                                }
248                                // Ignore other events (tool use, etc.)
249                                _ => {}
250                            }
251                        }
252                        Some(Err(e)) => {
253                            return Err(StatelessError::ExecutionFailed {
254                                op: "streaming".to_string(),
255                                message: e.to_string(),
256                            });
257                        }
258                        None => {
259                            // Stream ended
260                            break;
261                        }
262                    }
263                }
264            }
265        }
266
267        result.text = text_builder;
268        Ok(result)
269    }
270
271    /// Builds MessageOptions from config and request options.
272    fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
273        let max_tokens = opts
274            .and_then(|o| o.max_tokens)
275            .unwrap_or(if self.config.max_tokens > 0 {
276                self.config.max_tokens
277            } else {
278                DEFAULT_MAX_TOKENS
279            });
280
281        let temperature = opts
282            .and_then(|o| o.temperature)
283            .or(self.config.temperature);
284
285        MessageOptions {
286            max_tokens: Some(max_tokens),
287            temperature,
288            ..Default::default()
289        }
290    }
291
292    /// Extracts text from a LLMClient message response.
293    fn extract_text(&self, message: &LLMMessage) -> String {
294        use crate::client::models::Content;
295
296        let mut text = String::new();
297        for block in &message.content {
298            if let Content::Text(t) = block {
299                text.push_str(&t);
300            }
301        }
302        text
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_config_validation() {
312        // Missing API key
313        let config = StatelessConfig {
314            provider: LLMProvider::Anthropic,
315            api_key: "".to_string(),
316            model: "claude-3".to_string(),
317            base_url: None,
318            max_tokens: 4096,
319            system_prompt: None,
320            temperature: None,
321            azure_resource: None,
322            azure_deployment: None,
323            azure_api_version: None,
324            bedrock_region: None,
325            bedrock_access_key_id: None,
326            bedrock_secret_access_key: None,
327            bedrock_session_token: None,
328        };
329        assert!(config.validate().is_err());
330
331        // Missing model
332        let config = StatelessConfig {
333            provider: LLMProvider::Anthropic,
334            api_key: "test-key".to_string(),
335            model: "".to_string(),
336            base_url: None,
337            max_tokens: 4096,
338            system_prompt: None,
339            temperature: None,
340            azure_resource: None,
341            azure_deployment: None,
342            azure_api_version: None,
343            bedrock_region: None,
344            bedrock_access_key_id: None,
345            bedrock_secret_access_key: None,
346            bedrock_session_token: None,
347        };
348        assert!(config.validate().is_err());
349
350        // Valid config
351        let config = StatelessConfig::anthropic("test-key", "claude-3");
352        assert!(config.validate().is_ok());
353    }
354
355    #[test]
356    fn test_request_options_builder() {
357        let opts = RequestOptions::new()
358            .with_model("gpt-4")
359            .with_max_tokens(2048)
360            .with_system_prompt("Be helpful")
361            .with_temperature(0.7);
362
363        assert_eq!(opts.model, Some("gpt-4".to_string()));
364        assert_eq!(opts.max_tokens, Some(2048));
365        assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
366        assert_eq!(opts.temperature, Some(0.7));
367    }
368
369    #[test]
370    fn test_config_builder() {
371        let config = StatelessConfig::anthropic("key", "model")
372            .with_max_tokens(8192)
373            .with_system_prompt("You are helpful")
374            .with_temperature(0.5);
375
376        assert_eq!(config.api_key, "key");
377        assert_eq!(config.model, "model");
378        assert_eq!(config.max_tokens, 8192);
379        assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
380        assert_eq!(config.temperature, Some(0.5));
381    }
382}