agent_core/controller/stateless/
executor.rs

1use tokio_util::sync::CancellationToken;
2
3use crate::client::models::{Message as LLMMessage, MessageOptions, StreamEvent};
4use crate::client::providers::anthropic::AnthropicProvider;
5use crate::client::providers::openai::OpenAIProvider;
6use crate::client::LLMClient;
7
8use crate::controller::session::LLMProvider;
9
10use super::types::{
11    RequestOptions, StatelessConfig, StatelessError, StatelessResult, StreamCallback,
12    DEFAULT_MAX_TOKENS,
13};
14
15/// Stateless executor for single LLM requests without session state.
16/// Safe for concurrent use - multiple tasks can call execute simultaneously.
17pub struct StatelessExecutor {
18    client: LLMClient,
19    config: StatelessConfig,
20}
21
22impl StatelessExecutor {
23    /// Creates a new stateless executor with the given configuration.
24    pub fn new(config: StatelessConfig) -> Result<Self, StatelessError> {
25        config.validate()?;
26
27        let client = match config.provider {
28            LLMProvider::Anthropic => {
29                let provider =
30                    AnthropicProvider::new(config.api_key.clone(), config.model.clone());
31                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
32                    op: "init_client".to_string(),
33                    message: format!("failed to initialize LLM client: {}", e),
34                })?
35            }
36            LLMProvider::OpenAI => {
37                let provider = OpenAIProvider::new(config.api_key.clone(), config.model.clone());
38                LLMClient::new(Box::new(provider)).map_err(|e| StatelessError::ExecutionFailed {
39                    op: "init_client".to_string(),
40                    message: format!("failed to initialize LLM client: {}", e),
41                })?
42            }
43        };
44
45        Ok(Self { client, config })
46    }
47
48    /// Sends a single request to the LLM and waits for the complete response.
49    /// This is the simplest API - use execute_stream for progress feedback.
50    pub async fn execute(
51        &self,
52        input: &str,
53        options: Option<RequestOptions>,
54    ) -> Result<StatelessResult, StatelessError> {
55        if input.is_empty() {
56            return Err(StatelessError::EmptyInput);
57        }
58
59        let msg_opts = self.build_message_options(options.as_ref());
60        let mut messages = Vec::new();
61
62        // Add system prompt if configured
63        let system_prompt = options
64            .as_ref()
65            .and_then(|o| o.system_prompt.as_ref())
66            .or(self.config.system_prompt.as_ref());
67
68        if let Some(prompt) = system_prompt {
69            messages.push(LLMMessage::system(prompt));
70        }
71
72        // Add user message
73        messages.push(LLMMessage::user(input));
74
75        // Send request
76        let response = self
77            .client
78            .send_message(&messages, &msg_opts)
79            .await
80            .map_err(|e| StatelessError::ExecutionFailed {
81                op: "send_message".to_string(),
82                message: e.to_string(),
83            })?;
84
85        // Extract text from response
86        let text = self.extract_text(&response);
87
88        Ok(StatelessResult {
89            text,
90            input_tokens: 0,  // Non-streaming doesn't provide usage
91            output_tokens: 0, // Non-streaming doesn't provide usage
92            model: self.config.model.clone(),
93            stop_reason: None,
94        })
95    }
96
97    /// Sends a request and streams the response via callback.
98    /// The callback is called for each text chunk as it arrives.
99    /// Returns the complete Result after streaming finishes.
100    pub async fn execute_stream(
101        &self,
102        input: &str,
103        mut callback: StreamCallback,
104        options: Option<RequestOptions>,
105        cancel_token: Option<CancellationToken>,
106    ) -> Result<StatelessResult, StatelessError> {
107        use futures::StreamExt;
108
109        if input.is_empty() {
110            return Err(StatelessError::EmptyInput);
111        }
112
113        let msg_opts = self.build_message_options(options.as_ref());
114        let mut messages = Vec::new();
115
116        // Add system prompt if configured
117        let system_prompt = options
118            .as_ref()
119            .and_then(|o| o.system_prompt.as_ref())
120            .or(self.config.system_prompt.as_ref());
121
122        if let Some(prompt) = system_prompt {
123            messages.push(LLMMessage::system(prompt));
124        }
125
126        // Add user message
127        messages.push(LLMMessage::user(input));
128
129        // Create streaming request
130        let mut stream = self
131            .client
132            .send_message_stream(&messages, &msg_opts)
133            .await
134            .map_err(|e| StatelessError::ExecutionFailed {
135                op: "create_stream".to_string(),
136                message: e.to_string(),
137            })?;
138
139        // Process stream events
140        let mut result = StatelessResult {
141            model: self.config.model.clone(),
142            ..Default::default()
143        };
144        let mut text_builder = String::new();
145        let cancel = cancel_token.unwrap_or_else(CancellationToken::new);
146
147        loop {
148            tokio::select! {
149                _ = cancel.cancelled() => {
150                    return Err(StatelessError::Cancelled);
151                }
152                event = stream.next() => {
153                    match event {
154                        Some(Ok(stream_event)) => {
155                            match stream_event {
156                                StreamEvent::MessageStart { model, .. } => {
157                                    result.model = model;
158                                }
159                                StreamEvent::TextDelta { text, .. } => {
160                                    text_builder.push_str(&text);
161                                    // Call the callback
162                                    if callback(&text).is_err() {
163                                        return Err(StatelessError::StreamInterrupted);
164                                    }
165                                }
166                                StreamEvent::MessageDelta { stop_reason, usage } => {
167                                    if let Some(usage) = usage {
168                                        result.input_tokens = usage.input_tokens as i64;
169                                        result.output_tokens = usage.output_tokens as i64;
170                                    }
171                                    result.stop_reason = stop_reason;
172                                }
173                                StreamEvent::MessageStop => {
174                                    break;
175                                }
176                                // Ignore other events (tool use, etc.)
177                                _ => {}
178                            }
179                        }
180                        Some(Err(e)) => {
181                            return Err(StatelessError::ExecutionFailed {
182                                op: "streaming".to_string(),
183                                message: e.to_string(),
184                            });
185                        }
186                        None => {
187                            // Stream ended
188                            break;
189                        }
190                    }
191                }
192            }
193        }
194
195        result.text = text_builder;
196        Ok(result)
197    }
198
199    /// Builds MessageOptions from config and request options.
200    fn build_message_options(&self, opts: Option<&RequestOptions>) -> MessageOptions {
201        let max_tokens = opts
202            .and_then(|o| o.max_tokens)
203            .unwrap_or(if self.config.max_tokens > 0 {
204                self.config.max_tokens
205            } else {
206                DEFAULT_MAX_TOKENS
207            });
208
209        let temperature = opts
210            .and_then(|o| o.temperature)
211            .or(self.config.temperature);
212
213        MessageOptions {
214            max_tokens: Some(max_tokens),
215            temperature,
216            ..Default::default()
217        }
218    }
219
220    /// Extracts text from a LLMClient message response.
221    fn extract_text(&self, message: &LLMMessage) -> String {
222        use crate::client::models::Content;
223
224        let mut text = String::new();
225        for block in &message.content {
226            if let Content::Text(t) = block {
227                text.push_str(&t);
228            }
229        }
230        text
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_config_validation() {
240        // Missing API key
241        let config = StatelessConfig {
242            provider: LLMProvider::Anthropic,
243            api_key: "".to_string(),
244            model: "claude-3".to_string(),
245            max_tokens: 4096,
246            system_prompt: None,
247            temperature: None,
248        };
249        assert!(config.validate().is_err());
250
251        // Missing model
252        let config = StatelessConfig {
253            provider: LLMProvider::Anthropic,
254            api_key: "test-key".to_string(),
255            model: "".to_string(),
256            max_tokens: 4096,
257            system_prompt: None,
258            temperature: None,
259        };
260        assert!(config.validate().is_err());
261
262        // Valid config
263        let config = StatelessConfig::anthropic("test-key", "claude-3");
264        assert!(config.validate().is_ok());
265    }
266
267    #[test]
268    fn test_request_options_builder() {
269        let opts = RequestOptions::new()
270            .with_model("gpt-4")
271            .with_max_tokens(2048)
272            .with_system_prompt("Be helpful")
273            .with_temperature(0.7);
274
275        assert_eq!(opts.model, Some("gpt-4".to_string()));
276        assert_eq!(opts.max_tokens, Some(2048));
277        assert_eq!(opts.system_prompt, Some("Be helpful".to_string()));
278        assert_eq!(opts.temperature, Some(0.7));
279    }
280
281    #[test]
282    fn test_config_builder() {
283        let config = StatelessConfig::anthropic("key", "model")
284            .with_max_tokens(8192)
285            .with_system_prompt("You are helpful")
286            .with_temperature(0.5);
287
288        assert_eq!(config.api_key, "key");
289        assert_eq!(config.model, "model");
290        assert_eq!(config.max_tokens, 8192);
291        assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
292        assert_eq!(config.temperature, Some(0.5));
293    }
294}