#[cfg(test)]
mod tests {
use crate::types::{FunctionCall, MessageRole};
use crate::util::merge_chunk_message;
use crate::validation::validate_api_request;
use crate::{Agent, Instructions, Message, SwarmError};
use serde_json::json;
fn test_agent() -> Agent {
Agent::new(
"message_test_agent",
"gpt-4",
Instructions::Text("Validate message shapes".to_string()),
)
.expect("Failed to create test agent")
}
#[test]
fn test_message_constructors_cover_valid_roles() {
let system = Message::system("System prompt").expect("Expected valid system message");
let user = Message::user("Hello").expect("Expected valid user message");
let assistant = Message::assistant("Hi").expect("Expected valid assistant message");
let function =
Message::function("lookup_docs", "{\"ok\":true}").expect("Expected function message");
let function_call = FunctionCall::new("lookup_docs", "{\"query\":\"rust\"}")
.expect("Expected valid function call");
let assistant_call = Message::assistant_function_call(function_call.clone())
.expect("Expected valid assistant function-call message");
assert_eq!(system.role(), MessageRole::System);
assert_eq!(system.content(), Some("System prompt"));
assert_eq!(user.role(), MessageRole::User);
assert_eq!(assistant.role(), MessageRole::Assistant);
assert_eq!(assistant.content(), Some("Hi"));
assert_eq!(function.role(), MessageRole::Function);
assert_eq!(function.name(), Some("lookup_docs"));
assert_eq!(assistant_call.function_call(), Some(&function_call));
}
#[test]
fn test_message_deserialization_rejects_invalid_shapes() {
let double_payload = serde_json::from_value::<Message>(json!({
"role": "assistant",
"content": "hello",
"function_call": {
"name": "lookup_docs",
"arguments": "{}"
}
}))
.expect_err("Assistant messages cannot carry content and function_call");
assert!(
double_payload.to_string().contains("exactly one of")
|| double_payload
.to_string()
.contains("either content or a function call")
);
let invalid_role = serde_json::from_value::<Message>(json!({
"role": "moderator",
"content": "hello"
}))
.expect_err("Unknown roles should fail deserialization");
assert!(invalid_role.to_string().contains("unknown variant"));
let invalid_function_call = FunctionCall::new("lookup_docs", "not-json")
.expect_err("Function calls require JSON arguments");
assert!(matches!(
invalid_function_call,
SwarmError::ValidationError(_)
));
}
#[test]
fn test_validate_api_request_rejects_empty_history() {
let agent = test_agent();
let error = validate_api_request(&agent, &[], &None, 1)
.expect_err("empty history should fail preflight validation");
assert!(matches!(error, SwarmError::ValidationError(_)));
assert!(error.to_string().to_lowercase().contains("empty"));
}
#[test]
fn test_validate_api_request_rejects_structurally_invalid_messages() {
let agent = test_agent();
let invalid_empty_assistant =
Message::from_parts_unchecked(MessageRole::Assistant, None, None, None);
let invalid_function_without_name = Message::from_parts_unchecked(
MessageRole::Function,
Some("done".to_string()),
None,
None,
);
let invalid_system_function_call = Message::from_parts_unchecked(
MessageRole::System,
Some("system".to_string()),
None,
Some(FunctionCall::from_parts_unchecked(
"lookup_docs".to_string(),
"{}".to_string(),
)),
);
for message in [
invalid_empty_assistant,
invalid_function_without_name,
invalid_system_function_call,
] {
let error = validate_api_request(&agent, &[message], &None, 1)
.expect_err("Invalid message should fail request validation");
assert!(matches!(error, SwarmError::ValidationError(_)));
}
}
#[test]
fn test_tool_result_message_valid_and_serializes_correctly() {
let msg = Message::tool_result("call_abc123", "42").expect("tool_result should be valid");
assert_eq!(msg.role(), MessageRole::Tool);
assert_eq!(msg.content(), Some("42"));
assert_eq!(msg.tool_call_id(), Some("call_abc123"));
assert!(msg.function_call().is_none());
assert!(msg.tool_calls().is_none());
let json = serde_json::to_value(&msg).expect("serialize");
assert_eq!(json["role"], "tool");
assert_eq!(json["content"], "42");
assert_eq!(json["tool_call_id"], "call_abc123");
assert!(json.get("function_call").is_none() || json["function_call"].is_null());
}
#[test]
fn test_assistant_tool_calls_message_valid() {
use crate::types::{FunctionCall, ToolCall};
let fc = FunctionCall::new("my_tool", "{\"x\":1}").expect("fc");
let tc = ToolCall::new("call_xyz", fc).expect("tc");
let msg =
Message::assistant_tool_calls(vec![tc]).expect("assistant_tool_calls should be valid");
assert_eq!(msg.role(), MessageRole::Assistant);
assert!(msg.content().is_none());
assert!(msg.function_call().is_none());
let calls = msg.tool_calls().expect("should have tool_calls");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id(), "call_xyz");
assert_eq!(calls[0].function().name(), "my_tool");
}
#[test]
fn test_tool_message_requires_tool_call_id() {
let msg = Message::from_parts_unchecked(
MessageRole::Tool,
Some("some result".to_string()),
None,
None,
);
let err = msg
.validate()
.expect_err("Tool message without tool_call_id must fail");
assert!(err.to_string().contains("tool_call_id"));
}
#[test]
fn test_tool_message_requires_content() {
let err = serde_json::from_value::<Message>(serde_json::json!({
"role": "tool",
"tool_call_id": "call_abc"
}))
.expect_err("Tool message without content must fail validation");
assert!(err.to_string().contains("content"));
}
#[test]
fn test_assistant_tool_calls_rejects_empty_vec() {
Message::assistant_tool_calls(vec![])
.expect_err("Empty tool_calls vec must fail validation");
}
#[test]
fn test_assistant_cannot_mix_tool_calls_and_content() {
let err = serde_json::from_value::<Message>(serde_json::json!({
"role": "assistant",
"content": "hello",
"tool_calls": [{
"id": "c1",
"type": "function",
"function": {"name": "f", "arguments": "{}"}
}]
}))
.expect_err("Assistant message with both content and tool_calls must fail");
assert!(err.to_string().contains("exactly one"));
}
#[test]
fn test_tool_call_delta_streaming_accumulation() {
let mut msg = Message::from_parts_unchecked(MessageRole::Assistant, None, None, None);
let chunk1 = serde_json::json!({"index": 0, "id": "call_aaa", "type": "function",
"function": {"name": "weather", "arguments": ""}});
let chunk2 = serde_json::json!({"index": 0, "function": {"arguments": "{\"city\":\""}});
let chunk3 = serde_json::json!({"index": 0, "function": {"arguments": "London\"}"}});
let chunk4 = serde_json::json!({"index": 1, "id": "call_bbb", "type": "function",
"function": {"name": "stock", "arguments": "{}"}});
msg.merge_tool_call_delta(0, &chunk1);
msg.merge_tool_call_delta(0, &chunk2);
msg.merge_tool_call_delta(0, &chunk3);
msg.merge_tool_call_delta(1, &chunk4);
msg.finalize_tool_calls();
let calls = msg
.tool_calls()
.expect("should have 2 tool calls after finalization");
assert_eq!(calls.len(), 2, "expected 2 calls, got {}", calls.len());
assert_eq!(calls[0].id(), "call_aaa");
assert_eq!(calls[0].function().name(), "weather");
assert_eq!(calls[0].function().arguments(), "{\"city\":\"London\"}");
assert_eq!(calls[1].id(), "call_bbb");
assert_eq!(calls[1].function().name(), "stock");
}
#[test]
fn test_merge_chunk_message_accumulates_streamed_content_and_function_calls() {
let mut message = Message::from_parts_unchecked(MessageRole::Assistant, None, None, None);
let first_content = json!({ "content": "Hello" });
let second_content = json!({ "content": " world" });
let function_name = json!({ "function_call": { "name": "lookup_docs" } });
let function_args_1 = json!({ "function_call": { "arguments": "{\"query\":\"ru" } });
let function_args_2 = json!({ "function_call": { "arguments": "st\"}" } });
merge_chunk_message(&mut message, first_content.as_object().unwrap());
merge_chunk_message(&mut message, second_content.as_object().unwrap());
assert_eq!(message.content(), Some("Hello world"));
merge_chunk_message(&mut message, function_name.as_object().unwrap());
merge_chunk_message(&mut message, function_args_1.as_object().unwrap());
merge_chunk_message(&mut message, function_args_2.as_object().unwrap());
let function_call = message
.function_call()
.expect("Expected streamed function call fragments to accumulate");
assert_eq!(function_call.name(), "lookup_docs");
assert_eq!(function_call.arguments(), "{\"query\":\"rust\"}");
}
}