use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientMessage {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerMessage {
pub id: String,
#[serde(rename = "type")]
pub message_type: MessageType,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub meta: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
Assistant,
ToolUse,
ToolResult,
Cost,
Error,
Status,
Thinking,
}
impl ClientMessage {
pub fn validate(&self) -> Result<(), String> {
if self.content.trim().is_empty() {
return Err("Message content cannot be empty".to_string());
}
if self.content.len() > 10 * 1024 * 1024 {
return Err("Message content exceeds maximum size (10MB)".to_string());
}
Ok(())
}
}
impl ServerMessage {
pub fn new(message_type: MessageType, content: String, session_id: Option<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
message_type,
content,
meta: None,
session_id,
}
}
pub fn with_metadata(
message_type: MessageType,
content: String,
meta: Value,
session_id: Option<String>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
message_type,
content,
meta: Some(meta),
session_id,
}
}
pub fn error(message: String) -> Self {
Self::new(MessageType::Error, message, None)
}
pub fn status(message: String, session_id: Option<String>) -> Self {
Self::new(MessageType::Status, message, session_id)
}
pub fn validate(&self) -> Result<(), String> {
if self.id.trim().is_empty() {
return Err("Message ID cannot be empty".to_string());
}
if self.content.len() > 10 * 1024 * 1024 {
return Err("Message content exceeds maximum size (10MB)".to_string());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_client_message_serialization() {
let msg = ClientMessage {
content: "Hello".to_string(),
session_id: None,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("Hello"));
assert!(!json.contains("session_id")); }
#[test]
fn test_client_message_deserialization() {
let json = r#"{"content":"Hello","session_id":"sess_123"}"#;
let msg: ClientMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content, "Hello");
assert_eq!(msg.session_id, Some("sess_123".to_string()));
}
#[test]
fn test_client_message_validation() {
let msg = ClientMessage {
content: "Hello".to_string(),
session_id: None,
};
assert!(msg.validate().is_ok());
let msg = ClientMessage {
content: "".to_string(),
session_id: None,
};
assert!(msg.validate().is_err());
let msg = ClientMessage {
content: "x".repeat(11 * 1024 * 1024),
session_id: None,
};
assert!(msg.validate().is_err());
}
#[test]
fn test_server_message_serialization() {
let msg = ServerMessage::new(
MessageType::Assistant,
"Response".to_string(),
Some("sess_123".to_string()),
);
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("assistant"));
assert!(json.contains("Response"));
assert!(json.contains("sess_123"));
}
#[test]
fn test_server_message_with_metadata() {
let meta = json!({
"tokens": 100,
"cost": 0.002
});
let msg = ServerMessage::with_metadata(
MessageType::Cost,
"Cost info".to_string(),
meta,
Some("sess_123".to_string()),
);
assert_eq!(msg.message_type, MessageType::Cost);
assert!(msg.meta.is_some());
}
#[test]
fn test_message_type_serialization() {
assert_eq!(
serde_json::to_string(&MessageType::Assistant).unwrap(),
"\"assistant\""
);
assert_eq!(
serde_json::to_string(&MessageType::ToolUse).unwrap(),
"\"tool_use\""
);
assert_eq!(
serde_json::to_string(&MessageType::Error).unwrap(),
"\"error\""
);
}
#[test]
fn test_round_trip_serialization() {
let original = ClientMessage {
content: "Test message".to_string(),
session_id: Some("sess_123".to_string()),
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: ClientMessage = serde_json::from_str(&json).unwrap();
assert_eq!(original.content, deserialized.content);
assert_eq!(original.session_id, deserialized.session_id);
}
}