use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(flatten)]
pub parameters: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ChatMessage {
pub role: Role,
#[serde(deserialize_with = "deserialize_content")]
pub content: Content,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub phase: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
fn deserialize_content<'de, D>(deserializer: D) -> Result<Content, D::Error>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
if value.is_null() {
return Ok(Content::Text(String::new()));
}
serde_json::from_value(value).map_err(D::Error::custom)
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
#[serde(alias = "developer")]
System,
User,
Assistant,
Tool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_message_accepts_null_content_as_empty_text() {
let value = serde_json::json!({
"role": "assistant",
"content": null
});
let msg: ChatMessage = serde_json::from_value(value).expect("should deserialize");
assert_eq!(msg.role, Role::Assistant);
assert_eq!(msg.content, Content::Text(String::new()));
}
#[test]
fn role_accepts_developer_alias() {
let value = serde_json::json!({
"role": "developer",
"content": "You are a helpful assistant."
});
let msg: ChatMessage = serde_json::from_value(value).expect("should deserialize");
assert_eq!(msg.role, Role::System);
assert_eq!(
msg.content,
Content::Text("You are a helpful assistant.".to_string())
);
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
impl From<bamboo_domain::MessagePart> for ContentPart {
fn from(part: bamboo_domain::MessagePart) -> Self {
match part {
bamboo_domain::MessagePart::Text { text } => ContentPart::Text { text },
bamboo_domain::MessagePart::ImageUrl { image_url: url_ref } => ContentPart::ImageUrl {
image_url: ImageUrl {
url: url_ref.url,
detail: url_ref.detail,
},
},
}
}
}
impl From<ContentPart> for bamboo_domain::MessagePart {
fn from(part: ContentPart) -> Self {
match part {
ContentPart::Text { text } => bamboo_domain::MessagePart::Text { text },
ContentPart::ImageUrl { image_url } => bamboo_domain::MessagePart::ImageUrl {
image_url: bamboo_domain::ImageUrlRef {
url: image_url.url,
detail: image_url.detail,
},
},
}
}
}
impl From<ImageUrl> for bamboo_domain::ImageUrlRef {
fn from(url: ImageUrl) -> Self {
bamboo_domain::ImageUrlRef {
url: url.url,
detail: url.detail,
}
}
}
impl From<bamboo_domain::ImageUrlRef> for ImageUrl {
fn from(url: bamboo_domain::ImageUrlRef) -> Self {
ImageUrl {
url: url.url,
detail: url.detail,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: serde_json::Value, }
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
pub enum ToolChoice {
String(String),
Object {
#[serde(rename = "type")]
tool_type: String,
function: FunctionChoice,
},
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub arguments: String, }
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ChatCompletionResponse {
pub id: String,
#[serde(default)]
pub object: Option<String>,
#[serde(default)]
pub created: Option<u64>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub choices: Vec<ResponseChoice>,
#[serde(default)]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ResponseChoice {
#[serde(default)]
pub index: u32,
pub message: ChatMessage,
#[serde(default)]
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Usage {
#[serde(default)]
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
#[serde(default)]
pub total_tokens: u32,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ChatCompletionStreamChunk {
pub id: String,
#[serde(default)]
pub object: Option<String>,
pub created: u64,
#[serde(default)]
pub model: Option<String>,
pub choices: Vec<StreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct StreamChoice {
pub index: u32,
pub delta: StreamDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct StreamToolCall {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type")]
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<StreamFunctionCall>,
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct StreamFunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone, Default, PartialEq)]
pub struct StreamDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<Role>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<StreamToolCall>>,
}