Skip to main content

minion_engine/steps/
chat.rs

1use std::time::Duration;
2
3use async_trait::async_trait;
4
5use crate::config::StepConfig;
6use crate::engine::context::{ChatMessage, Context};
7use crate::error::StepError;
8use crate::workflow::schema::StepDef;
9
10use super::{ChatOutput, StepExecutor, StepOutput};
11
12// ── Rig imports ──────────────────────────────────────────────────
13use rig::client::CompletionClient;
14use rig::completion::{CompletionError, CompletionModel, CompletionResponse};
15use rig::message::{AssistantContent, Message};
16
17/// Truncation strategy for chat history (Story 5.2)
18#[derive(Debug, Clone)]
19pub enum TruncationStrategy {
20    /// Keep all messages (default)
21    None,
22    /// Keep the last N messages
23    Last(usize),
24    /// Keep the first N messages
25    First(usize),
26    /// Keep the first `first` and last `last` messages
27    FirstLast { first: usize, last: usize },
28    /// Drop oldest messages until total estimated tokens <= max_tokens
29    SlidingWindow { max_tokens: usize },
30}
31
32impl TruncationStrategy {
33    /// Parse from StepConfig keys: truncation_strategy, truncation_count, truncation_first,
34    /// truncation_last, truncation_max_tokens
35    pub fn from_config(config: &crate::config::StepConfig) -> Self {
36        match config.get_str("truncation_strategy") {
37            Some("last") => {
38                let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
39                TruncationStrategy::Last(n)
40            }
41            Some("first") => {
42                let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
43                TruncationStrategy::First(n)
44            }
45            Some("first_last") => {
46                let first = config.get_u64("truncation_first").unwrap_or(2) as usize;
47                let last = config.get_u64("truncation_last").unwrap_or(5) as usize;
48                TruncationStrategy::FirstLast { first, last }
49            }
50            Some("sliding_window") => {
51                let max_tokens =
52                    config.get_u64("truncation_max_tokens").unwrap_or(50_000) as usize;
53                TruncationStrategy::SlidingWindow { max_tokens }
54            }
55            _ => TruncationStrategy::None,
56        }
57    }
58}
59
60/// Estimate token count using simple word-based heuristic (words * 1.3)
61fn estimate_tokens(text: &str) -> usize {
62    let words = text.split_whitespace().count();
63    ((words as f64) * 1.3).ceil() as usize
64}
65
66/// Apply truncation to a list of messages, returning the subset to send
67pub fn truncate_messages(
68    messages: &[ChatMessage],
69    strategy: &TruncationStrategy,
70) -> Vec<ChatMessage> {
71    match strategy {
72        TruncationStrategy::None => messages.to_vec(),
73        TruncationStrategy::Last(n) => {
74            let start = messages.len().saturating_sub(*n);
75            messages[start..].to_vec()
76        }
77        TruncationStrategy::First(n) => {
78            messages[..messages.len().min(*n)].to_vec()
79        }
80        TruncationStrategy::FirstLast { first, last } => {
81            let len = messages.len();
82            let first_end = (*first).min(len);
83            let last_start = len.saturating_sub(*last);
84            if first_end >= last_start {
85                // Overlap or adjacent — return all
86                messages.to_vec()
87            } else {
88                let mut result = messages[..first_end].to_vec();
89                result.extend_from_slice(&messages[last_start..]);
90                result
91            }
92        }
93        TruncationStrategy::SlidingWindow { max_tokens } => {
94            // Greedily include messages from oldest to newest until token budget exceeded
95            // Then drop from the front until we fit
96            let total_tokens: usize =
97                messages.iter().map(|m| estimate_tokens(&m.content)).sum();
98            if total_tokens <= *max_tokens {
99                return messages.to_vec();
100            }
101            let mut tokens_used = total_tokens;
102            let mut drop_count = 0;
103            for msg in messages.iter() {
104                if tokens_used <= *max_tokens {
105                    break;
106                }
107                tokens_used -= estimate_tokens(&msg.content);
108                drop_count += 1;
109            }
110            messages[drop_count..].to_vec()
111        }
112    }
113}
114
115// ── Rig helper functions ─────────────────────────────────────────
116
117/// Convert internal ChatMessage list to Rig Message vector
118fn to_rig_messages(history: &[ChatMessage]) -> Vec<Message> {
119    history
120        .iter()
121        .map(|m| match m.role.as_str() {
122            "assistant" => {
123                Message::from(AssistantContent::text(&m.content))
124            }
125            _ => {
126                // user, system, or any other role → treat as user message
127                Message::from(m.content.as_str())
128            }
129        })
130        .collect()
131}
132
133/// Extract text and token usage from a Rig CompletionResponse
134fn extract_chat_output<T>(response: CompletionResponse<T>, model: &str) -> ChatOutput {
135    let text = response
136        .choice
137        .iter()
138        .filter_map(|c| {
139            if let AssistantContent::Text(t) = c {
140                Some(t.text.clone())
141            } else {
142                None
143            }
144        })
145        .collect::<Vec<_>>()
146        .join("\n");
147
148    ChatOutput {
149        response: text,
150        model: model.to_string(),
151        input_tokens: response.usage.input_tokens,
152        output_tokens: response.usage.output_tokens,
153    }
154}
155
156/// Map Rig CompletionError to StepError
157fn map_rig_error(provider: &str, err: CompletionError) -> StepError {
158    StepError::Fail(format!("{} API error: {}", provider, err))
159}
160
161/// Build Rig client error to StepError
162fn map_build_error(provider: &str, err: impl std::fmt::Display) -> StepError {
163    StepError::Fail(format!("Failed to build {} client: {}", provider, err))
164}
165
166/// Macro to avoid repeating completion_request → send → extract in every match arm.
167/// Each provider arm creates its own `client`, then invokes this macro.
168macro_rules! send_completion {
169    ($client:expr, $model_name:expr, $prompt:expr, $messages:expr,
170     $temperature:expr, $max_tokens:expr, $provider:expr) => {{
171        let model = $client.completion_model($model_name);
172        let resp: Result<_, CompletionError> = model
173            .completion_request($prompt)
174            .messages($messages)
175            .temperature($temperature)
176            .max_tokens($max_tokens)
177            .send()
178            .await;
179        let resp = resp.map_err(|e| map_rig_error($provider, e))?;
180        Ok::<StepOutput, StepError>(StepOutput::Chat(extract_chat_output(resp, $model_name)))
181    }};
182}
183
184/// Call LLM via Rig — unified multi-provider completion
185#[allow(clippy::too_many_arguments)]
186async fn call_via_rig(
187    provider: &str,
188    model_name: &str,
189    api_key: &str,
190    base_url: Option<&str>,
191    messages: Vec<Message>,
192    prompt: &str,
193    temperature: f64,
194    max_tokens: u64,
195    timeout: Duration,
196) -> Result<StepOutput, StepError> {
197    tokio::time::timeout(timeout, async {
198        match provider {
199            // ── Anthropic ────────────────────────────────────────
200            "anthropic" => {
201                let mut builder = rig::providers::anthropic::Client::builder()
202                    .api_key(api_key);
203                if let Some(url) = base_url {
204                    builder = builder.base_url(url);
205                }
206                let client = builder.build().map_err(|e| map_build_error("anthropic", e))?;
207                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "anthropic")
208            }
209
210            // ── OpenAI (Chat Completions API — LiteLLM compatible) ──
211            "openai" => {
212                let mut builder = rig::providers::openai::CompletionsClient::builder()
213                    .api_key(api_key);
214                if let Some(url) = base_url {
215                    builder = builder.base_url(url);
216                }
217                let client = builder.build().map_err(|e| map_build_error("openai", e))?;
218                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "openai")
219            }
220
221            // ── Ollama (local, no API key) ───────────────────────
222            "ollama" => {
223                let mut builder = rig::providers::ollama::Client::builder()
224                    .api_key(rig::client::Nothing);
225                let url = base_url.unwrap_or("http://localhost:11434");
226                builder = builder.base_url(url);
227                let client = builder.build().map_err(|e| map_build_error("ollama", e))?;
228                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "ollama")
229            }
230
231            // ── Groq ─────────────────────────────────────────────
232            "groq" => {
233                let mut builder = rig::providers::groq::Client::builder()
234                    .api_key(api_key);
235                if let Some(url) = base_url {
236                    builder = builder.base_url(url);
237                }
238                let client = builder.build().map_err(|e| map_build_error("groq", e))?;
239                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "groq")
240            }
241
242            // ── DeepSeek ─────────────────────────────────────────
243            "deepseek" => {
244                let mut builder = rig::providers::deepseek::Client::builder()
245                    .api_key(api_key);
246                if let Some(url) = base_url {
247                    builder = builder.base_url(url);
248                }
249                let client = builder.build().map_err(|e| map_build_error("deepseek", e))?;
250                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "deepseek")
251            }
252
253            // ── Google Gemini ────────────────────────────────────
254            "gemini" | "google" => {
255                let mut builder = rig::providers::gemini::Client::builder()
256                    .api_key(api_key);
257                if let Some(url) = base_url {
258                    builder = builder.base_url(url);
259                }
260                let client = builder.build().map_err(|e| map_build_error("gemini", e))?;
261                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "gemini")
262            }
263
264            // ── Cohere ───────────────────────────────────────────
265            "cohere" => {
266                let mut builder = rig::providers::cohere::Client::builder()
267                    .api_key(api_key);
268                if let Some(url) = base_url {
269                    builder = builder.base_url(url);
270                }
271                let client = builder.build().map_err(|e| map_build_error("cohere", e))?;
272                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "cohere")
273            }
274
275            // ── Perplexity ───────────────────────────────────────
276            "perplexity" => {
277                let mut builder = rig::providers::perplexity::Client::builder()
278                    .api_key(api_key);
279                if let Some(url) = base_url {
280                    builder = builder.base_url(url);
281                }
282                let client = builder.build().map_err(|e| map_build_error("perplexity", e))?;
283                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "perplexity")
284            }
285
286            // ── xAI (Grok) ──────────────────────────────────────
287            "xai" | "grok" => {
288                let mut builder = rig::providers::xai::Client::builder()
289                    .api_key(api_key);
290                if let Some(url) = base_url {
291                    builder = builder.base_url(url);
292                }
293                let client = builder.build().map_err(|e| map_build_error("xai", e))?;
294                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "xai")
295            }
296
297            // ── Mistral ─────────────────────────────────────────
298            "mistral" => {
299                let mut builder = rig::providers::mistral::Client::builder()
300                    .api_key(api_key);
301                if let Some(url) = base_url {
302                    builder = builder.base_url(url);
303                }
304                let client = builder.build().map_err(|e| map_build_error("mistral", e))?;
305                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "mistral")
306            }
307
308            // ── Any other: OpenAI-compatible with custom base_url ──
309            // This covers LiteLLM, vLLM, Azure (via base_url), or
310            // any service that implements the OpenAI Chat Completions API.
311            other => {
312                let url = base_url.ok_or_else(|| StepError::Fail(format!(
313                    "Unknown provider '{}': set 'base_url' to use as OpenAI-compatible endpoint",
314                    other
315                )))?;
316                let builder = rig::providers::openai::CompletionsClient::builder()
317                    .api_key(api_key)
318                    .base_url(url);
319                let client = builder.build().map_err(|e| map_build_error(other, e))?;
320                send_completion!(client, model_name, prompt, messages, temperature, max_tokens, other)
321            }
322        }
323    })
324    .await
325    .map_err(|_| StepError::Timeout(timeout))?
326}
327
328// ── ChatExecutor ─────────────────────────────────────────────────
329
330pub struct ChatExecutor;
331
332#[async_trait]
333impl StepExecutor for ChatExecutor {
334    async fn execute(
335        &self,
336        step: &StepDef,
337        config: &StepConfig,
338        ctx: &Context,
339    ) -> Result<StepOutput, StepError> {
340        let provider = config.get_str("provider").unwrap_or("anthropic");
341        let model = config.get_str("model").unwrap_or(match provider {
342            "openai" => "gpt-4o-mini",
343            "ollama" => "llama3.2",
344            "groq" => "llama-3.3-70b-versatile",
345            "deepseek" => "deepseek-chat",
346            "gemini" | "google" => "gemini-2.0-flash",
347            _ => "claude-3-haiku-20240307",
348        });
349        let max_tokens = config.get_u64("max_tokens").unwrap_or(1024);
350        let temperature = config
351            .values
352            .get("temperature")
353            .and_then(|v| v.as_f64())
354            .unwrap_or(0.0);
355        let timeout = config
356            .get_duration("timeout")
357            .unwrap_or(Duration::from_secs(120));
358
359        // Resolve API key (Ollama doesn't need one)
360        let api_key = if provider == "ollama" {
361            String::new()
362        } else {
363            let api_key_env = config.get_str("api_key_env").unwrap_or(match provider {
364                "openai" => "OPENAI_API_KEY",
365                "groq" => "GROQ_API_KEY",
366                "deepseek" => "DEEPSEEK_API_KEY",
367                "gemini" | "google" => "GEMINI_API_KEY",
368                "cohere" => "COHERE_API_KEY",
369                "perplexity" => "PERPLEXITY_API_KEY",
370                "xai" | "grok" => "XAI_API_KEY",
371                "mistral" => "MISTRAL_API_KEY",
372                _ => "ANTHROPIC_API_KEY",
373            });
374            std::env::var(api_key_env).map_err(|_| {
375                StepError::Fail(format!(
376                    "API key not found: environment variable '{}' is not set",
377                    api_key_env
378                ))
379            })?
380        };
381
382        // Resolve base_url: generic > provider-specific > default
383        let base_url: Option<String> = config
384            .get_str("base_url")
385            .map(String::from)
386            .or_else(|| {
387                // Backward compatibility with old per-provider config keys
388                match provider {
389                    "anthropic" => config.get_str("anthropic_base_url").map(String::from),
390                    "openai" => config.get_str("openai_base_url").map(String::from),
391                    _ => None,
392                }
393            });
394
395        let prompt_template = step
396            .prompt
397            .as_ref()
398            .ok_or_else(|| StepError::Fail("chat step missing 'prompt' field".into()))?;
399
400        let prompt = ctx.render_template(prompt_template)?;
401
402        // Story 5.1 + 5.2: Build message list from chat history with optional truncation
403        let session_name = config.get_str("session");
404        let truncation = TruncationStrategy::from_config(config);
405        let rig_messages: Vec<Message> = if let Some(session) = session_name {
406            let history = ctx.get_chat_messages(session);
407            let truncated = truncate_messages(&history, &truncation);
408            to_rig_messages(&truncated)
409        } else {
410            Vec::new()
411        };
412
413        let output = call_via_rig(
414            provider,
415            model,
416            &api_key,
417            base_url.as_deref(),
418            rig_messages,
419            &prompt,
420            temperature,
421            max_tokens,
422            timeout,
423        )
424        .await?;
425
426        // Story 5.1: Store sent message and response in chat history
427        if let Some(session) = session_name {
428            let response_text = output.text().to_string();
429            ctx.append_chat_messages(
430                session,
431                vec![
432                    ChatMessage { role: "user".to_string(), content: prompt },
433                    ChatMessage { role: "assistant".to_string(), content: response_text },
434                ],
435            );
436        }
437
438        Ok(output)
439    }
440}
441
442// ── Tests ────────────────────────────────────────────────────────
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use std::collections::HashMap;
448
449    fn make_step(prompt: &str) -> StepDef {
450        StepDef {
451            name: "test_chat".to_string(),
452            step_type: crate::workflow::schema::StepType::Chat,
453            run: None,
454            prompt: Some(prompt.to_string()),
455            condition: None,
456            on_pass: None,
457            on_fail: None,
458            message: None,
459            scope: None,
460            max_iterations: None,
461            initial_value: None,
462            items: None,
463            parallel: None,
464            steps: None,
465            config: HashMap::new(),
466            outputs: None,
467            output_type: None,
468            async_exec: None,
469        }
470    }
471
472    #[tokio::test]
473    async fn chat_missing_api_key_friendly_error() {
474        // Use a custom env var name that definitely won't be set
475        let step = StepDef {
476            name: "test_chat".to_string(),
477            step_type: crate::workflow::schema::StepType::Chat,
478            run: None,
479            prompt: Some("Hello".to_string()),
480            condition: None,
481            on_pass: None,
482            on_fail: None,
483            message: None,
484            scope: None,
485            max_iterations: None,
486            initial_value: None,
487            items: None,
488            parallel: None,
489            steps: None,
490            config: HashMap::new(),
491            outputs: None,
492            output_type: None,
493            async_exec: None,
494        };
495        // Override the api_key_env to a definitely-unset var
496        let mut config_values = HashMap::new();
497        config_values.insert(
498            "api_key_env".to_string(),
499            serde_json::Value::String("DEFINITELY_NOT_SET_API_KEY_XYZ123".to_string()),
500        );
501        let config = StepConfig { values: config_values };
502        let ctx = Context::new(String::new(), HashMap::new());
503        let result = ChatExecutor.execute(&step, &config, &ctx).await;
504        assert!(result.is_err());
505        let err = result.unwrap_err().to_string();
506        assert!(
507            err.contains("DEFINITELY_NOT_SET_API_KEY_XYZ123"),
508            "Error should mention env var name: {}", err
509        );
510    }
511
512    #[tokio::test]
513    async fn chat_missing_prompt_field_error() {
514        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
515        let step = StepDef {
516            name: "test".to_string(),
517            step_type: crate::workflow::schema::StepType::Chat,
518            run: None,
519            prompt: None,  // missing!
520            condition: None,
521            on_pass: None,
522            on_fail: None,
523            message: None,
524            scope: None,
525            max_iterations: None,
526            initial_value: None,
527            items: None,
528            parallel: None,
529            steps: None,
530            config: HashMap::new(),
531            outputs: None,
532            output_type: None,
533            async_exec: None,
534        };
535        let config = StepConfig::default();
536        let ctx = Context::new(String::new(), HashMap::new());
537        let result = ChatExecutor.execute(&step, &config, &ctx).await;
538        assert!(result.is_err());
539        let err = result.unwrap_err().to_string();
540        assert!(err.contains("prompt"), "Error should mention prompt: {}", err);
541    }
542
543    #[tokio::test]
544    async fn chat_mock_anthropic_response() {
545        // Rig's Anthropic client sends POST to /v1/messages with the same format
546        // as the raw API. We mock the endpoint using wiremock.
547        use wiremock::{MockServer, Mock, ResponseTemplate};
548        use wiremock::matchers::{method, path};
549
550        let mock_server = MockServer::start().await;
551        let response_body = serde_json::json!({
552            "id": "msg_mock123",
553            "type": "message",
554            "role": "assistant",
555            "model": "claude-3-haiku-20240307",
556            "content": [{"type": "text", "text": "Hello from mock!"}],
557            "usage": {"input_tokens": 10, "output_tokens": 5},
558            "stop_reason": "end_turn",
559            "stop_sequence": null
560        });
561
562        Mock::given(method("POST"))
563            .and(path("/v1/messages"))
564            .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
565            .mount(&mock_server)
566            .await;
567
568        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
569
570        let step = make_step("Hello");
571        let mut config_values = HashMap::new();
572        // Use Rig's base_url config to route to wiremock
573        config_values.insert(
574            "base_url".to_string(),
575            serde_json::Value::String(mock_server.uri()),
576        );
577        let config = StepConfig { values: config_values };
578        let ctx = Context::new(String::new(), HashMap::new());
579
580        let result = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
581        assert_eq!(result.text(), "Hello from mock!");
582        if let StepOutput::Chat(o) = result {
583            assert_eq!(o.model, "claude-3-haiku-20240307");
584            assert_eq!(o.input_tokens, 10);
585            assert_eq!(o.output_tokens, 5);
586        } else {
587            panic!("Expected Chat output");
588        }
589    }
590
591    fn make_messages(count: usize) -> Vec<ChatMessage> {
592        (0..count)
593            .map(|i| ChatMessage {
594                role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
595                content: format!("message {}", i),
596            })
597            .collect()
598    }
599
600    #[test]
601    fn truncation_last_keeps_n_messages() {
602        let msgs = make_messages(50);
603        let result = truncate_messages(&msgs, &TruncationStrategy::Last(10));
604        assert_eq!(result.len(), 10);
605        assert_eq!(result[0].content, "message 40");
606        assert_eq!(result[9].content, "message 49");
607    }
608
609    #[test]
610    fn truncation_first_last_keeps_first_and_last() {
611        let msgs = make_messages(50);
612        let result =
613            truncate_messages(&msgs, &TruncationStrategy::FirstLast { first: 2, last: 5 });
614        assert_eq!(result.len(), 7);
615        assert_eq!(result[0].content, "message 0");
616        assert_eq!(result[1].content, "message 1");
617        assert_eq!(result[2].content, "message 45");
618    }
619
620    #[test]
621    fn truncation_sliding_window_fits_within_tokens() {
622        // Each message "message X" is ~1-2 words → ~2-3 estimated tokens
623        // Build 50 messages; set max_tokens low enough to drop some
624        let msgs = make_messages(50);
625        let result =
626            truncate_messages(&msgs, &TruncationStrategy::SlidingWindow { max_tokens: 50 });
627        // Total tokens of 50 messages would exceed 50; result should be smaller
628        let total: usize = result.iter().map(|m| estimate_tokens(&m.content)).sum();
629        assert!(total <= 50, "Expected tokens <= 50, got {}", total);
630    }
631
632    #[test]
633    fn truncation_none_returns_all() {
634        let msgs = make_messages(10);
635        let result = truncate_messages(&msgs, &TruncationStrategy::None);
636        assert_eq!(result.len(), 10);
637    }
638
639    #[tokio::test]
640    async fn chat_history_stores_messages_and_resends_on_second_call() {
641        use wiremock::{Mock, MockServer, ResponseTemplate};
642        use wiremock::matchers::{method, path};
643
644        let mock_server = MockServer::start().await;
645        let response_body = serde_json::json!({
646            "id": "msg_mock456",
647            "type": "message",
648            "role": "assistant",
649            "model": "claude-3-haiku-20240307",
650            "content": [{"type": "text", "text": "Response text"}],
651            "usage": {"input_tokens": 10, "output_tokens": 5},
652            "stop_reason": "end_turn",
653            "stop_sequence": null
654        });
655
656        Mock::given(method("POST"))
657            .and(path("/v1/messages"))
658            .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
659            .expect(2)
660            .mount(&mock_server)
661            .await;
662
663        unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
664
665        let step = make_step("First message");
666        let mut config_values = HashMap::new();
667        config_values.insert(
668            "base_url".to_string(),
669            serde_json::Value::String(mock_server.uri()),
670        );
671        config_values.insert(
672            "session".to_string(),
673            serde_json::Value::String("review".to_string()),
674        );
675        let config = StepConfig { values: config_values };
676        let ctx = Context::new(String::new(), HashMap::new());
677
678        // First call — stores user + assistant messages
679        let _result1 = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
680
681        // After first call, history should have 2 messages
682        let history = ctx.get_chat_messages("review");
683        assert_eq!(history.len(), 2);
684        assert_eq!(history[0].role, "user");
685        assert_eq!(history[0].content, "First message");
686        assert_eq!(history[1].role, "assistant");
687
688        // Second call — history is sent along with new message
689        let step2 = make_step("Second message");
690        let _result2 = ChatExecutor.execute(&step2, &config, &ctx).await.unwrap();
691
692        // History now has 4 messages
693        let history2 = ctx.get_chat_messages("review");
694        assert_eq!(history2.len(), 4);
695    }
696
697    #[test]
698    fn to_rig_messages_converts_correctly() {
699        let history = vec![
700            ChatMessage { role: "user".to_string(), content: "Hello".to_string() },
701            ChatMessage { role: "assistant".to_string(), content: "Hi!".to_string() },
702            ChatMessage { role: "user".to_string(), content: "How are you?".to_string() },
703        ];
704        let rig_msgs = to_rig_messages(&history);
705        assert_eq!(rig_msgs.len(), 3);
706
707        // Verify user messages
708        match &rig_msgs[0] {
709            Message::User { .. } => {},
710            _ => panic!("Expected User message at index 0"),
711        }
712
713        // Verify assistant messages
714        match &rig_msgs[1] {
715            Message::Assistant { .. } => {},
716            _ => panic!("Expected Assistant message at index 1"),
717        }
718    }
719
720    #[test]
721    fn ollama_does_not_require_api_key() {
722        // Verify that "ollama" provider skips the API key check
723        let step = make_step("Hello");
724        let mut config_values = HashMap::new();
725        config_values.insert(
726            "provider".to_string(),
727            serde_json::Value::String("ollama".to_string()),
728        );
729        // No api_key_env set — should not fail at config resolution
730        let config = StepConfig { values: config_values };
731        let ctx = Context::new(String::new(), HashMap::new());
732
733        // We can't fully execute without Ollama running, but we verify
734        // the provider is recognized and no API key error is raised.
735        // The execute will fail at the HTTP level (connection refused)
736        // rather than at the "API key not found" level.
737        let rt = tokio::runtime::Runtime::new().unwrap();
738        let result = rt.block_on(ChatExecutor.execute(&step, &config, &ctx));
739        assert!(result.is_err());
740        let err = result.unwrap_err().to_string();
741        // Should NOT contain "API key not found"
742        assert!(
743            !err.contains("API key not found"),
744            "Ollama should not require API key, but got: {}",
745            err
746        );
747    }
748}