use std::pin::Pin;
use async_trait::async_trait;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::Result;
use crate::tool::Tool;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmConfig {
pub provider: String,
pub model: String,
pub temperature: f32,
pub max_tokens: u32,
}
impl LlmConfig {
pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model: model.into(),
temperature: 0.7,
max_tokens: 4096,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "role", rename_all = "snake_case")]
pub enum Message {
System { content: String },
User { content: String },
Assistant {
content: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: String,
},
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self::System {
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self::User {
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::Assistant {
content: Some(content.into()),
tool_calls: vec![],
}
}
pub fn assistant_with_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self::Assistant {
content: None,
tool_calls,
}
}
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self::ToolResult {
tool_call_id: tool_call_id.into(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
}
impl ToolDefinition {
pub fn from_tool(tool: &dyn Tool) -> Self {
Self {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub content: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub usage: TokenUsage,
}
#[derive(Debug, Clone)]
pub enum LlmChunk {
Text(String),
ToolCallStart { id: String, name: String },
ToolCallDelta { id: String, arguments_delta: String },
Done,
}
#[async_trait]
pub trait LlmProvider: Send + Sync {
async fn chat(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
config: &LlmConfig,
) -> Result<LlmResponse>;
async fn chat_stream(
&self,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
config: &LlmConfig,
) -> Result<Pin<Box<dyn Stream<Item = Result<LlmChunk>> + Send>>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_provider_is_object_safe() {
fn _assert_object_safe(_: &dyn LlmProvider) {}
}
#[test]
fn test_llm_config_new_defaults() {
let config = LlmConfig::new("openai", "gpt-4");
assert_eq!(config.provider, "openai");
assert_eq!(config.model, "gpt-4");
assert!((config.temperature - 0.7).abs() < f32::EPSILON);
assert_eq!(config.max_tokens, 4096);
}
#[test]
fn test_llm_config_serialization() {
let config = LlmConfig::new("anthropic", "claude-sonnet-4-6");
let json = serde_json::to_string(&config).unwrap();
let deserialized: LlmConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.provider, "anthropic");
assert_eq!(deserialized.model, "claude-sonnet-4-6");
}
#[test]
fn test_message_system_openai_format() {
let msg = Message::system("You are helpful");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "system");
assert_eq!(json["content"], "You are helpful");
assert_eq!(json.as_object().unwrap().len(), 2); }
#[test]
fn test_message_user_openai_format() {
let msg = Message::user("Hello");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "user");
assert_eq!(json["content"], "Hello");
}
#[test]
fn test_message_assistant_text_only_format() {
let msg = Message::assistant("The answer is 42");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "assistant");
assert_eq!(json["content"], "The answer is 42");
assert!(json.get("tool_calls").is_none());
}
#[test]
fn test_message_assistant_with_tool_calls_format() {
let msg = Message::assistant_with_tool_calls(vec![ToolCall {
id: "call_abc".into(),
name: "search".into(),
arguments: serde_json::json!({"query": "rust"}),
}]);
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "assistant");
assert!(json["content"].is_null());
assert_eq!(json["tool_calls"][0]["id"], "call_abc");
assert_eq!(json["tool_calls"][0]["name"], "search");
}
#[test]
fn test_message_tool_result_openai_format() {
let msg = Message::tool_result("call_abc", "Search results here");
let json = serde_json::to_value(&msg).unwrap();
assert_eq!(json["role"], "tool"); assert_eq!(json["tool_call_id"], "call_abc");
assert_eq!(json["content"], "Search results here");
}
#[test]
fn test_message_serde_roundtrip_all_variants() {
let messages = [
Message::system("Be helpful"),
Message::user("Hi"),
Message::assistant("Hello!"),
Message::assistant_with_tool_calls(vec![ToolCall {
id: "c1".into(),
name: "read".into(),
arguments: serde_json::json!({}),
}]),
Message::tool_result("c1", "file contents"),
];
for msg in &messages {
let json = serde_json::to_string(msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
let json2 = serde_json::to_string(&deserialized).unwrap();
assert_eq!(json, json2);
}
}
#[test]
fn test_message_deserialize_from_openai_response() {
let openai_json = r#"{"role": "assistant", "content": "Hello!", "tool_calls": []}"#;
let msg: Message = serde_json::from_str(openai_json).unwrap();
assert!(matches!(msg, Message::Assistant { content: Some(c), .. } if c == "Hello!"));
}
#[test]
fn test_message_deserialize_assistant_without_tool_calls() {
let openai_json = r#"{"role": "assistant", "content": "Hello!"}"#;
let msg: Message = serde_json::from_str(openai_json).unwrap();
match msg {
Message::Assistant {
content,
tool_calls,
} => {
assert_eq!(content, Some("Hello!".into()));
assert!(tool_calls.is_empty()); }
_ => panic!("Expected Assistant"),
}
}
#[test]
fn test_message_convenience_constructors() {
assert!(matches!(Message::system("x"), Message::System { content } if content == "x"));
assert!(matches!(Message::user("y"), Message::User { content } if content == "y"));
assert!(
matches!(Message::assistant("z"), Message::Assistant { content: Some(c), tool_calls } if c == "z" && tool_calls.is_empty())
);
assert!(
matches!(Message::assistant_with_tool_calls(vec![]), Message::Assistant { content: None, tool_calls } if tool_calls.is_empty())
);
assert!(
matches!(Message::tool_result("id", "res"), Message::ToolResult { tool_call_id, content } if tool_call_id == "id" && content == "res")
);
}
#[test]
fn test_tool_definition_from_tool() {
use crate::error::PulseHiveError;
use crate::tool::{ToolContext, ToolResult};
struct MockTool;
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"mock_tool"
}
fn description(&self) -> &str {
"A mock tool for testing"
}
fn parameters(&self) -> Value {
serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}})
}
async fn execute(
&self,
_params: Value,
_ctx: &ToolContext,
) -> std::result::Result<ToolResult, PulseHiveError> {
Ok(ToolResult::text("ok"))
}
}
let def = ToolDefinition::from_tool(&MockTool);
assert_eq!(def.name, "mock_tool");
assert_eq!(def.description, "A mock tool for testing");
assert_eq!(def.parameters["type"], "object");
}
#[test]
fn test_multi_turn_conversation_serialization() {
let conversation = [
Message::system("You are a code assistant."),
Message::user("Read the config file."),
Message::assistant_with_tool_calls(vec![ToolCall {
id: "call_1".into(),
name: "read_file".into(),
arguments: serde_json::json!({"path": "config.toml"}),
}]),
Message::tool_result("call_1", "[package]\nname = \"test\""),
Message::assistant("The config file defines a package named 'test'."),
];
for msg in &conversation {
let json = serde_json::to_value(msg).unwrap();
assert!(json.get("role").is_some(), "Missing role field");
}
assert_eq!(conversation.len(), 5);
}
#[test]
fn test_tool_definition_construction() {
let tool = ToolDefinition {
name: "search".into(),
description: "Search the codebase".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
}),
};
assert_eq!(tool.name, "search");
}
#[test]
fn test_token_usage_default() {
let usage = TokenUsage::default();
assert_eq!(usage.input_tokens, 0);
assert_eq!(usage.output_tokens, 0);
}
#[test]
fn test_llm_chunk_variants() {
let text = LlmChunk::Text("hello".into());
assert!(matches!(text, LlmChunk::Text(s) if s == "hello"));
let start = LlmChunk::ToolCallStart {
id: "1".into(),
name: "search".into(),
};
assert!(matches!(start, LlmChunk::ToolCallStart { .. }));
let delta = LlmChunk::ToolCallDelta {
id: "1".into(),
arguments_delta: "{\"q".into(),
};
assert!(matches!(delta, LlmChunk::ToolCallDelta { .. }));
let done = LlmChunk::Done;
assert!(matches!(done, LlmChunk::Done));
}
}