use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::cell::Cell;
use crate::{
api_resources::{Choices, TokenUsage},
Client, Result,
};
#[skip_serializing_none]
#[derive(Builder, Clone, Debug, Default, Deserialize, Serialize)]
#[builder(default, setter(into, strip_option))]
pub struct CompletionParam {
model: String,
prompt: Option<String>,
suffix: Option<String>,
max_tokens: Option<i32>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<u32>,
#[builder(setter(skip))]
stream: Cell<bool>,
logprobs: Option<f32>,
echo: Option<bool>,
stop: Option<String>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
best_of: Option<u16>,
user: Option<String>,
}
impl CompletionParamBuilder {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: Some(model.into()),
..Self::default()
}
}
}
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(default)]
pub struct Completion {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choices>,
pub usage: Option<TokenUsage>,
}
pub async fn create(client: &Client, param: &CompletionParam) -> Result<Completion> {
client.create_completion(param).await
}
pub async fn create_with_stream(
client: &Client,
param: &CompletionParam,
) -> Result<reqwest::Response> {
param.stream.set(true);
client.create_completion_with_stream(param).await
}
impl Client {
async fn create_completion(&self, param: &CompletionParam) -> Result<Completion> {
self.post::<CompletionParam, Completion>("completions", Some(param))
.await
}
async fn create_completion_with_stream(
&self,
param: &CompletionParam,
) -> Result<reqwest::Response> {
self.post_stream("completions", Some(param)).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_completion() {
let param: CompletionParam = serde_json::from_str(
r#"
{
"model": "text-davinci-003",
"prompt": "Say this is a test",
"max_tokens": 7,
"temperature": 0,
"top_p": 1,
"n": 1,
"stream": false,
"logprobs": null,
"stop": "\n"
}
"#,
)
.unwrap();
let resp: Completion = serde_json::from_str(
r#"
{
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
"object": "text_completion",
"created": 1589478378,
"model": "text-davinci-003",
"choices": [
{
"text": "\n\nThis is indeed a test",
"index": 0,
"logprobs": null,
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
"#,
)
.unwrap();
assert_eq!(param.model, "text-davinci-003");
assert_eq!(param.prompt.unwrap(), "Say this is a test");
assert_eq!(param.suffix, None);
assert_eq!(resp.choices.len(), 1);
assert_eq!(
resp.choices[0].text,
Some("\n\nThis is indeed a test".to_string())
);
assert_eq!(resp.choices[0].logprobs, None);
assert_eq!(resp.usage.unwrap().prompt_tokens, 5);
}
}