1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
use http_req::{
request::{Method, Request},
uri::Uri,
};
use serde::Serialize;
use urlencoding::encode;
use crate::Retry;
/// Request struct for the completion.
///
/// The default model is "text-davinci-003".
///
/// For more detail about parameters, please refer to
/// [OpenAI docs](https://platform.openai.com/docs/api-reference/completions/create)
///
#[derive(Debug, Serialize)]
pub struct CompletionRequest {
/// The ID or name of the model to use for completion.
pub model: String,
/// The text to be used as the prompt for completion.
pub prompt: String,
/// An optional suffix to append to the prompt before completion.
pub suffix: Option<String>,
/// The number of completions to generate.
pub n: u8,
/// The number of completions to consider and return the best from.
pub best_of: u8,
/// The maximum number of tokens in the generated completions.
pub max_tokens: u16,
/// Controls the randomness of the generated completions.
pub temperature: f32,
/// Nucleus sampling: controls the diversity of the generated completions.
pub top_p: f32,
/// Whether to include log probabilities for each token in the completions.
pub logprobs: Option<u8>,
/// Penalty to discourage the model from generating repetitive completions.
pub presence_penalty: f32,
/// Penalty to discourage the model from using low-frequency words in completions.
pub frequency_penalty: f32,
}
impl Default for CompletionRequest {
fn default() -> CompletionRequest {
CompletionRequest {
model: String::from("text-davinci-003"),
prompt: String::from("<|endoftext|>"),
suffix: None,
n: 1,
best_of: 1,
max_tokens: 16,
temperature: 1.0,
top_p: 1.0,
logprobs: None,
presence_penalty: 0.0,
frequency_penalty: 0.0,
}
}
}
impl crate::OpenAIFlows {
/// Create completion for the provided prompt and parameters.
///
/// `params` is a [CompletionRequest] object.
///
/// If you haven't connected your OpenAI account with [Flows.network platform](https://flows.network),
/// you will receive an error in the flow's building log or running log.
///
/// ```rust,no_run
/// // Preceeding code has obtained a question from the user in a String named `text`.
/// // Create a CompletionRequest.
/// let cr = CompletionRequest {
/// prompt: "I want you to act as my legal advisor. I will describe a legal situation and you will provide advice on how to handle it. My question is \"".to_owned() + text,
/// max_tokens: 2048,
/// ..Default::default()
/// };
/// // Call create_completion.
/// match openai.create_completion(cr).await {
/// Ok(res) => res,
/// Err(e) => {your error handling},
/// }
/// ```
pub async fn create_completion(
&self,
params: CompletionRequest,
) -> Result<Vec<String>, String> {
self.keep_trying(|account| create_completion_inner(account, ¶ms))
}
}
fn create_completion_inner(account: &str, params: &CompletionRequest) -> Retry<Vec<String>> {
let flows_user = unsafe { crate::_get_flows_user() };
let mut writer = Vec::new();
let uri = format!(
"{}/{}/create_completion?account={}",
crate::OPENAI_API_PREFIX.as_str(),
flows_user,
encode(account),
);
let uri = Uri::try_from(uri.as_str()).unwrap();
let body = serde_json::to_vec(params).unwrap_or_default();
match Request::new(&uri)
.method(Method::POST)
.header("Content-Type", "application/json")
.header("Content-Length", &body.len())
.body(&body)
.send(&mut writer)
{
Ok(res) => {
match res.status_code().is_success() {
true => Retry::No(
serde_json::from_slice::<Vec<String>>(&writer)
.or(Err(String::from("Unexpected error"))),
),
false => {
match res.status_code().into() {
409 | 429 | 503 => {
// 409 TryAgain 429 RateLimitError
// 503 ServiceUnavailable
Retry::Yes(String::from_utf8_lossy(&writer).into_owned())
}
_ => Retry::No(Err(String::from_utf8_lossy(&writer).into_owned())),
}
}
}
}
Err(e) => Retry::No(Err(e.to_string())),
}
}