use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "internal")]
pub mod internal;
#[cfg(feature = "internal")]
pub use internal::{
ConversionError, FromInternal, InternalTool, InternalToolCall, InternalToolResult, ToInternal,
TryFromInternal, TryToInternal,
};
#[cfg(feature = "mcp")]
pub mod mcp;
#[cfg(feature = "mcp")]
pub use mcp::{
McpContent, McpInputSchema, McpTool, McpToolAnnotations, McpToolCall, McpToolResult,
};
pub mod chatml;
pub use chatml::{ChatMLFormatter, ChatMLMessage, MessageRole as ChatMLMessageRole};
pub use chatml::count_tokens_for_text;
#[cfg(feature = "streaming")]
pub mod streaming;
#[cfg(feature = "streaming")]
pub use streaming::{AccumulatedResponse, StreamChunk, StreamingAccumulator};
pub mod events;
pub use events::{
Event, EventEnvelope, EventType, McpContext, MessageEvent, ModelInfo, ToolCall as EventToolCall,
ToolCallEvent, ToolCallStatus, ToolResult, ToolResultEvent,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InternalMessage {
pub role: MessageRole,
pub content: MessageContent,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl InternalMessage {
pub fn system(text: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: MessageContent::Text(text.into()),
metadata: HashMap::new(),
tool_call_id: None,
name: None,
}
}
pub fn user(text: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: MessageContent::Text(text.into()),
metadata: HashMap::new(),
tool_call_id: None,
name: None,
}
}
pub fn assistant(text: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: MessageContent::Text(text.into()),
metadata: HashMap::new(),
tool_call_id: None,
name: None,
}
}
pub fn tool(content: MessageContent) -> Self {
Self {
role: MessageRole::Tool,
content,
metadata: HashMap::new(),
tool_call_id: None,
name: None,
}
}
pub fn tool_result(
tool_call_id: impl Into<String>,
name: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
role: MessageRole::Tool,
content: MessageContent::Text(content.into()),
metadata: HashMap::new(),
tool_call_id: Some(tool_call_id.into()),
name: Some(name.into()),
}
}
pub fn assistant_with_tools(content: impl Into<String>, tool_calls: Vec<ContentBlock>) -> Self {
let mut blocks = vec![ContentBlock::text(content.into())];
blocks.extend(tool_calls);
Self {
role: MessageRole::Assistant,
content: MessageContent::Blocks(blocks),
metadata: HashMap::new(),
tool_call_id: None,
name: None,
}
}
pub fn text(&self) -> Option<&str> {
match &self.content {
MessageContent::Text(text) => Some(text),
_ => None,
}
}
pub fn blocks(&self) -> Option<&[ContentBlock]> {
match &self.content {
MessageContent::Blocks(blocks) => Some(blocks),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl MessageRole {
pub fn as_str(&self) -> &str {
match self {
Self::System => "system",
Self::User => "user",
Self::Assistant => "assistant",
Self::Tool => "tool",
}
}
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
impl MessageContent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text(text.into())
}
pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
Self::Blocks(blocks)
}
pub fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub fn is_blocks(&self) -> bool {
matches!(self, Self::Blocks(_))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
},
Image {
source: ImageSource,
},
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
ToolResult {
tool_use_id: String,
content: String,
},
}
impl ContentBlock {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image(source: ImageSource) -> Self {
Self::Image { source }
}
pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: serde_json::Value) -> Self {
Self::ToolUse {
id: id.into(),
name: name.into(),
input,
}
}
pub fn tool_result(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
Self::ToolResult {
tool_use_id: tool_use_id.into(),
content: content.into(),
}
}
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text { text } => Some(text),
_ => None,
}
}
pub fn as_tool_use(&self) -> Option<(&str, &str, &serde_json::Value)> {
match self {
Self::ToolUse { id, name, input } => Some((id, name, input)),
_ => None,
}
}
pub fn as_tool_result(&self) -> Option<(&str, &str)> {
match self {
Self::ToolResult { tool_use_id, content } => Some((tool_use_id, content)),
_ => None,
}
}
pub fn as_image(&self) -> Option<&ImageSource> {
match self {
Self::Image { source } => Some(source),
_ => None,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub r#type: String,
pub function: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
#[serde(rename = "type")]
pub r#type: String,
pub function: Function,
}
#[derive(Debug)]
pub enum GenerateResult {
Content {
text: String,
reasoning: Option<String>,
},
ToolCalls {
calls: Vec<ToolCall>,
content: Option<String>,
reasoning: Option<String>,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let msg = InternalMessage::system("You are a helpful assistant");
assert_eq!(msg.role, MessageRole::System);
assert_eq!(msg.text(), Some("You are a helpful assistant"));
let msg = InternalMessage::user("Hello");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.text(), Some("Hello"));
let msg = InternalMessage::assistant("Hi there!");
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.text(), Some("Hi there!"));
}
#[test]
fn test_content_blocks() {
let block = ContentBlock::text("Hello world");
assert_eq!(block.as_text(), Some("Hello world"));
let block = ContentBlock::tool_use(
"tool_123",
"get_weather",
serde_json::json!({"location": "SF"}),
);
let (id, name, input) = block.as_tool_use().unwrap();
assert_eq!(id, "tool_123");
assert_eq!(name, "get_weather");
assert_eq!(input["location"], "SF");
let block = ContentBlock::tool_result("tool_123", "72°F, sunny");
let (tool_use_id, content) = block.as_tool_result().unwrap();
assert_eq!(tool_use_id, "tool_123");
assert_eq!(content, "72°F, sunny");
}
#[test]
fn test_message_serialization() {
let msg = InternalMessage::user("Test message");
let json = serde_json::to_string(&msg).unwrap();
let deserialized: InternalMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.role, MessageRole::User);
assert_eq!(deserialized.text(), Some("Test message"));
}
#[test]
fn test_role_string_conversion() {
assert_eq!(MessageRole::System.as_str(), "system");
assert_eq!(MessageRole::User.as_str(), "user");
assert_eq!(MessageRole::Assistant.as_str(), "assistant");
assert_eq!(MessageRole::Tool.as_str(), "tool");
}
#[test]
fn test_text_block_matches_spec() {
let block = ContentBlock::text("Hello world");
let json = serde_json::to_value(&block).unwrap();
assert_eq!(json["type"], "text");
assert_eq!(json["text"], "Hello world");
let obj = json.as_object().unwrap();
assert_eq!(obj.len(), 2);
}
#[test]
fn test_tool_use_block_matches_spec() {
let block = ContentBlock::tool_use(
"call_123",
"search",
serde_json::json!({"query": "weather"}),
);
let json = serde_json::to_value(&block).unwrap();
assert_eq!(json["type"], "tool_use");
assert_eq!(json["id"], "call_123");
assert_eq!(json["name"], "search");
assert_eq!(json["input"]["query"], "weather");
let obj = json.as_object().unwrap();
assert_eq!(obj.len(), 4);
}
#[test]
fn test_tool_result_block_matches_spec() {
let block = ContentBlock::tool_result("call_123", "Result text");
let json = serde_json::to_value(&block).unwrap();
assert_eq!(json["type"], "tool_result");
assert_eq!(json["tool_use_id"], "call_123");
assert_eq!(json["content"], "Result text");
let obj = json.as_object().unwrap();
assert_eq!(obj.len(), 3);
}
#[test]
fn test_message_with_tool_call_id() {
let msg = InternalMessage::tool_result("call_123", "search", "Weather is sunny");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "tool");
assert_eq!(json["tool_call_id"], "call_123");
assert_eq!(json["name"], "search");
assert_eq!(json["content"], "Weather is sunny");
}
#[test]
fn test_full_message_roundtrip() {
let blocks = vec![
ContentBlock::text("I'll search for you"),
ContentBlock::tool_use("call_123", "search", serde_json::json!({"q": "test"})),
];
let msg = InternalMessage {
role: MessageRole::Assistant,
content: MessageContent::Blocks(blocks),
metadata: std::collections::HashMap::new(),
tool_call_id: None,
name: None,
};
let json = serde_json::to_string(&msg).unwrap();
let deserialized: InternalMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.role, MessageRole::Assistant);
if let MessageContent::Blocks(blocks) = deserialized.content {
assert_eq!(blocks.len(), 2);
assert!(matches!(blocks[0], ContentBlock::Text { .. }));
assert!(matches!(blocks[1], ContentBlock::ToolUse { .. }));
} else {
panic!("Expected blocks content");
}
}
#[test]
fn test_spec_compliance_full_example() {
let blocks = vec![
ContentBlock::text("I'll help you search"),
ContentBlock::tool_use(
"call_abc123",
"search",
serde_json::json!({"query": "weather"}),
),
];
let msg = InternalMessage {
role: MessageRole::Assistant,
content: MessageContent::Blocks(blocks),
metadata: std::collections::HashMap::new(),
tool_call_id: None,
name: None,
};
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "assistant");
let content = json["content"].as_array().unwrap();
assert_eq!(content.len(), 2);
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "I'll help you search");
assert_eq!(content[1]["type"], "tool_use");
assert_eq!(content[1]["id"], "call_abc123");
assert_eq!(content[1]["name"], "search");
assert_eq!(content[1]["input"]["query"], "weather");
}
#[test]
fn test_wasm_provider_can_parse() {
let msg = InternalMessage::tool_result("call_123", "search", "Result");
let json_str = serde_json::to_string(&msg).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
assert_eq!(parsed["role"].as_str(), Some("tool"));
assert_eq!(parsed["tool_call_id"].as_str(), Some("call_123"));
assert_eq!(parsed["name"].as_str(), Some("search"));
assert_eq!(parsed["content"].as_str(), Some("Result"));
}
}