use serde_json::json;
use crate::chain::ChainError;
use crate::prompt::PromptArgs;
use crate::schemas::messages::Message;
use crate::schemas::MessageType;
pub fn convert_messages_to_prompt_args(
input_variables: PromptArgs,
) -> Result<PromptArgs, ChainError> {
let messages_value = input_variables
.get("messages")
.ok_or_else(|| ChainError::OtherError("Missing 'messages' key".to_string()))?;
let messages: Vec<Message> = serde_json::from_value(messages_value.clone())
.map_err(|e| ChainError::OtherError(format!("Failed to parse messages: {}", e)))?;
let input = messages
.iter()
.rev()
.find(|m| matches!(m.message_type, MessageType::HumanMessage))
.map(|m| m.content.clone())
.unwrap_or_else(|| {
messages
.last()
.map(|m| m.content.clone())
.unwrap_or_default()
});
let mut prompt_args = PromptArgs::new();
prompt_args.insert("input".to_string(), json!(input));
if input_variables.contains_key("chat_history") {
prompt_args.insert(
"chat_history".to_string(),
input_variables["chat_history"].clone(),
);
} else {
prompt_args.insert("chat_history".to_string(), json!(messages));
}
for (key, value) in input_variables {
if key != "messages" && key != "chat_history" {
prompt_args.insert(key, value);
}
}
Ok(prompt_args)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prompt_args;
use crate::schemas::Message;
#[test]
fn test_convert_messages_with_human_message() {
let messages = vec![
Message::new_system_message("You are helpful"),
Message::new_human_message("Hello"),
Message::new_ai_message("Hi there!"),
];
let input = prompt_args! {
"messages" => messages
};
let result = convert_messages_to_prompt_args(input);
assert!(result.is_ok());
let args = result.unwrap();
assert_eq!(args["input"], json!("Hello"));
assert!(args.contains_key("chat_history"));
}
#[test]
fn test_convert_messages_without_human_message() {
let messages = vec![
Message::new_system_message("You are helpful"),
Message::new_ai_message("I am an AI"),
];
let input = prompt_args! {
"messages" => messages
};
let result = convert_messages_to_prompt_args(input);
assert!(result.is_ok());
let args = result.unwrap();
assert_eq!(args["input"], json!("I am an AI"));
}
#[test]
fn test_convert_messages_preserves_other_keys() {
let messages = vec![Message::new_human_message("Hello")];
let input = prompt_args! {
"messages" => messages,
"custom_key" => "custom_value"
};
let result = convert_messages_to_prompt_args(input);
assert!(result.is_ok());
let args = result.unwrap();
assert_eq!(args["custom_key"], json!("custom_value"));
}
#[test]
fn test_convert_messages_missing_key() {
let input = prompt_args! {
"other_key" => "value"
};
let result = convert_messages_to_prompt_args(input);
assert!(result.is_err());
}
}