use std::collections::HashMap;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use crate::chat::chunk_stream::ChunkStream;
use crate::chat::Bias;
use crate::chat::ChatApiError;
use crate::chat::ChatApiResult;
use crate::chat::ChatChunkResult;
use crate::chat::ChatCompletionObject;
use crate::chat::ChatModel;
use crate::chat::LogprobsOption;
use crate::chat::MaxTokens;
use crate::chat::Message;
use crate::chat::Penalty;
use crate::chat::ResponseFormat;
use crate::chat::StopOption;
use crate::chat::StreamOption;
use crate::chat::Tool;
use crate::chat::ToolChoice;
use crate::chat::TopLogprobs;
use crate::chat::TopP;
use crate::ApiError;
use crate::Client;
use crate::ClientError;
use crate::Temperature;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionsRequestBody {
pub messages: Vec<Message>,
pub model: ChatModel,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<Penalty>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, Bias>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogprobsOption>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<TopLogprobs>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<MaxTokens>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<Penalty>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StopOption>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<StreamOption>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<Temperature>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<TopP>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl Default for CompletionsRequestBody {
fn default() -> Self {
Self {
messages: Vec::new(),
model: ChatModel::Gpt35Turbo,
frequency_penalty: None,
logit_bias: None,
logprobs: None,
top_logprobs: None,
max_tokens: None,
n: None,
presence_penalty: None,
response_format: None,
seed: None,
stop: None,
stream: None,
temperature: None,
top_p: None,
tools: None,
tool_choice: None,
user: None,
}
}
}
pub(crate) async fn complete(
client: &Client,
request_body: CompletionsRequestBody,
) -> ChatApiResult<ChatCompletionObject> {
if let Some(stream) = request_body.stream {
if stream != StreamOption::ReturnOnce {
return Err(ChatApiError::StreamOptionMismatch);
}
}
let response = client
.post("https://api.openai.com/v1/chat/completions")
.json(&request_body)
.send()
.await
.map_err(ClientError::HttpRequestError)?;
let status_code = response.status();
let response_text = response
.text()
.await
.map_err(ClientError::ReadResponseTextFailed)?;
if status_code.is_success() {
serde_json::from_str(&response_text).map_err(|error| {
{
ClientError::ResponseDeserializationFailed {
error,
text: response_text,
}
}
.into()
})
}
else {
let error_response =
serde_json::from_str(&response_text).map_err(|error| {
ClientError::ErrorResponseDeserializationFailed {
error,
text: response_text,
}
})?;
Err(ApiError {
status_code,
error_response,
}
.into())
}
}
pub(crate) async fn complete_stream(
client: &Client,
request_body: CompletionsRequestBody,
) -> ChatApiResult<impl Stream<Item = ChatChunkResult>> {
if request_body.stream.is_none() {
return Err(ChatApiError::StreamOptionMismatch);
}
if let Some(stream) = request_body.stream {
if stream != StreamOption::ReturnStream {
return Err(ChatApiError::StreamOptionMismatch);
}
}
let response = client
.post("https://api.openai.com/v1/chat/completions")
.json(&request_body)
.send()
.await
.map_err(ClientError::HttpRequestError)?;
let status_code = response.status();
if status_code.is_success() {
Ok(ChunkStream::new(
response.bytes_stream(),
))
}
else {
let response_text = response
.text()
.await
.map_err(ClientError::ReadResponseTextFailed)?;
let error_response =
serde_json::from_str(&response_text).map_err(|error| {
ClientError::ErrorResponseDeserializationFailed {
error,
text: response_text,
}
})?;
Err(ApiError {
status_code,
error_response,
}
.into())
}
}