use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}
impl ChatMessage {
pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
name: None,
function_call: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(MessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(MessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(MessageRole::Assistant, content)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl ChatRequest {
pub fn new(messages: Vec<ChatMessage>, model: impl Into<String>) -> Self {
Self {
messages,
model: model.into(),
max_tokens: None,
temperature: None,
top_p: None,
stream: None,
extra: HashMap::new(),
}
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn with_stream(mut self, stream: bool) -> Self {
self.stream = Some(stream);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Usage,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
impl ChatResponse {
pub fn content(&self) -> Option<&str> {
self.choices
.first()
.map(|choice| choice.message.content.as_str())
}
pub fn message(&self) -> Option<&ChatMessage> {
self.choices.first().map(|choice| &choice.message)
}
}