#[cfg(feature = "functions")]
use crate::functions::FunctionCall;
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize, Eq, Ord)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
Assistant,
User,
Function,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(deserialize_with = "deserialize_maybe_null")]
pub content: String,
#[cfg(feature = "functions")]
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}
fn deserialize_maybe_null<'de, D>(deserializer: D) -> Result<String, D::Error>
where D: Deserializer<'de> {
let buf = Option::<String>::deserialize(deserializer)?;
Ok(buf.unwrap_or(String::new()))
}
impl ChatMessage {
#[cfg(feature = "streams")]
pub fn from_response_chunks(chunks: Vec<ResponseChunk>) -> Vec<Self> {
let mut result: Vec<Self> = Vec::new();
for chunk in chunks {
match chunk {
ResponseChunk::Content {
delta,
response_index,
} => {
let msg = result
.get_mut(response_index)
.expect("Invalid response chunk sequence!");
msg.content.push_str(&delta);
}
ResponseChunk::BeginResponse {
role,
response_index: _,
} => {
let msg = ChatMessage {
role,
content: String::new(),
#[cfg(feature = "functions")]
function_call: None,
};
result.push(msg);
}
_ => {}
}
}
result
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct CompletionRequest<'a> {
pub model: &'a str,
pub messages: &'a Vec<ChatMessage>,
pub stream: bool,
pub temperature: f32,
pub top_p: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
pub frequency_penalty: f32,
pub presence_penalty: f32,
#[serde(rename = "n")]
pub reply_count: u32,
#[cfg(feature = "functions")]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub functions: &'a Vec<serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize)]
#[serde(untagged)]
pub enum ServerResponse {
Error {
error: CompletionError,
},
Completion(CompletionResponse),
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize)]
pub struct CompletionError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize)]
pub struct CompletionResponse {
#[serde(rename = "id")]
pub message_id: Option<String>,
#[serde(rename = "created")]
pub created_timestamp: Option<u64>,
pub model: String,
pub usage: TokenUsage,
#[serde(rename = "choices")]
pub message_choices: Vec<MessageChoice>,
}
impl CompletionResponse {
pub fn message(&self) -> &ChatMessage {
&self.message_choices.first().unwrap().message
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize)]
pub struct MessageChoice {
pub message: ChatMessage,
pub finish_reason: String,
pub index: u32,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)]
#[cfg(feature = "streams")]
pub enum ResponseChunk {
Content {
delta: String,
response_index: usize,
},
BeginResponse {
role: Role,
response_index: usize,
},
CloseResponse {
response_index: usize,
},
Done,
}
#[derive(Debug, Clone, Deserialize)]
#[cfg(feature = "streams")]
pub struct InboundResponseChunk {
pub choices: Vec<InboundChunkChoice>,
}
#[derive(Debug, Clone, Deserialize)]
#[cfg(feature = "streams")]
pub struct InboundChunkChoice {
pub delta: InboundChunkPayload,
pub index: usize,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
#[cfg(feature = "streams")]
pub enum InboundChunkPayload {
AnnounceRoles {
role: Role,
},
StreamContent {
content: String,
},
Close {},
}