1use std::time::Duration;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6
7use crate::config::StepConfig;
8use crate::engine::context::{ChatMessage, Context};
9use crate::error::StepError;
10use crate::workflow::schema::StepDef;
11
12use super::{ChatOutput, StepExecutor, StepOutput};
13
14#[derive(Debug, Clone)]
16pub enum TruncationStrategy {
17 None,
19 Last(usize),
21 First(usize),
23 FirstLast { first: usize, last: usize },
25 SlidingWindow { max_tokens: usize },
27}
28
29impl TruncationStrategy {
30 pub fn from_config(config: &crate::config::StepConfig) -> Self {
33 match config.get_str("truncation_strategy") {
34 Some("last") => {
35 let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
36 TruncationStrategy::Last(n)
37 }
38 Some("first") => {
39 let n = config.get_u64("truncation_count").unwrap_or(10) as usize;
40 TruncationStrategy::First(n)
41 }
42 Some("first_last") => {
43 let first = config.get_u64("truncation_first").unwrap_or(2) as usize;
44 let last = config.get_u64("truncation_last").unwrap_or(5) as usize;
45 TruncationStrategy::FirstLast { first, last }
46 }
47 Some("sliding_window") => {
48 let max_tokens =
49 config.get_u64("truncation_max_tokens").unwrap_or(50_000) as usize;
50 TruncationStrategy::SlidingWindow { max_tokens }
51 }
52 _ => TruncationStrategy::None,
53 }
54 }
55}
56
57fn estimate_tokens(text: &str) -> usize {
59 let words = text.split_whitespace().count();
60 ((words as f64) * 1.3).ceil() as usize
61}
62
63pub fn truncate_messages(
65 messages: &[ChatMessage],
66 strategy: &TruncationStrategy,
67) -> Vec<ChatMessage> {
68 match strategy {
69 TruncationStrategy::None => messages.to_vec(),
70 TruncationStrategy::Last(n) => {
71 let start = messages.len().saturating_sub(*n);
72 messages[start..].to_vec()
73 }
74 TruncationStrategy::First(n) => {
75 messages[..messages.len().min(*n)].to_vec()
76 }
77 TruncationStrategy::FirstLast { first, last } => {
78 let len = messages.len();
79 let first_end = (*first).min(len);
80 let last_start = len.saturating_sub(*last);
81 if first_end >= last_start {
82 messages.to_vec()
84 } else {
85 let mut result = messages[..first_end].to_vec();
86 result.extend_from_slice(&messages[last_start..]);
87 result
88 }
89 }
90 TruncationStrategy::SlidingWindow { max_tokens } => {
91 let total_tokens: usize =
94 messages.iter().map(|m| estimate_tokens(&m.content)).sum();
95 if total_tokens <= *max_tokens {
96 return messages.to_vec();
97 }
98 let mut tokens_used = total_tokens;
99 let mut drop_count = 0;
100 for msg in messages.iter() {
101 if tokens_used <= *max_tokens {
102 break;
103 }
104 tokens_used -= estimate_tokens(&msg.content);
105 drop_count += 1;
106 }
107 messages[drop_count..].to_vec()
108 }
109 }
110}
111
112pub struct ChatExecutor;
113
114#[async_trait]
115impl StepExecutor for ChatExecutor {
116 async fn execute(
117 &self,
118 step: &StepDef,
119 config: &StepConfig,
120 ctx: &Context,
121 ) -> Result<StepOutput, StepError> {
122 let provider = config.get_str("provider").unwrap_or("anthropic");
123 let model = config.get_str("model").unwrap_or(match provider {
124 "openai" => "gpt-4o-mini",
125 _ => "claude-3-haiku-20240307",
126 });
127 let max_tokens = config.get_u64("max_tokens").unwrap_or(1024);
128 let temperature = config
129 .values
130 .get("temperature")
131 .and_then(|v| v.as_f64())
132 .unwrap_or(0.0);
133 let api_key_env = config.get_str("api_key_env").unwrap_or(match provider {
134 "openai" => "OPENAI_API_KEY",
135 _ => "ANTHROPIC_API_KEY",
136 });
137 let timeout = config
138 .get_duration("timeout")
139 .unwrap_or(Duration::from_secs(120));
140
141 let anthropic_base = config
143 .get_str("anthropic_base_url")
144 .unwrap_or("https://api.anthropic.com");
145 let openai_base = config
146 .get_str("openai_base_url")
147 .unwrap_or("https://api.openai.com");
148
149 let api_key = std::env::var(api_key_env).map_err(|_| {
150 StepError::Fail(format!(
151 "API key not found: environment variable '{}' is not set",
152 api_key_env
153 ))
154 })?;
155
156 let prompt_template = step
157 .prompt
158 .as_ref()
159 .ok_or_else(|| StepError::Fail("chat step missing 'prompt' field".into()))?;
160
161 let prompt = ctx.render_template(prompt_template)?;
162
163 let session_name = config.get_str("session");
165 let truncation = TruncationStrategy::from_config(config);
166 let mut messages: Vec<serde_json::Value> = if let Some(session) = session_name {
167 let history = ctx.get_chat_messages(session);
168 let truncated = truncate_messages(&history, &truncation);
169 truncated
170 .into_iter()
171 .map(|m| serde_json::json!({"role": m.role, "content": m.content}))
172 .collect()
173 } else {
174 Vec::new()
175 };
176 messages.push(serde_json::json!({"role": "user", "content": prompt}));
177
178 let client = Client::builder()
179 .timeout(timeout)
180 .build()
181 .map_err(|e| StepError::Fail(format!("Failed to create HTTP client: {e}")))?;
182
183 let output = match provider {
184 "openai" => {
185 let url = format!("{}/v1/chat/completions", openai_base);
186 call_openai(&client, &api_key, model, &messages, max_tokens, temperature, &url).await?
187 }
188 _ => {
189 let url = format!("{}/v1/messages", anthropic_base);
190 call_anthropic(&client, &api_key, model, &messages, max_tokens, temperature, &url).await?
191 }
192 };
193
194 if let Some(session) = session_name {
196 let response_text = output.text().to_string();
197 ctx.append_chat_messages(
198 session,
199 vec![
200 ChatMessage { role: "user".to_string(), content: prompt },
201 ChatMessage { role: "assistant".to_string(), content: response_text },
202 ],
203 );
204 }
205
206 Ok(output)
207 }
208}
209
210async fn call_anthropic(
211 client: &Client,
212 api_key: &str,
213 model: &str,
214 messages: &[serde_json::Value],
215 max_tokens: u64,
216 temperature: f64,
217 url: &str,
218) -> Result<StepOutput, StepError> {
219 let body = serde_json::json!({
220 "model": model,
221 "max_tokens": max_tokens,
222 "temperature": temperature,
223 "messages": messages,
224 });
225
226 let response = client
227 .post(url)
228 .header("x-api-key", api_key)
229 .header("anthropic-version", "2023-06-01")
230 .header("content-type", "application/json")
231 .json(&body)
232 .send()
233 .await
234 .map_err(|e| StepError::Fail(format!("Anthropic API request failed: {e}")))?;
235
236 if !response.status().is_success() {
237 let status = response.status();
238 let text = response.text().await.unwrap_or_default();
239 return Err(StepError::Fail(format!(
240 "Anthropic API error ({}): {}",
241 status, text
242 )));
243 }
244
245 #[derive(Deserialize)]
246 struct AnthropicResponse {
247 model: String,
248 content: Vec<AnthropicContent>,
249 usage: AnthropicUsage,
250 }
251 #[derive(Deserialize)]
252 struct AnthropicContent {
253 text: String,
254 }
255 #[derive(Deserialize)]
256 struct AnthropicUsage {
257 input_tokens: u64,
258 output_tokens: u64,
259 }
260
261 let resp: AnthropicResponse = response
262 .json()
263 .await
264 .map_err(|e| StepError::Fail(format!("Failed to parse Anthropic response: {e}")))?;
265
266 let text = resp
267 .content
268 .into_iter()
269 .map(|c| c.text)
270 .collect::<Vec<_>>()
271 .join("\n");
272
273 Ok(StepOutput::Chat(ChatOutput {
274 response: text,
275 model: resp.model,
276 input_tokens: resp.usage.input_tokens,
277 output_tokens: resp.usage.output_tokens,
278 }))
279}
280
281async fn call_openai(
282 client: &Client,
283 api_key: &str,
284 model: &str,
285 messages: &[serde_json::Value],
286 max_tokens: u64,
287 temperature: f64,
288 url: &str,
289) -> Result<StepOutput, StepError> {
290 let body = serde_json::json!({
291 "model": model,
292 "max_tokens": max_tokens,
293 "temperature": temperature,
294 "messages": messages,
295 });
296
297 let response = client
298 .post(url)
299 .header("Authorization", format!("Bearer {}", api_key))
300 .header("content-type", "application/json")
301 .json(&body)
302 .send()
303 .await
304 .map_err(|e| StepError::Fail(format!("OpenAI API request failed: {e}")))?;
305
306 if !response.status().is_success() {
307 let status = response.status();
308 let text = response.text().await.unwrap_or_default();
309 return Err(StepError::Fail(format!(
310 "OpenAI API error ({}): {}",
311 status, text
312 )));
313 }
314
315 #[derive(Deserialize)]
316 struct OpenAIResponse {
317 model: String,
318 choices: Vec<OpenAIChoice>,
319 usage: OpenAIUsage,
320 }
321 #[derive(Deserialize)]
322 struct OpenAIChoice {
323 message: OpenAIMessage,
324 }
325 #[derive(Deserialize)]
326 struct OpenAIMessage {
327 content: String,
328 }
329 #[derive(Deserialize)]
330 struct OpenAIUsage {
331 prompt_tokens: u64,
332 completion_tokens: u64,
333 }
334
335 let resp: OpenAIResponse = response
336 .json()
337 .await
338 .map_err(|e| StepError::Fail(format!("Failed to parse OpenAI response: {e}")))?;
339
340 let text = resp
341 .choices
342 .into_iter()
343 .map(|c| c.message.content)
344 .collect::<Vec<_>>()
345 .join("\n");
346
347 Ok(StepOutput::Chat(ChatOutput {
348 response: text,
349 model: resp.model,
350 input_tokens: resp.usage.prompt_tokens,
351 output_tokens: resp.usage.completion_tokens,
352 }))
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use std::collections::HashMap;
359
360 fn make_step(prompt: &str) -> StepDef {
361 StepDef {
362 name: "test_chat".to_string(),
363 step_type: crate::workflow::schema::StepType::Chat,
364 run: None,
365 prompt: Some(prompt.to_string()),
366 condition: None,
367 on_pass: None,
368 on_fail: None,
369 message: None,
370 scope: None,
371 max_iterations: None,
372 initial_value: None,
373 items: None,
374 parallel: None,
375 steps: None,
376 config: HashMap::new(),
377 outputs: None,
378 output_type: None,
379 async_exec: None,
380 }
381 }
382
383 #[tokio::test]
384 async fn chat_missing_api_key_friendly_error() {
385 let step = StepDef {
387 name: "test_chat".to_string(),
388 step_type: crate::workflow::schema::StepType::Chat,
389 run: None,
390 prompt: Some("Hello".to_string()),
391 condition: None,
392 on_pass: None,
393 on_fail: None,
394 message: None,
395 scope: None,
396 max_iterations: None,
397 initial_value: None,
398 items: None,
399 parallel: None,
400 steps: None,
401 config: HashMap::new(),
402 outputs: None,
403 output_type: None,
404 async_exec: None,
405 };
406 let mut config_values = HashMap::new();
408 config_values.insert(
409 "api_key_env".to_string(),
410 serde_json::Value::String("DEFINITELY_NOT_SET_API_KEY_XYZ123".to_string()),
411 );
412 let config = StepConfig { values: config_values };
413 let ctx = Context::new(String::new(), HashMap::new());
414 let result = ChatExecutor.execute(&step, &config, &ctx).await;
415 assert!(result.is_err());
416 let err = result.unwrap_err().to_string();
417 assert!(
418 err.contains("DEFINITELY_NOT_SET_API_KEY_XYZ123"),
419 "Error should mention env var name: {}", err
420 );
421 }
422
423 #[tokio::test]
424 async fn chat_missing_prompt_field_error() {
425 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
426 let step = StepDef {
427 name: "test".to_string(),
428 step_type: crate::workflow::schema::StepType::Chat,
429 run: None,
430 prompt: None, condition: None,
432 on_pass: None,
433 on_fail: None,
434 message: None,
435 scope: None,
436 max_iterations: None,
437 initial_value: None,
438 items: None,
439 parallel: None,
440 steps: None,
441 config: HashMap::new(),
442 outputs: None,
443 output_type: None,
444 async_exec: None,
445 };
446 let config = StepConfig::default();
447 let ctx = Context::new(String::new(), HashMap::new());
448 let result = ChatExecutor.execute(&step, &config, &ctx).await;
449 assert!(result.is_err());
450 let err = result.unwrap_err().to_string();
451 assert!(err.contains("prompt"), "Error should mention prompt: {}", err);
452 }
453
454 #[tokio::test]
455 async fn chat_mock_anthropic_response() {
456 use wiremock::{MockServer, Mock, ResponseTemplate};
458 use wiremock::matchers::{method, path};
459
460 let mock_server = MockServer::start().await;
461 let response_body = serde_json::json!({
462 "model": "claude-3-haiku-20240307",
463 "content": [{"type": "text", "text": "Hello from mock!"}],
464 "usage": {"input_tokens": 10, "output_tokens": 5}
465 });
466
467 Mock::given(method("POST"))
468 .and(path("/v1/messages"))
469 .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
470 .mount(&mock_server)
471 .await;
472
473 std::env::set_var("ANTHROPIC_API_KEY", "test-key");
474
475 let step = make_step("Hello");
476 let mut config_values = HashMap::new();
477 config_values.insert(
478 "anthropic_base_url".to_string(),
479 serde_json::Value::String(mock_server.uri()),
480 );
481 let config = StepConfig { values: config_values };
482 let ctx = Context::new(String::new(), HashMap::new());
483
484 let result = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
485 assert_eq!(result.text(), "Hello from mock!");
486 if let StepOutput::Chat(o) = result {
487 assert_eq!(o.model, "claude-3-haiku-20240307");
488 assert_eq!(o.input_tokens, 10);
489 assert_eq!(o.output_tokens, 5);
490 } else {
491 panic!("Expected Chat output");
492 }
493 }
494
495 fn make_messages(count: usize) -> Vec<ChatMessage> {
496 (0..count)
497 .map(|i| ChatMessage {
498 role: if i % 2 == 0 { "user".to_string() } else { "assistant".to_string() },
499 content: format!("message {}", i),
500 })
501 .collect()
502 }
503
504 #[test]
505 fn truncation_last_keeps_n_messages() {
506 let msgs = make_messages(50);
507 let result = truncate_messages(&msgs, &TruncationStrategy::Last(10));
508 assert_eq!(result.len(), 10);
509 assert_eq!(result[0].content, "message 40");
510 assert_eq!(result[9].content, "message 49");
511 }
512
513 #[test]
514 fn truncation_first_last_keeps_first_and_last() {
515 let msgs = make_messages(50);
516 let result =
517 truncate_messages(&msgs, &TruncationStrategy::FirstLast { first: 2, last: 5 });
518 assert_eq!(result.len(), 7);
519 assert_eq!(result[0].content, "message 0");
520 assert_eq!(result[1].content, "message 1");
521 assert_eq!(result[2].content, "message 45");
522 }
523
524 #[test]
525 fn truncation_sliding_window_fits_within_tokens() {
526 let msgs = make_messages(50);
529 let result =
530 truncate_messages(&msgs, &TruncationStrategy::SlidingWindow { max_tokens: 50 });
531 let total: usize = result.iter().map(|m| estimate_tokens(&m.content)).sum();
533 assert!(total <= 50, "Expected tokens <= 50, got {}", total);
534 }
535
536 #[test]
537 fn truncation_none_returns_all() {
538 let msgs = make_messages(10);
539 let result = truncate_messages(&msgs, &TruncationStrategy::None);
540 assert_eq!(result.len(), 10);
541 }
542
543 #[tokio::test]
544 async fn chat_history_stores_messages_and_resends_on_second_call() {
545 use wiremock::{Mock, MockServer, ResponseTemplate};
546 use wiremock::matchers::{method, path};
547
548 let mock_server = MockServer::start().await;
549 let response_body = serde_json::json!({
550 "model": "claude-3-haiku-20240307",
551 "content": [{"type": "text", "text": "Response text"}],
552 "usage": {"input_tokens": 10, "output_tokens": 5}
553 });
554
555 Mock::given(method("POST"))
556 .and(path("/v1/messages"))
557 .respond_with(ResponseTemplate::new(200).set_body_json(&response_body))
558 .expect(2)
559 .mount(&mock_server)
560 .await;
561
562 unsafe { std::env::set_var("ANTHROPIC_API_KEY", "test-key"); }
563
564 let step = make_step("First message");
565 let mut config_values = HashMap::new();
566 config_values.insert(
567 "anthropic_base_url".to_string(),
568 serde_json::Value::String(mock_server.uri()),
569 );
570 config_values.insert(
571 "session".to_string(),
572 serde_json::Value::String("review".to_string()),
573 );
574 let config = StepConfig { values: config_values };
575 let ctx = Context::new(String::new(), HashMap::new());
576
577 let _result1 = ChatExecutor.execute(&step, &config, &ctx).await.unwrap();
579
580 let history = ctx.get_chat_messages("review");
582 assert_eq!(history.len(), 2);
583 assert_eq!(history[0].role, "user");
584 assert_eq!(history[0].content, "First message");
585 assert_eq!(history[1].role, "assistant");
586
587 let step2 = make_step("Second message");
589 let _result2 = ChatExecutor.execute(&step2, &config, &ctx).await.unwrap();
590
591 let history2 = ctx.get_chat_messages("review");
593 assert_eq!(history2.len(), 4);
594 }
595}