#![deny(missing_docs)]
use std::fmt;
use std::pin::Pin;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl fmt::Display for MessageRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::System => write!(f, "system"),
Self::User => write!(f, "user"),
Self::Assistant => write!(f, "assistant"),
Self::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text {
text: String,
},
ImageUrl {
url: String,
},
ImageBase64 {
media_type: String,
data: String,
},
}
impl ContentPart {
pub fn text(s: impl Into<String>) -> Self {
Self::Text { text: s.into() }
}
pub fn image_url(url: impl Into<String>) -> Self {
Self::ImageUrl { url: url.into() }
}
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text { text } => Some(text),
_ => None,
}
}
}
mod content_vec_serde {
use super::ContentPart;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(parts: &[ContentPart], s: S) -> Result<S::Ok, S::Error> {
if parts.len() == 1
&& let ContentPart::Text { text } = &parts[0]
{
return s.serialize_str(text);
}
parts.serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<ContentPart>, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrParts {
S(String),
P(Vec<ContentPart>),
}
match StringOrParts::deserialize(d)? {
StringOrParts::S(s) => Ok(vec![ContentPart::text(s)]),
StringOrParts::P(v) => Ok(v),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
#[serde(with = "content_vec_serde")]
pub content: Vec<ContentPart>,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: MessageRole::System,
content: vec![ContentPart::text(content)],
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: MessageRole::User,
content: vec![ContentPart::text(content)],
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Assistant,
content: vec![ContentPart::text(content)],
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: MessageRole::Tool,
content: vec![ContentPart::text(content)],
}
}
pub fn user_multimodal(parts: Vec<ContentPart>) -> Self {
Self {
role: MessageRole::User,
content: parts,
}
}
pub fn text_content(&self) -> String {
self.content
.iter()
.filter_map(|p| p.as_text())
.collect::<Vec<_>>()
.join("")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub provider: String,
pub model: String,
pub api_key_env: String,
pub base_url: Option<String>,
pub temperature: f32,
pub max_tokens: Option<u32>,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
provider: "openai".into(),
model: "gpt-4o".into(),
api_key_env: "OPENAI_API_KEY".into(),
base_url: None,
temperature: 0.7,
max_tokens: Some(4096),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
Json,
JsonSchema {
schema: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequest {
pub system: Option<String>,
pub messages: Vec<ChatMessage>,
pub temperature: f32,
pub max_tokens: Option<u32>,
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<crate::llm::ToolDefinition>>,
}
impl LLMRequest {
pub fn builder() -> LLMRequestBuilder {
LLMRequestBuilder::default()
}
pub fn into_openai_messages(self) -> Vec<(String, String)> {
let mut out = Vec::with_capacity(self.messages.len() + 1);
if let Some(system) = self.system {
out.push(("system".into(), system));
}
for msg in self.messages {
out.push((msg.role.to_string(), msg.text_content()));
}
out
}
pub fn into_anthropic_messages(self) -> Vec<(String, String)> {
self.messages
.into_iter()
.map(|m| (m.role.to_string(), m.text_content()))
.collect()
}
}
#[derive(Debug, Clone, Default)]
pub struct LLMRequestBuilder {
system: Option<String>,
messages: Vec<ChatMessage>,
temperature: Option<f32>,
max_tokens: Option<u32>,
model: Option<String>,
response_format: Option<ResponseFormat>,
tools: Option<Vec<crate::llm::ToolDefinition>>,
}
impl LLMRequestBuilder {
pub fn system(mut self, prompt: impl Into<String>) -> Self {
self.system = Some(prompt.into());
self
}
pub fn user_message(mut self, content: impl Into<String>) -> Self {
self.messages.push(ChatMessage::user(content));
self
}
pub fn assistant_message(mut self, content: impl Into<String>) -> Self {
self.messages.push(ChatMessage::assistant(content));
self
}
pub fn message(mut self, msg: ChatMessage) -> Self {
self.messages.push(msg);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn tools(mut self, tools: Vec<crate::llm::ToolDefinition>) -> Self {
self.tools = Some(tools);
self
}
pub fn build(self) -> LLMRequest {
LLMRequest {
system: self.system,
messages: self.messages,
temperature: self.temperature.unwrap_or(0.7),
max_tokens: self.max_tokens,
model: self.model,
response_format: self.response_format,
tools: self.tools,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub content: String,
pub model: String,
pub usage: TokenUsage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub created: Option<u64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
Delta {
content: String,
},
Usage(TokenUsage),
Done,
}
#[cfg(feature = "client-async")]
pub type LLMStream =
Pin<Box<dyn futures_core::Stream<Item = crate::error::Result<StreamEvent>> + Send>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_role_display() {
assert_eq!(MessageRole::System.to_string(), "system");
assert_eq!(MessageRole::User.to_string(), "user");
assert_eq!(MessageRole::Assistant.to_string(), "assistant");
assert_eq!(MessageRole::Tool.to_string(), "tool");
}
#[test]
fn message_role_serde_roundtrip() {
let json = serde_json::to_string(&MessageRole::User).unwrap();
assert_eq!(json, "\"user\"");
let back: MessageRole = serde_json::from_str(&json).unwrap();
assert_eq!(back, MessageRole::User);
}
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("instructions");
assert_eq!(sys.role, MessageRole::System);
let user = ChatMessage::user("hello");
assert_eq!(user.role, MessageRole::User);
let asst = ChatMessage::assistant("hi there");
assert_eq!(asst.role, MessageRole::Assistant);
let tool = ChatMessage::tool("result");
assert_eq!(tool.role, MessageRole::Tool);
}
#[test]
fn single_text_serializes_as_string() {
let msg = ChatMessage::user("hello");
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"content\":\"hello\""), "got: {json}");
}
#[test]
fn multipart_serializes_as_array() {
let msg = ChatMessage::user_multimodal(vec![
ContentPart::text("describe this"),
ContentPart::image_url("https://example.com/img.png"),
]);
let json = serde_json::to_string(&msg).unwrap();
assert!(
json.contains("\"content\":["),
"expected array serialization, got: {json}"
);
}
#[test]
fn single_text_deserialize_from_string() {
let json = r#"{"role":"user","content":"hello"}"#;
let msg: ChatMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content.len(), 1);
assert_eq!(msg.text_content(), "hello");
}
#[test]
fn multipart_deserialize_from_array() {
let json = r#"{"role":"user","content":[{"type":"text","text":"hi"},{"type":"image_url","url":"https://x.com/img.png"}]}"#;
let msg: ChatMessage = serde_json::from_str(json).unwrap();
assert_eq!(msg.content.len(), 2);
}
#[test]
fn content_part_text_helper() {
let p = ContentPart::text("hello");
assert_eq!(p.as_text(), Some("hello"));
}
#[test]
fn response_format_json_serialization() {
let fmt = ResponseFormat::Json;
let json = serde_json::to_string(&fmt).unwrap();
assert!(json.contains("\"type\":\"json\""), "got: {json}");
}
#[test]
fn response_format_text_serialization() {
let fmt = ResponseFormat::Text;
let json = serde_json::to_string(&fmt).unwrap();
assert!(json.contains("\"type\":\"text\""), "got: {json}");
}
#[test]
fn response_format_json_schema() {
let fmt = ResponseFormat::JsonSchema {
schema: serde_json::json!({"type": "object"}),
};
let json = serde_json::to_string(&fmt).unwrap();
assert!(json.contains("json_schema"), "got: {json}");
}
#[test]
fn builder_basic() {
let req = LLMRequest::builder()
.system("you are helpful")
.user_message("hello")
.temperature(0.5)
.build();
assert_eq!(req.system.as_deref(), Some("you are helpful"));
assert_eq!(req.messages.len(), 1);
assert_eq!(req.temperature, 0.5);
}
#[test]
fn builder_with_model_and_format() {
let req = LLMRequest::builder()
.user_message("test")
.model("gpt-4o-mini")
.response_format(ResponseFormat::Json)
.max_tokens(100)
.build();
assert_eq!(req.model.as_deref(), Some("gpt-4o-mini"));
assert!(matches!(req.response_format, Some(ResponseFormat::Json)));
assert_eq!(req.max_tokens, Some(100));
}
#[test]
fn builder_with_tools() {
use crate::llm::ToolDefinition;
let req = LLMRequest::builder()
.user_message("what's the weather?")
.tools(vec![ToolDefinition {
name: "get_weather".into(),
description: "Get weather".into(),
input_schema: serde_json::json!({"type": "object"}),
}])
.build();
assert!(req.tools.is_some());
assert_eq!(req.tools.unwrap().len(), 1);
}
#[test]
fn into_openai_messages_with_system() {
let req = LLMRequest::builder()
.system("be helpful")
.user_message("hi")
.assistant_message("hello")
.build();
let msgs = req.into_openai_messages();
assert_eq!(msgs.len(), 3);
assert_eq!(msgs[0].0, "system");
assert_eq!(msgs[1].0, "user");
assert_eq!(msgs[2].0, "assistant");
}
#[test]
fn into_anthropic_messages_excludes_system() {
let req = LLMRequest::builder()
.system("be helpful")
.user_message("hi")
.build();
let msgs = req.into_anthropic_messages();
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].0, "user");
}
#[test]
fn text_content_extracts_text() {
let msg = ChatMessage::user_multimodal(vec![
ContentPart::text("hello "),
ContentPart::image_url("http://x.com/i.png"),
ContentPart::text("world"),
]);
assert_eq!(msg.text_content(), "hello world");
}
}