use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl Message {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: content.into(),
name: None,
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: content.into(),
name: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: content.into(),
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
impl ChatCompletionRequest {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
temperature: None,
max_tokens: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
stream: Some(false),
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.stop = Some(stop);
self
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(default)]
pub usage: Option<Usage>,
}
impl ChatCompletionResponse {
pub fn content(&self) -> Option<&str> {
self.choices
.first()
.and_then(|c| c.message.as_ref())
.map(|m| m.effective_content())
}
pub fn finish_reason(&self) -> Option<&str> {
self.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: Option<ResponseMessage>,
pub delta: Option<ResponseMessage>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ResponseMessage {
pub role: Option<Role>,
#[serde(default)]
pub content: String,
#[serde(default)]
pub reasoning: Option<String>,
}
impl ResponseMessage {
pub fn effective_content(&self) -> &str {
if self.content.is_empty() {
self.reasoning.as_deref().unwrap_or("")
} else {
&self.content
}
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct Usage {
#[serde(default)]
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
#[serde(default)]
pub total_tokens: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiErrorResponse {
pub error: ApiErrorDetail,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: Option<String>,
pub code: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let system = Message::system("You are helpful");
assert_eq!(system.role, Role::System);
assert_eq!(system.content, "You are helpful");
let user = Message::user("Hello").with_name("John");
assert_eq!(user.role, Role::User);
assert_eq!(user.name, Some("John".to_string()));
}
#[test]
fn test_request_builder() {
let request = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")])
.with_temperature(0.7)
.with_max_tokens(1000);
assert_eq!(request.model, "gpt-4");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.max_tokens, Some(1000));
}
#[test]
fn test_response_parsing() {
let json = r#"{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-4",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you?"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}"#;
let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.content(), Some("Hello! How can I help you?"));
assert_eq!(response.finish_reason(), Some("stop"));
assert_eq!(response.usage.as_ref().unwrap().total_tokens, 30);
}
}