use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Message {
pub role: MessageRole,
pub content: MessageContent,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: MessageContent::Text(content.into()),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: MessageContent::Text(content.into()),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: MessageContent::Text(content.into()),
}
}
#[must_use]
pub fn assistant_with_tools(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: MessageRole::Assistant,
content: MessageContent::ToolCalls(tool_calls),
}
}
#[must_use]
pub fn tool_result(result: ToolCallResult) -> Self {
Self {
role: MessageRole::Tool,
content: MessageContent::ToolResult(result),
}
}
#[must_use]
pub fn text(&self) -> Option<&str> {
match &self.content {
MessageContent::Text(s) => Some(s),
_ => None,
}
}
#[must_use]
pub fn tool_calls(&self) -> Option<&[ToolCall]> {
match &self.content {
MessageContent::ToolCalls(calls) => Some(calls),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
ToolCalls(Vec<ToolCall>),
ToolResult(ToolCallResult),
MultiPart(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
Image {
data: String,
media_type: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
impl ToolCall {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
arguments: Value::Object(serde_json::Map::new()),
}
}
#[must_use]
pub fn with_arguments(mut self, args: Value) -> Self {
self.arguments = args;
self
}
#[must_use]
pub fn parse_name(&self) -> Option<(&str, &str)> {
self.name.split_once(':')
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCallResult {
pub call_id: String,
pub content: String,
#[serde(default)]
pub is_error: bool,
}
impl ToolCallResult {
pub fn success(call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
call_id: call_id.into(),
content: content.into(),
is_error: false,
}
}
pub fn error(call_id: impl Into<String>, error: impl Into<String>) -> Self {
Self {
call_id: call_id.into(),
content: error.into(),
is_error: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LlmToolDefinition {
pub name: String,
pub description: Option<String>,
pub input_schema: Value,
}
impl LlmToolDefinition {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
input_schema: serde_json::json!({"type": "object"}),
}
}
#[must_use]
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
#[must_use]
pub fn with_schema(mut self, schema: Value) -> Self {
self.input_schema = schema;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum StreamEvent {
TextDelta(String),
ToolCallStart {
id: String,
name: String,
},
ToolCallDelta {
id: String,
args_delta: String,
},
ToolCallEnd {
id: String,
},
ReasoningDelta(String),
Usage {
input_tokens: usize,
output_tokens: usize,
},
Done,
Error(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LlmResponse {
pub message: Message,
pub has_tool_calls: bool,
pub stop_reason: StopReason,
pub usage: Usage,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
StopSequence,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct Usage {
pub input_tokens: usize,
pub output_tokens: usize,
}
impl Usage {
#[must_use]
pub fn total(&self) -> usize {
self.input_tokens.saturating_add(self.output_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let user = Message::user("Hello");
assert_eq!(user.role, MessageRole::User);
assert_eq!(user.text(), Some("Hello"));
let assistant = Message::assistant("Hi there!");
assert_eq!(assistant.role, MessageRole::Assistant);
}
#[test]
fn test_tool_call() {
let call = ToolCall::new("123", "filesystem:read_file")
.with_arguments(serde_json::json!({"path": "/tmp/test.txt"}));
assert_eq!(call.parse_name(), Some(("filesystem", "read_file")));
}
#[test]
fn test_tool_result() {
let success = ToolCallResult::success("123", "file contents");
assert!(!success.is_error);
let error = ToolCallResult::error("123", "file not found");
assert!(error.is_error);
}
}