1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use http_req::{
    request::{Method, Request},
    uri::Uri,
};
use serde::Serialize;
use urlencoding::encode;

use crate::Retry;

/// Request struct for the completion.
///
/// The default model is "text-davinci-003".
///
/// For more detail about parameters, please refer to
/// [OpenAI docs](https://platform.openai.com/docs/api-reference/completions/create)
///
#[derive(Debug, Serialize)]
pub struct CompletionRequest {
    /// The ID or name of the model to use for completion.
    pub model: String,
    /// The text to be used as the prompt for completion.
    pub prompt: String,
    /// An optional suffix to append to the prompt before completion.
    pub suffix: Option<String>,
    /// The number of completions to generate.
    pub n: u8,
    /// The number of completions to consider and return the best from.
    pub best_of: u8,
    /// The maximum number of tokens in the generated completions.
    pub max_tokens: u16,
    /// Controls the randomness of the generated completions.
    pub temperature: f32,
    /// Nucleus sampling: controls the diversity of the generated completions.
    pub top_p: f32,
    /// Whether to include log probabilities for each token in the completions.
    pub logprobs: Option<u8>,
    /// Penalty to discourage the model from generating repetitive completions.
    pub presence_penalty: f32,
    /// Penalty to discourage the model from using low-frequency words in completions.
    pub frequency_penalty: f32,
}

impl Default for CompletionRequest {
    fn default() -> CompletionRequest {
        CompletionRequest {
            model: String::from("text-davinci-003"),
            prompt: String::from("<|endoftext|>"),
            suffix: None,
            n: 1,
            best_of: 1,
            max_tokens: 16,
            temperature: 1.0,
            top_p: 1.0,
            logprobs: None,
            presence_penalty: 0.0,
            frequency_penalty: 0.0,
        }
    }
}

impl crate::OpenAIFlows {
    /// Create completion for the provided prompt and parameters.
    ///
    /// `params` is a [CompletionRequest] object.
    ///
    /// If you haven't connected your OpenAI account with [Flows.network platform](https://flows.network),
    /// you will receive an error in the flow's building log or running log.
    ///
    ///    ```rust,no_run
    ///    // Preceeding code has obtained a question from the user in a String named `text`.
    ///    // Create a CompletionRequest.
    ///    let cr = CompletionRequest {
    ///        prompt: "I want you to act as my legal advisor. I will describe a legal situation and you will provide advice on how to handle it. My question is \"".to_owned() + text,
    ///        max_tokens: 2048,
    ///        ..Default::default()
    ///    };
    ///    // Call create_completion.
    ///    match openai.create_completion(cr).await {
    ///       Ok(res) => res,
    ///       Err(e) => {your error handling},
    /// }
    ///    ```
    pub async fn create_completion(
        &self,
        params: CompletionRequest,
    ) -> Result<Vec<String>, String> {
        self.keep_trying(|account| create_completion_inner(account, &params))
    }
}

fn create_completion_inner(account: &str, params: &CompletionRequest) -> Retry<Vec<String>> {
    let flows_user = unsafe { crate::_get_flows_user() };

    let mut writer = Vec::new();
    let uri = format!(
        "{}/{}/create_completion?account={}",
        crate::OPENAI_API_PREFIX.as_str(),
        flows_user,
        encode(account),
    );
    let uri = Uri::try_from(uri.as_str()).unwrap();
    let body = serde_json::to_vec(params).unwrap_or_default();
    match Request::new(&uri)
        .method(Method::POST)
        .header("Content-Type", "application/json")
        .header("Content-Length", &body.len())
        .body(&body)
        .send(&mut writer)
    {
        Ok(res) => {
            match res.status_code().is_success() {
                true => Retry::No(
                    serde_json::from_slice::<Vec<String>>(&writer)
                        .or(Err(String::from("Unexpected error"))),
                ),
                false => {
                    match res.status_code().into() {
                        409 | 429 | 503 => {
                            // 409 TryAgain 429 RateLimitError
                            // 503 ServiceUnavailable
                            Retry::Yes(String::from_utf8_lossy(&writer).into_owned())
                        }
                        _ => Retry::No(Err(String::from_utf8_lossy(&writer).into_owned())),
                    }
                }
            }
        }
        Err(e) => Retry::No(Err(e.to_string())),
    }
}