use std::collections::HashMap;
use crate::requests::Requests;
use crate::*;
use serde::{Deserialize, Serialize};
use super::{Usage, COMPLETION_CREATE};
#[derive(Debug, Serialize, Deserialize)]
pub struct Completion {
pub id: Option<String>,
pub object: Option<String>,
pub created: Option<u64>,
pub model: Option<String>,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionsBody {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub echo: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
pub trait CompletionsApi {
fn completion_create(&self, completions_body: &CompletionsBody) -> ApiResult<Completion>;
}
impl CompletionsApi for OpenAI {
fn completion_create(&self, completions_body: &CompletionsBody) -> ApiResult<Completion> {
let request_body = serde_json::to_value(completions_body).unwrap();
let res = self.post(COMPLETION_CREATE, request_body)?;
let completion: Completion = serde_json::from_value(res.clone()).unwrap();
Ok(completion)
}
}
#[cfg(test)]
mod tests {
use crate::openai::new_test_openai;
use super::{CompletionsApi, CompletionsBody};
#[test]
fn test_completions() {
let openai = new_test_openai();
let body = CompletionsBody {
model: "babbage-002".to_string(),
prompt: Some(vec!["Say this is a test".to_string()]),
suffix: None,
max_tokens: Some(7),
temperature: Some(0_f32),
top_p: Some(0_f32),
n: Some(2),
stream: Some(false),
logprobs: None,
echo: None,
stop: Some(vec!["\n".to_string()]),
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
};
let rs = openai.completion_create(&body);
let choice = rs.unwrap().choices;
let text = &choice[0].text.as_ref().unwrap();
assert!(text.contains("this"));
}
}