Skip to main content

minion_engine/steps/
chat.rs

1use std::time::Duration;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6
7use crate::config::StepConfig;
8use crate::engine::context::{ChatMessage, Context};
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{ChatOutput, StepExecutor, StepOutput};
13
14/// Truncation strategy for chat history (Story 5.2)
15#[derive(Debug, Clone)]
16pub enum TruncationStrategy {
17    /// Keep all messages (default)
18    None,
19    /// Keep the last N messages
20    Last(usize),
21    /// Keep the first N messages
22    First(usize),
23    /// Keep the first `first` and last `last` messages
24    FirstLast { first: usize, last: usize },
25    /// Drop oldest messages until total estimated tokens <= max_tokens
26    SlidingWindow { max_tokens: usize },
27}
28
29impl TruncationStrategy {
30    /// Parse from StepConfig keys: truncation_strategy, truncation_count, truncation_first,
31    /// truncation_last, truncation_max_tokens
32    pub fn from_config(config: &crate::config::StepConfig) -> Self {
33        match config.get_str("truncation_strategy") {
34            Some("last") => {
35                let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
36                TruncationStrategy::Last(n)
37            }
38            Some("first") => {
39                let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
40                TruncationStrategy::First(n)
41            }
42            Some("first_last") => {
43                let first = config.get_u64("truncation_first").unwrap_or(2) as usize;
44                let last = config.get_u64("truncation_last").unwrap_or(5) as usize;
45                TruncationStrategy::FirstLast { first, last }
46            }
47            Some("sliding_window") => {
48                let max_tokens =
49                    config.get_u64("truncation_max_tokens").unwrap_or(50_000) as usize;
50                TruncationStrategy::SlidingWindow { max_tokens }
51            }
52            _ => TruncationStrategy::None,
53        }
54    }
55}
56
57/// Estimate token count using simple word-based heuristic (words * 1.3)
58fn estimate_tokens(text: &str) -> usize {
59    let words = text.split_whitespace().count();
60    ((words as f64) * 1.3).ceil() as usize
61}
62
63/// Apply truncation to a list of messages, returning the subset to send
64pub fn truncate_messages(
65    messages: &[ChatMessage],
66    strategy: &TruncationStrategy,
67) -> Vec<ChatMessage> {
68    match strategy {
69        TruncationStrategy::None => messages.to_vec(),
70        TruncationStrategy::Last(n) => {
71            let start = messages.len().saturating_sub(*n);
72            messages[start..].to_vec()
73        }
74        TruncationStrategy::First(n) => {
75            messages[..messages.len().min(*n)].to_vec()
76        }
77        TruncationStrategy::FirstLast { first, last } => {
78            let len = messages.len();
79            let first_end = (*first).min(len);
80            let last_start = len.saturating_sub(*last);
81            if first_end >= last_start {
82                // Overlap or adjacent — return all
83                messages.to_vec()
84            } else {
85                let mut result = messages[..first_end].to_vec();
86                result.extend_from_slice(&messages[last_start..]);
87                result
88            }
89        }
90        TruncationStrategy::SlidingWindow { max_tokens } => {
91            // Greedily include messages from oldest to newest until token budget exceeded
92            // Then drop from the front until we fit
93            let total_tokens: usize =
94                messages.iter().map(|m| estimate_tokens(&m.content)).sum();
95            if total_tokens <= *max_tokens {
96                return messages.to_vec();
97            }
98            let mut tokens_used = total_tokens;
99            let mut drop_count = 0;
100            for msg in messages.iter() {
101                if tokens_used <= *max_tokens {
102                    break;
103                }
104                tokens_used -= estimate_tokens(&msg.content);
105                drop_count += 1;
106            }
107            messages[drop_count..].to_vec()
108        }
109    }
110}
111
112pub struct ChatExecutor;
113
114#[async_trait]
115impl StepExecutor for ChatExecutor {
116    async fn execute(
117        &self,
118        step: &StepDef,
119        config: &StepConfig,
120        ctx: &Context,
121    ) -> Result<StepOutput, StepError> {
122        let provider = config.get_str("provider").unwrap_or("anthropic");
123        let model = config.get_str("model").unwrap_or(match provider {
124            "openai" => "gpt-4o-mini",
125            _ => "claude-3-haiku-20240307",
126        });
127        let max_tokens = config.get_u64("max_tokens").unwrap_or(1024);
128        let temperature = config
129            .values
130            .get("temperature")
131            .and_then(|v| v.as_f64())
132            .unwrap_or(0.0);
133        let api_key_env = config.get_str("api_key_env").unwrap_or(match provider {
134            "openai" => "OPENAI_API_KEY",
135            _ => "ANTHROPIC_API_KEY",
136        });
137        let timeout = config
138            .get_duration("timeout")
139            .unwrap_or(Duration::from_secs(120));
140
141        // Allow base URL override for testing
142        let anthropic_base = config
143            .get_str("anthropic_base_url")
144            .unwrap_or("https://api.anthropic.com");
145        let openai_base = config
146            .get_str("openai_base_url")
147            .unwrap_or("https://api.openai.com");
148
149        let api_key = std::env::var(api_key_env).map_err(|_| {
150            StepError::Fail(format!(
151                "API key not found: environment variable '{}' is not set",
152                api_key_env
153            ))
154        })?;
155
156        let prompt_template = step
157            .prompt
158            .as_ref()
159            .ok_or_else(|| StepError::Fail("chat step missing 'prompt' field".into()))?;
160
161        let prompt = ctx.render_template(prompt_template)?;
162
163        // Story 5.1 + 5.2: Build message list from chat history with optional truncation
164        let session_name = config.get_str("session");
165        let truncation = TruncationStrategy::from_config(config);
166        let mut messages: Vec<serde_json::Value> = if let Some(session) = session_name {
167            let history = ctx.get_chat_messages(session);
168            let truncated = truncate_messages(&history, &truncation);
169            truncated
170                .into_iter()
171                .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
172                .collect()
173        } else {
174            Vec::new()
175        };
176        messages.push(serde_json::json!({"role": "user", "content": prompt}));
177
178        let client = Client::builder()
179            .timeout(timeout)
180            .build()
181            .map_err(|e| StepError::Fail(format!("Failed to create HTTP client: {e}")))?;
182
183        let output = match provider {
184            "openai" => {
185                let url = format!("{}/v1/chat/completions", openai_base);
186                call_openai(&client, &api_key, model, &messages, max_tokens, temperature, &url).await?
187            }
188            _ => {
189                let url = format!("{}/v1/messages", anthropic_base);
190                call_anthropic(&client, &api_key, model, &messages, max_tokens, temperature, &url).await?
191            }
192        };
193
194        // Story 5.1: Store sent message and response in chat history
195        if let Some(session) = session_name {
196            let response_text = output.text().to_string();
197            ctx.append_chat_messages(
198                session,
199                vec![
200                    ChatMessage { role: "user".to_string(), content: prompt },
201                    ChatMessage { role: "assistant".to_string(), content: response_text },
202                ],
203            );
204        }
205
206        Ok(output)
207    }
208}
209
210async fn call_anthropic(
211    client: &Client,
212    api_key: &str,
213    model: &str,
214    messages: &[serde_json::Value],
215    max_tokens: u64,
216    temperature: f64,
217    url: &str,
218) -> Result<StepOutput, StepError> {
219    let body = serde_json::json!({
220        "model": model,
221        "max_tokens": max_tokens,
222        "temperature": temperature,
223        "messages": messages,
224    });
225
226    let response = client
227        .post(url)
228        .header("x-api-key", api_key)
229        .header("anthropic-version", "2023-06-01")
230        .header("content-type", "application/json")
231        .json(&body)
232        .send()
233        .await
234        .map_err(|e| StepError::Fail(format!("Anthropic API request failed: {e}")))?;
235
236    if !response.status().is_success() {
237        let status = response.status();
238        let text = response.text().await.unwrap_or_default();
239        return Err(StepError::Fail(format!(
240            "Anthropic API error ({}): {}",
241            status, text
242        )));
243    }
244
245    #[derive(Deserialize)]
246    struct AnthropicResponse {
247        model: String,
248        content: Vec<AnthropicContent>,
249        usage: AnthropicUsage,
250    }
251    #[derive(Deserialize)]
252    struct AnthropicContent {
253        text: String,
254    }
255    #[derive(Deserialize)]
256    struct AnthropicUsage {
257        input_tokens: u64,
258        output_tokens: u64,
259    }
260
261    let resp: AnthropicResponse = response
262        .json()
263        .await
264        .map_err(|e| StepError::Fail(format!("Failed to parse Anthropic response: {e}")))?;
265
266    let text = resp
267        .content
268        .into_iter()
269        .map(|c| c.text)
270        .collect::<Vec<_>>()
271        .join("\n");
272
273    Ok(StepOutput::Chat(ChatOutput {
274        response: text,
275        model: resp.model,
276        input_tokens: resp.usage.input_tokens,
277        output_tokens: resp.usage.output_tokens,
278    }))
279}
280
281async fn call_openai(
282    client: &Client,
283    api_key: &str,
284    model: &str,
285    messages: &[serde_json::Value],
286    max_tokens: u64,
287    temperature: f64,
288    url: &str,
289) -> Result<StepOutput, StepError> {
290    let body = serde_json::json!({
291        "model": model,
292        "max_tokens": max_tokens,
293        "temperature": temperature,
294        "messages": messages,
295    });
296
297    let response = client
298        .post(url)
299        .header("Authorization", format!("Bearer {}", api_key))
300        .header("content-type", "application/json")
301        .json(&body)
302        .send()
303        .await
304        .map_err(|e| StepError::Fail(format!("OpenAI API request failed: {e}")))?;
305
306    if !response.status().is_success() {
307        let status = response.status();
308        let text = response.text().await.unwrap_or_default();
309        return Err(StepError::Fail(format!(
310            "OpenAI API error ({}): {}",
311            status, text
312        )));
313    }
314
315    #[derive(Deserialize)]
316    struct OpenAIResponse {
317        model: String,
318        choices: Vec<OpenAIChoice>,
319        usage: OpenAIUsage,
320    }
321    #[derive(Deserialize)]
322    struct OpenAIChoice {
323        message: OpenAIMessage,
324    }
325    #[derive(Deserialize)]
326    struct OpenAIMessage {
327        content: String,
328    }
329    #[derive(Deserialize)]
330    struct OpenAIUsage {
331        prompt_tokens: u64,
332        completion_tokens: u64,
333    }
334
335    let resp: OpenAIResponse = response
336        .json()
337        .await
338        .map_err(|e| StepError::Fail(format!("Failed to parse OpenAI response: {e}")))?;
339
340    let text = resp
341        .choices
342        .into_iter()
343        .map(|c| c.message.content)
344        .collect::<Vec<_>>()
345        .join("\n");
346
347    Ok(StepOutput::Chat(ChatOutput {
348        response: text,
349        model: resp.model,
350        input_tokens: resp.usage.prompt_tokens,
351        output_tokens: resp.usage.completion_tokens,
352    }))
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use std::collections::HashMap;
359
360    fn make_step(prompt: &str) -> StepDef {
361        StepDef {
362            name: "test_chat".to_string(),
363            step_type: crate::workflow::schema::StepType::Chat,
364            run: None,
365            prompt: Some(prompt.to_string()),
366            condition: None,
367            on_pass: None,
368            on_fail: None,
369            message: None,
370            scope: None,
371            max_iterations: None,
372            initial_value: None,
373            items: None,
374            parallel: None,
375            steps: None,
376            config: HashMap::new(),
377            outputs: None,
378            output_type: None,
379            async_exec: None,
380        }
381    }
382
383    #[tokio::test]
384    async fn chat_missing_api_key_friendly_error() {
385        // Use a custom env var name that definitely won't be set
386        let step = StepDef {
387            name: "test_chat".to_string(),
388            step_type: crate::workflow::schema::StepType::Chat,
389            run: None,
390            prompt: Some("Hello".to_string()),
391            condition: None,
392            on_pass: None,
393            on_fail: None,
394            message: None,
395            scope: None,
396            max_iterations: None,
397            initial_value: None,
398            items: None,
399            parallel: None,
400            steps: None,
401            config: HashMap::new(),
402            outputs: None,
403            output_type: None,
404            async_exec: None,
405        };
406        // Override the api_key_env to a definitely-unset var
407        let mut config_values = HashMap::new();
408        config_values.insert(
409            "api_key_env".to_string(),
410            serde_json::Value::String("DEFINITELY_NOT_SET_API_KEY_XYZ123".to_string()),
411        );
412        let config = StepConfig { values: config_values };
413        let ctx = Context::new(String::new(), HashMap::new());
414        let result = ChatExecutor.execute(&step, &config, &ctx).await;
415        assert!(result.is_err());
416        let err = result.unwrap_err().to_string();
417        assert!(
418            err.contains("DEFINITELY_NOT_SET_API_KEY_XYZ123"),
419            "Error should mention env var name: {}", err
420        );
421    }
422
423    #[tokio::test]
424    async fn chat_missing_prompt_field_error() {
425        std::env::set_var("ANTHROPIC_API_KEY", "test-key");
426        let step = StepDef {
427            name: "test".to_string(),
428            step_type: crate::workflow::schema::StepType::Chat,
429            run: None,
430            prompt: None,  // missing!
431            condition: None,
432            on_pass: None,
433            on_fail: None,
434            message: None,
435            scope: None,
436            max_iterations: None,
437            initial_value: None,
438            items: None,
439            parallel: None,
440            steps: None,
441            config: HashMap::new(),
442            outputs: None,
443            output_type: None,
444            async_exec: None,
445        };
446        let config = StepConfig::default();
447        let ctx = Context::new(String::new(), HashMap::new());
448        let result = ChatExecutor.execute(&step, &config, &ctx).await;
449        assert!(result.is_err());
450        let err = result.unwrap_err().to_string();
451        assert!(err.contains("prompt"), "Error should mention prompt: {}", err);
452    }
453
454    #[tokio::test]
455    async fn chat_mock_anthropic_response() {
456        // Use a wiremock server to mock the Anthropic API
457        use wiremock::{MockServer, Mock, ResponseTemplate};
458        use wiremock::matchers::{method, path};
459
460        let mock_server = MockServer::start().await;
461        let response_body = serde_json::json!({
462            "model": "claude-3-haiku-20240307",
463            "content": [{"type": "text", "text": "Hello from mock!"}],
464            "usage": {"input_tokens": 10, "output_tokens": 5}
465        });
466
467        Mock::given(method("POST"))
468            .and(path("/v1/messages"))
469            .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
470            .mount(&mock_server)
471            .await;
472
473        std::env::set_var("ANTHROPIC_API_KEY", "test-key");
474
475        let step = make_step("Hello");
476        let mut config_values = HashMap::new();
477        config_values.insert(
478            "anthropic_base_url".to_string(),
479            serde_json::Value::String(mock_server.uri()),
480        );
481        let config = StepConfig { values: config_values };
482        let ctx = Context::new(String::new(), HashMap::new());
483
484        let result = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
485        assert_eq!(result.text(), "Hello from mock!");
486        if let StepOutput::Chat(o) = result {
487            assert_eq!(o.model, "claude-3-haiku-20240307");
488            assert_eq!(o.input_tokens, 10);
489            assert_eq!(o.output_tokens, 5);
490        } else {
491            panic!("Expected Chat output");
492        }
493    }
494
495    fn make_messages(count: usize) -> Vec<ChatMessage> {
496        (0..count)
497            .map(|i| ChatMessage {
498                role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
499                content: format!("message {}", i),
500            })
501            .collect()
502    }
503
504    #[test]
505    fn truncation_last_keeps_n_messages() {
506        let msgs = make_messages(50);
507        let result = truncate_messages(&msgs, &TruncationStrategy::Last(10));
508        assert_eq!(result.len(), 10);
509        assert_eq!(result[0].content, "message 40");
510        assert_eq!(result[9].content, "message 49");
511    }
512
513    #[test]
514    fn truncation_first_last_keeps_first_and_last() {
515        let msgs = make_messages(50);
516        let result =
517            truncate_messages(&msgs, &TruncationStrategy::FirstLast { first: 2, last: 5 });
518        assert_eq!(result.len(), 7);
519        assert_eq!(result[0].content, "message 0");
520        assert_eq!(result[1].content, "message 1");
521        assert_eq!(result[2].content, "message 45");
522    }
523
524    #[test]
525    fn truncation_sliding_window_fits_within_tokens() {
526        // Each message "message X" is ~1-2 words → ~2-3 estimated tokens
527        // Build 50 messages; set max_tokens low enough to drop some
528        let msgs = make_messages(50);
529        let result =
530            truncate_messages(&msgs, &TruncationStrategy::SlidingWindow { max_tokens: 50 });
531        // Total tokens of 50 messages would exceed 50; result should be smaller
532        let total: usize = result.iter().map(|m| estimate_tokens(&m.content)).sum();
533        assert!(total <= 50, "Expected tokens <= 50, got {}", total);
534    }
535
536    #[test]
537    fn truncation_none_returns_all() {
538        let msgs = make_messages(10);
539        let result = truncate_messages(&msgs, &TruncationStrategy::None);
540        assert_eq!(result.len(), 10);
541    }
542
543    #[tokio::test]
544    async fn chat_history_stores_messages_and_resends_on_second_call() {
545        use wiremock::{Mock, MockServer, ResponseTemplate};
546        use wiremock::matchers::{method, path};
547
548        let mock_server = MockServer::start().await;
549        let response_body = serde_json::json!({
550            "model": "claude-3-haiku-20240307",
551            "content": [{"type": "text", "text": "Response text"}],
552            "usage": {"input_tokens": 10, "output_tokens": 5}
553        });
554
555        Mock::given(method("POST"))
556            .and(path("/v1/messages"))
557            .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
558            .expect(2)
559            .mount(&mock_server)
560            .await;
561
562        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
563
564        let step = make_step("First message");
565        let mut config_values = HashMap::new();
566        config_values.insert(
567            "anthropic_base_url".to_string(),
568            serde_json::Value::String(mock_server.uri()),
569        );
570        config_values.insert(
571            "session".to_string(),
572            serde_json::Value::String("review".to_string()),
573        );
574        let config = StepConfig { values: config_values };
575        let ctx = Context::new(String::new(), HashMap::new());
576
577        // First call — stores user + assistant messages
578        let _result1 = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
579
580        // After first call, history should have 2 messages
581        let history = ctx.get_chat_messages("review");
582        assert_eq!(history.len(), 2);
583        assert_eq!(history[0].role, "user");
584        assert_eq!(history[0].content, "First message");
585        assert_eq!(history[1].role, "assistant");
586
587        // Second call — history is sent along with new message
588        let step2 = make_step("Second message");
589        let _result2 = ChatExecutor.execute(&step2, &config, &ctx).await.unwrap();
590
591        // History now has 4 messages
592        let history2 = ctx.get_chat_messages("review");
593        assert_eq!(history2.len(), 4);
594    }
595}