use super::{
common::Usage,
error::{BuilderError, Result},
Str,
};
use crate::{error::FallibleResponse, Client, OpenAiStream};
use chrono::{DateTime, Utc};
use futures::{Stream, TryStreamExt};
use reqwest::Response;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow, collections::HashMap, future::ready, marker::PhantomData, ops::RangeInclusive,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Role {
#[default]
User,
System,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message<'a> {
pub role: Role,
pub content: Str<'a>,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ChatChoice {
pub message: Message<'static>,
pub index: u64,
#[serde(default)]
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ChatCompletion {
pub id: String,
#[serde(with = "chrono::serde::ts_seconds")]
pub created: DateTime<Utc>,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(default)]
pub usage: Option<Usage>,
}
pub type ChatCompletionStream = OpenAiStream<ChatCompletion>;
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionBuilder<'a> {
model: Str<'a>,
messages: Vec<Message<'a>>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
n: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<Str<'a>>>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
logit_bias: Option<HashMap<Str<'a>, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<Str<'a>>,
}
impl<'a> Message<'a> {
#[inline]
pub fn new(role: Role, content: impl Into<Str<'a>>) -> Self {
return Self {
role,
content: content.into(),
};
}
#[inline]
pub fn user(content: impl Into<Str<'a>>) -> Self {
return Self::new(Role::User, content);
}
#[inline]
pub fn system(content: impl Into<Str<'a>>) -> Self {
return Self::new(Role::System, content);
}
#[inline]
pub fn assistant(content: impl Into<Str<'a>>) -> Self {
return Self::new(Role::Assistant, content);
}
}
impl ChatCompletion {
#[inline]
pub async fn new<'a, I: IntoIterator<Item = Message<'a>>>(
model: impl Into<Str<'a>>,
messages: I,
client: impl AsRef<Client>,
) -> Result<Self> {
return Self::builder(model, messages).build(client).await;
}
#[inline]
pub async fn new_stream<'a, I: IntoIterator<Item = Message<'a>>>(
model: impl Into<Str<'a>>,
messages: I,
client: impl AsRef<Client>,
) -> Result<ChatCompletionStream> {
return ChatCompletionStream::new(model, messages, client).await;
}
#[inline]
pub fn builder<'a, I: IntoIterator<Item = Message<'a>>>(
model: impl Into<Str<'a>>,
messages: I,
) -> ChatCompletionBuilder<'a> {
return ChatCompletionBuilder::new(model, messages);
}
}
impl ChatCompletion {
#[inline]
pub fn first(&self) -> Option<&ChatChoice> {
return self.choices.first();
}
#[inline]
pub fn into_first(self) -> Option<ChatChoice> {
return self.choices.into_iter().next();
}
}
impl<'a> ChatCompletionBuilder<'a> {
pub fn new<I: IntoIterator<Item = Message<'a>>>(
model: impl Into<Cow<'a, str>>,
messages: I,
) -> Self {
return Self {
model: model.into(),
messages: messages.into_iter().collect(),
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: false,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
stop: None,
};
}
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn temperature(mut self, temperature: f64) -> Result<Self, BuilderError<Self>> {
const RANGE: RangeInclusive<f64> = 0f64..=2f64;
return match RANGE.contains(&temperature) {
true => {
self.temperature = Some(temperature);
Ok(self)
}
false => Err(BuilderError::msg(
self,
format!("temperature out of range ({RANGE:?})"),
)),
};
}
pub fn top_p(mut self, top_p: f64) -> Self {
self.top_p = Some(top_p);
self
}
pub fn n(mut self, n: u64) -> Self {
self.n = Some(n);
self
}
pub fn stop<I: IntoIterator>(mut self, stop: I) -> Result<Self, BuilderError<Self>>
where
I::Item: Into<Str<'a>>,
{
const MAX_SIZE: usize = 4;
let mut stop = stop.into_iter();
let mut result = Vec::with_capacity(MAX_SIZE);
while let Some(next) = stop.next() {
if result.len() == result.capacity() {
return Err(BuilderError::msg(
self,
format!("Interator exceeds size limit of {MAX_SIZE}"),
));
}
result.push(next.into());
}
self.stop = Some(result);
return Ok(self);
}
pub fn presence_penalty(mut self, presence_penalty: f64) -> Result<Self, BuilderError<Self>> {
const RANGE: RangeInclusive<f64> = -2f64..=2f64;
return match RANGE.contains(&presence_penalty) {
true => {
self.presence_penalty = Some(presence_penalty);
Ok(self)
}
false => Err(BuilderError::msg(
self,
format!("presence_penalty out of range ({RANGE:?})"),
)),
};
}
pub fn frequency_penalty(mut self, frequency_penalty: f64) -> Result<Self, BuilderError<Self>> {
const RANGE: RangeInclusive<f64> = -2f64..=2f64;
return match RANGE.contains(&frequency_penalty) {
true => {
self.frequency_penalty = Some(frequency_penalty);
Ok(self)
}
false => Err(BuilderError::msg(
self,
format!("frequency_penalty out of range ({RANGE:?})"),
)),
};
}
pub fn logit_bias<K, I>(mut self, logit_bias: I) -> Self
where
K: Into<Str<'a>>,
I: IntoIterator<Item = (K, f64)>,
{
self.logit_bias = Some(logit_bias.into_iter().map(|(k, v)| (k.into(), v)).collect());
self
}
pub fn user(mut self, user: impl Into<Str<'a>>) -> Self {
self.user = Some(user.into());
self
}
pub async fn build(self, client: impl AsRef<Client>) -> Result<ChatCompletion> {
let resp = client
.as_ref()
.post("https://api.openai.com/v1/chat/completions")
.json(&self)
.send()
.await?
.json::<FallibleResponse<ChatCompletion>>()
.await?
.into_result()?;
return Ok(resp);
}
pub async fn build_stream(
mut self,
client: impl AsRef<Client>,
) -> Result<ChatCompletionStream> {
self.stream = true;
let resp = client
.as_ref()
.post("https://api.openai.com/v1/chat/completions")
.json(&self)
.send()
.await?;
return Ok(ChatCompletionStream::create(resp));
}
}
impl ChatCompletionStream {
#[inline]
pub async fn new<'a, I: IntoIterator<Item = Message<'a>>>(
model: impl Into<Str<'a>>,
messages: I,
client: impl AsRef<Client>,
) -> Result<Self> {
return ChatCompletion::builder(model, messages)
.build_stream(client)
.await;
}
#[inline]
fn create(resp: Response) -> Self {
return Self {
inner: Box::pin(resp.bytes_stream()),
current_chunk: None,
_phtm: PhantomData,
};
}
}
impl ChatCompletionStream {
pub fn into_message_stream(self) -> impl Stream<Item = Result<Message<'static>>> {
return self
.try_filter_map(|x| ready(Ok(x.choices.into_iter().next())))
.map_ok(|x| x.message);
}
pub fn into_text_stream(self) -> impl Stream<Item = Result<Str<'static>>> {
return self
.try_filter_map(|x| ready(Ok(x.choices.into_iter().next())))
.map_ok(|x| x.message.content);
}
}