1use std::time::Duration;
2
3use async_trait::async_trait;
4
5use crate::config::StepConfig;
6use crate::engine::context::{ChatMessage, Context};
7use crate::error::StepError;
8use crate::workflow::schema::StepDef;
9
10use super::{ChatOutput, StepExecutor, StepOutput};
11
12use rig::client::CompletionClient;
14use rig::completion::{CompletionError, CompletionModel, CompletionResponse};
15use rig::message::{AssistantContent, Message};
16
17#[derive(Debug, Clone)]
19pub enum TruncationStrategy {
20 None,
22 Last(usize),
24 First(usize),
26 FirstLast { first: usize, last: usize },
28 SlidingWindow { max_tokens: usize },
30}
31
32impl TruncationStrategy {
33 pub fn from_config(config: &crate::config::StepConfig) -> Self {
36 match config.get_str("truncation_strategy") {
37 Some("last") => {
38 let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
39 TruncationStrategy::Last(n)
40 }
41 Some("first") => {
42 let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
43 TruncationStrategy::First(n)
44 }
45 Some("first_last") => {
46 let first = config.get_u64("truncation_first").unwrap_or(2) as usize;
47 let last = config.get_u64("truncation_last").unwrap_or(5) as usize;
48 TruncationStrategy::FirstLast { first, last }
49 }
50 Some("sliding_window") => {
51 let max_tokens =
52 config.get_u64("truncation_max_tokens").unwrap_or(50_000) as usize;
53 TruncationStrategy::SlidingWindow { max_tokens }
54 }
55 _ => TruncationStrategy::None,
56 }
57 }
58}
59
60fn estimate_tokens(text: &str) -> usize {
62 let words = text.split_whitespace().count();
63 ((words as f64) * 1.3).ceil() as usize
64}
65
66pub fn truncate_messages(
68 messages: &[ChatMessage],
69 strategy: &TruncationStrategy,
70) -> Vec<ChatMessage> {
71 match strategy {
72 TruncationStrategy::None => messages.to_vec(),
73 TruncationStrategy::Last(n) => {
74 let start = messages.len().saturating_sub(*n);
75 messages[start..].to_vec()
76 }
77 TruncationStrategy::First(n) => {
78 messages[..messages.len().min(*n)].to_vec()
79 }
80 TruncationStrategy::FirstLast { first, last } => {
81 let len = messages.len();
82 let first_end = (*first).min(len);
83 let last_start = len.saturating_sub(*last);
84 if first_end >= last_start {
85 messages.to_vec()
87 } else {
88 let mut result = messages[..first_end].to_vec();
89 result.extend_from_slice(&messages[last_start..]);
90 result
91 }
92 }
93 TruncationStrategy::SlidingWindow { max_tokens } => {
94 let total_tokens: usize =
97 messages.iter().map(|m| estimate_tokens(&m.content)).sum();
98 if total_tokens <= *max_tokens {
99 return messages.to_vec();
100 }
101 let mut tokens_used = total_tokens;
102 let mut drop_count = 0;
103 for msg in messages.iter() {
104 if tokens_used <= *max_tokens {
105 break;
106 }
107 tokens_used -= estimate_tokens(&msg.content);
108 drop_count += 1;
109 }
110 messages[drop_count..].to_vec()
111 }
112 }
113}
114
115fn to_rig_messages(history: &[ChatMessage]) -> Vec<Message> {
119 history
120 .iter()
121 .map(|m| match m.role.as_str() {
122 "assistant" => {
123 Message::from(AssistantContent::text(&m.content))
124 }
125 _ => {
126 Message::from(m.content.as_str())
128 }
129 })
130 .collect()
131}
132
133fn extract_chat_output<T>(response: CompletionResponse<T>, model: &str) -> ChatOutput {
135 let text = response
136 .choice
137 .iter()
138 .filter_map(|c| {
139 if let AssistantContent::Text(t) = c {
140 Some(t.text.clone())
141 } else {
142 None
143 }
144 })
145 .collect::<Vec<_>>()
146 .join("\n");
147
148 ChatOutput {
149 response: text,
150 model: model.to_string(),
151 input_tokens: response.usage.input_tokens,
152 output_tokens: response.usage.output_tokens,
153 }
154}
155
156fn map_rig_error(provider: &str, err: CompletionError) -> StepError {
158 StepError::Fail(format!("{} API error: {}", provider, err))
159}
160
161fn map_build_error(provider: &str, err: impl std::fmt::Display) -> StepError {
163 StepError::Fail(format!("Failed to build {} client: {}", provider, err))
164}
165
166macro_rules! send_completion {
169 ($client:expr, $model_name:expr, $prompt:expr, $messages:expr,
170 $temperature:expr, $max_tokens:expr, $provider:expr) => {{
171 let model = $client.completion_model($model_name);
172 let resp: Result<_, CompletionError> = model
173 .completion_request($prompt)
174 .messages($messages)
175 .temperature($temperature)
176 .max_tokens($max_tokens)
177 .send()
178 .await;
179 let resp = resp.map_err(|e| map_rig_error($provider, e))?;
180 Ok::<StepOutput, StepError>(StepOutput::Chat(extract_chat_output(resp, $model_name)))
181 }};
182}
183
184#[allow(clippy::too_many_arguments)]
186async fn call_via_rig(
187 provider: &str,
188 model_name: &str,
189 api_key: &str,
190 base_url: Option<&str>,
191 messages: Vec<Message>,
192 prompt: &str,
193 temperature: f64,
194 max_tokens: u64,
195 timeout: Duration,
196) -> Result<StepOutput, StepError> {
197 tokio::time::timeout(timeout, async {
198 match provider {
199 "anthropic" => {
201 let mut builder = rig::providers::anthropic::Client::builder()
202 .api_key(api_key);
203 if let Some(url) = base_url {
204 builder = builder.base_url(url);
205 }
206 let client = builder.build().map_err(|e| map_build_error("anthropic", e))?;
207 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "anthropic")
208 }
209
210 "openai" => {
212 let mut builder = rig::providers::openai::CompletionsClient::builder()
213 .api_key(api_key);
214 if let Some(url) = base_url {
215 builder = builder.base_url(url);
216 }
217 let client = builder.build().map_err(|e| map_build_error("openai", e))?;
218 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "openai")
219 }
220
221 "ollama" => {
223 let mut builder = rig::providers::ollama::Client::builder()
224 .api_key(rig::client::Nothing);
225 let url = base_url.unwrap_or("http://localhost:11434");
226 builder = builder.base_url(url);
227 let client = builder.build().map_err(|e| map_build_error("ollama", e))?;
228 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "ollama")
229 }
230
231 "groq" => {
233 let mut builder = rig::providers::groq::Client::builder()
234 .api_key(api_key);
235 if let Some(url) = base_url {
236 builder = builder.base_url(url);
237 }
238 let client = builder.build().map_err(|e| map_build_error("groq", e))?;
239 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "groq")
240 }
241
242 "deepseek" => {
244 let mut builder = rig::providers::deepseek::Client::builder()
245 .api_key(api_key);
246 if let Some(url) = base_url {
247 builder = builder.base_url(url);
248 }
249 let client = builder.build().map_err(|e| map_build_error("deepseek", e))?;
250 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "deepseek")
251 }
252
253 "gemini" | "google" => {
255 let mut builder = rig::providers::gemini::Client::builder()
256 .api_key(api_key);
257 if let Some(url) = base_url {
258 builder = builder.base_url(url);
259 }
260 let client = builder.build().map_err(|e| map_build_error("gemini", e))?;
261 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "gemini")
262 }
263
264 "cohere" => {
266 let mut builder = rig::providers::cohere::Client::builder()
267 .api_key(api_key);
268 if let Some(url) = base_url {
269 builder = builder.base_url(url);
270 }
271 let client = builder.build().map_err(|e| map_build_error("cohere", e))?;
272 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "cohere")
273 }
274
275 "perplexity" => {
277 let mut builder = rig::providers::perplexity::Client::builder()
278 .api_key(api_key);
279 if let Some(url) = base_url {
280 builder = builder.base_url(url);
281 }
282 let client = builder.build().map_err(|e| map_build_error("perplexity", e))?;
283 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "perplexity")
284 }
285
286 "xai" | "grok" => {
288 let mut builder = rig::providers::xai::Client::builder()
289 .api_key(api_key);
290 if let Some(url) = base_url {
291 builder = builder.base_url(url);
292 }
293 let client = builder.build().map_err(|e| map_build_error("xai", e))?;
294 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "xai")
295 }
296
297 "mistral" => {
299 let mut builder = rig::providers::mistral::Client::builder()
300 .api_key(api_key);
301 if let Some(url) = base_url {
302 builder = builder.base_url(url);
303 }
304 let client = builder.build().map_err(|e| map_build_error("mistral", e))?;
305 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, "mistral")
306 }
307
308 other => {
312 let url = base_url.ok_or_else(|| StepError::Fail(format!(
313 "Unknown provider '{}': set 'base_url' to use as OpenAI-compatible endpoint",
314 other
315 )))?;
316 let builder = rig::providers::openai::CompletionsClient::builder()
317 .api_key(api_key)
318 .base_url(url);
319 let client = builder.build().map_err(|e| map_build_error(other, e))?;
320 send_completion!(client, model_name, prompt, messages, temperature, max_tokens, other)
321 }
322 }
323 })
324 .await
325 .map_err(|_| StepError::Timeout(timeout))?
326}
327
328pub struct ChatExecutor;
331
332#[async_trait]
333impl StepExecutor for ChatExecutor {
334 async fn execute(
335 &self,
336 step: &StepDef,
337 config: &StepConfig,
338 ctx: &Context,
339 ) -> Result<StepOutput, StepError> {
340 let provider = config.get_str("provider").unwrap_or("anthropic");
341 let model = config.get_str("model").unwrap_or(match provider {
342 "openai" => "gpt-4o-mini",
343 "ollama" => "llama3.2",
344 "groq" => "llama-3.3-70b-versatile",
345 "deepseek" => "deepseek-chat",
346 "gemini" | "google" => "gemini-2.0-flash",
347 _ => "claude-3-haiku-20240307",
348 });
349 let max_tokens = config.get_u64("max_tokens").unwrap_or(1024);
350 let temperature = config
351 .values
352 .get("temperature")
353 .and_then(|v| v.as_f64())
354 .unwrap_or(0.0);
355 let timeout = config
356 .get_duration("timeout")
357 .unwrap_or(Duration::from_secs(120));
358
359 let api_key = if provider == "ollama" {
361 String::new()
362 } else {
363 let api_key_env = config.get_str("api_key_env").unwrap_or(match provider {
364 "openai" => "OPENAI_API_KEY",
365 "groq" => "GROQ_API_KEY",
366 "deepseek" => "DEEPSEEK_API_KEY",
367 "gemini" | "google" => "GEMINI_API_KEY",
368 "cohere" => "COHERE_API_KEY",
369 "perplexity" => "PERPLEXITY_API_KEY",
370 "xai" | "grok" => "XAI_API_KEY",
371 "mistral" => "MISTRAL_API_KEY",
372 _ => "ANTHROPIC_API_KEY",
373 });
374 std::env::var(api_key_env).map_err(|_| {
375 StepError::Fail(format!(
376 "API key not found: environment variable '{}' is not set",
377 api_key_env
378 ))
379 })?
380 };
381
382 let base_url: Option<String> = config
384 .get_str("base_url")
385 .map(String::from)
386 .or_else(|| {
387 match provider {
389 "anthropic" => config.get_str("anthropic_base_url").map(String::from),
390 "openai" => config.get_str("openai_base_url").map(String::from),
391 _ => None,
392 }
393 });
394
395 let prompt_template = step
396 .prompt
397 .as_ref()
398 .ok_or_else(|| StepError::Fail("chat step missing 'prompt' field".into()))?;
399
400 let prompt = ctx.render_template(prompt_template)?;
401
402 let session_name = config.get_str("session");
404 let truncation = TruncationStrategy::from_config(config);
405 let rig_messages: Vec<Message> = if let Some(session) = session_name {
406 let history = ctx.get_chat_messages(session);
407 let truncated = truncate_messages(&history, &truncation);
408 to_rig_messages(&truncated)
409 } else {
410 Vec::new()
411 };
412
413 let output = call_via_rig(
414 provider,
415 model,
416 &api_key,
417 base_url.as_deref(),
418 rig_messages,
419 &prompt,
420 temperature,
421 max_tokens,
422 timeout,
423 )
424 .await?;
425
426 if let Some(session) = session_name {
428 let response_text = output.text().to_string();
429 ctx.append_chat_messages(
430 session,
431 vec![
432 ChatMessage { role: "user".to_string(), content: prompt },
433 ChatMessage { role: "assistant".to_string(), content: response_text },
434 ],
435 );
436 }
437
438 Ok(output)
439 }
440}
441
442#[cfg(test)]
445mod tests {
446 use super::*;
447 use std::collections::HashMap;
448
449 fn make_step(prompt: &str) -> StepDef {
450 StepDef {
451 name: "test_chat".to_string(),
452 step_type: crate::workflow::schema::StepType::Chat,
453 run: None,
454 prompt: Some(prompt.to_string()),
455 condition: None,
456 on_pass: None,
457 on_fail: None,
458 message: None,
459 scope: None,
460 max_iterations: None,
461 initial_value: None,
462 items: None,
463 parallel: None,
464 steps: None,
465 config: HashMap::new(),
466 outputs: None,
467 output_type: None,
468 async_exec: None,
469 }
470 }
471
472 #[tokio::test]
473 async fn chat_missing_api_key_friendly_error() {
474 let step = StepDef {
476 name: "test_chat".to_string(),
477 step_type: crate::workflow::schema::StepType::Chat,
478 run: None,
479 prompt: Some("Hello".to_string()),
480 condition: None,
481 on_pass: None,
482 on_fail: None,
483 message: None,
484 scope: None,
485 max_iterations: None,
486 initial_value: None,
487 items: None,
488 parallel: None,
489 steps: None,
490 config: HashMap::new(),
491 outputs: None,
492 output_type: None,
493 async_exec: None,
494 };
495 let mut config_values = HashMap::new();
497 config_values.insert(
498 "api_key_env".to_string(),
499 serde_json::Value::String("DEFINITELY_NOT_SET_API_KEY_XYZ123".to_string()),
500 );
501 let config = StepConfig { values: config_values };
502 let ctx = Context::new(String::new(), HashMap::new());
503 let result = ChatExecutor.execute(&step, &config, &ctx).await;
504 assert!(result.is_err());
505 let err = result.unwrap_err().to_string();
506 assert!(
507 err.contains("DEFINITELY_NOT_SET_API_KEY_XYZ123"),
508 "Error should mention env var name: {}", err
509 );
510 }
511
512 #[tokio::test]
513 async fn chat_missing_prompt_field_error() {
514 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
515 let step = StepDef {
516 name: "test".to_string(),
517 step_type: crate::workflow::schema::StepType::Chat,
518 run: None,
519 prompt: None, condition: None,
521 on_pass: None,
522 on_fail: None,
523 message: None,
524 scope: None,
525 max_iterations: None,
526 initial_value: None,
527 items: None,
528 parallel: None,
529 steps: None,
530 config: HashMap::new(),
531 outputs: None,
532 output_type: None,
533 async_exec: None,
534 };
535 let config = StepConfig::default();
536 let ctx = Context::new(String::new(), HashMap::new());
537 let result = ChatExecutor.execute(&step, &config, &ctx).await;
538 assert!(result.is_err());
539 let err = result.unwrap_err().to_string();
540 assert!(err.contains("prompt"), "Error should mention prompt: {}", err);
541 }
542
543 #[tokio::test]
544 async fn chat_mock_anthropic_response() {
545 use wiremock::{MockServer, Mock, ResponseTemplate};
548 use wiremock::matchers::{method, path};
549
550 let mock_server = MockServer::start().await;
551 let response_body = serde_json::json!({
552 "id": "msg_mock123",
553 "type": "message",
554 "role": "assistant",
555 "model": "claude-3-haiku-20240307",
556 "content": [{"type": "text", "text": "Hello from mock!"}],
557 "usage": {"input_tokens": 10, "output_tokens": 5},
558 "stop_reason": "end_turn",
559 "stop_sequence": null
560 });
561
562 Mock::given(method("POST"))
563 .and(path("/v1/messages"))
564 .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
565 .mount(&mock_server)
566 .await;
567
568 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
569
570 let step = make_step("Hello");
571 let mut config_values = HashMap::new();
572 config_values.insert(
574 "base_url".to_string(),
575 serde_json::Value::String(mock_server.uri()),
576 );
577 let config = StepConfig { values: config_values };
578 let ctx = Context::new(String::new(), HashMap::new());
579
580 let result = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
581 assert_eq!(result.text(), "Hello from mock!");
582 if let StepOutput::Chat(o) = result {
583 assert_eq!(o.model, "claude-3-haiku-20240307");
584 assert_eq!(o.input_tokens, 10);
585 assert_eq!(o.output_tokens, 5);
586 } else {
587 panic!("Expected Chat output");
588 }
589 }
590
591 fn make_messages(count: usize) -> Vec<ChatMessage> {
592 (0..count)
593 .map(|i| ChatMessage {
594 role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
595 content: format!("message {}", i),
596 })
597 .collect()
598 }
599
600 #[test]
601 fn truncation_last_keeps_n_messages() {
602 let msgs = make_messages(50);
603 let result = truncate_messages(&msgs, &TruncationStrategy::Last(10));
604 assert_eq!(result.len(), 10);
605 assert_eq!(result[0].content, "message 40");
606 assert_eq!(result[9].content, "message 49");
607 }
608
609 #[test]
610 fn truncation_first_last_keeps_first_and_last() {
611 let msgs = make_messages(50);
612 let result =
613 truncate_messages(&msgs, &TruncationStrategy::FirstLast { first: 2, last: 5 });
614 assert_eq!(result.len(), 7);
615 assert_eq!(result[0].content, "message 0");
616 assert_eq!(result[1].content, "message 1");
617 assert_eq!(result[2].content, "message 45");
618 }
619
620 #[test]
621 fn truncation_sliding_window_fits_within_tokens() {
622 let msgs = make_messages(50);
625 let result =
626 truncate_messages(&msgs, &TruncationStrategy::SlidingWindow { max_tokens: 50 });
627 let total: usize = result.iter().map(|m| estimate_tokens(&m.content)).sum();
629 assert!(total <= 50, "Expected tokens <= 50, got {}", total);
630 }
631
632 #[test]
633 fn truncation_none_returns_all() {
634 let msgs = make_messages(10);
635 let result = truncate_messages(&msgs, &TruncationStrategy::None);
636 assert_eq!(result.len(), 10);
637 }
638
639 #[tokio::test]
640 async fn chat_history_stores_messages_and_resends_on_second_call() {
641 use wiremock::{Mock, MockServer, ResponseTemplate};
642 use wiremock::matchers::{method, path};
643
644 let mock_server = MockServer::start().await;
645 let response_body = serde_json::json!({
646 "id": "msg_mock456",
647 "type": "message",
648 "role": "assistant",
649 "model": "claude-3-haiku-20240307",
650 "content": [{"type": "text", "text": "Response text"}],
651 "usage": {"input_tokens": 10, "output_tokens": 5},
652 "stop_reason": "end_turn",
653 "stop_sequence": null
654 });
655
656 Mock::given(method("POST"))
657 .and(path("/v1/messages"))
658 .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
659 .expect(2)
660 .mount(&mock_server)
661 .await;
662
663 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
664
665 let step = make_step("First message");
666 let mut config_values = HashMap::new();
667 config_values.insert(
668 "base_url".to_string(),
669 serde_json::Value::String(mock_server.uri()),
670 );
671 config_values.insert(
672 "session".to_string(),
673 serde_json::Value::String("review".to_string()),
674 );
675 let config = StepConfig { values: config_values };
676 let ctx = Context::new(String::new(), HashMap::new());
677
678 let _result1 = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
680
681 let history = ctx.get_chat_messages("review");
683 assert_eq!(history.len(), 2);
684 assert_eq!(history[0].role, "user");
685 assert_eq!(history[0].content, "First message");
686 assert_eq!(history[1].role, "assistant");
687
688 let step2 = make_step("Second message");
690 let _result2 = ChatExecutor.execute(&step2, &config, &ctx).await.unwrap();
691
692 let history2 = ctx.get_chat_messages("review");
694 assert_eq!(history2.len(), 4);
695 }
696
697 #[test]
698 fn to_rig_messages_converts_correctly() {
699 let history = vec![
700 ChatMessage { role: "user".to_string(), content: "Hello".to_string() },
701 ChatMessage { role: "assistant".to_string(), content: "Hi!".to_string() },
702 ChatMessage { role: "user".to_string(), content: "How are you?".to_string() },
703 ];
704 let rig_msgs = to_rig_messages(&history);
705 assert_eq!(rig_msgs.len(), 3);
706
707 match &rig_msgs[0] {
709 Message::User { .. } => {},
710 _ => panic!("Expected User message at index 0"),
711 }
712
713 match &rig_msgs[1] {
715 Message::Assistant { .. } => {},
716 _ => panic!("Expected Assistant message at index 1"),
717 }
718 }
719
720 #[test]
721 fn ollama_does_not_require_api_key() {
722 let step = make_step("Hello");
724 let mut config_values = HashMap::new();
725 config_values.insert(
726 "provider".to_string(),
727 serde_json::Value::String("ollama".to_string()),
728 );
729 let config = StepConfig { values: config_values };
731 let ctx = Context::new(String::new(), HashMap::new());
732
733 let rt = tokio::runtime::Runtime::new().unwrap();
738 let result = rt.block_on(ChatExecutor.execute(&step, &config, &ctx));
739 assert!(result.is_err());
740 let err = result.unwrap_err().to_string();
741 assert!(
743 !err.contains("API key not found"),
744 "Ollama should not require API key, but got: {}",
745 err
746 );
747 }
748}