1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
use anyhow::{anyhow, Result};
use log::{error, info, warn};
use reqwest::{header, Client};
use schemars::{schema_for, JsonSchema};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;

use crate::{domain::OpenAIDataResponse, models::OpenAIModels, utils::get_tokenizer};

/// [Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api)
///
/// Chat models take a list of messages as input and return a model-generated message as output.
/// Although the chat format is designed to make multi-turn conversations easy,
/// it’s just as useful for single-turn tasks without any conversation.
pub struct OpenAI {
    model: OpenAIModels,
    //For prompt & response
    max_tokens: usize,
    temperature: u32,
    input_json: Option<String>,
    debug: bool,
    function_call: bool,
    api_key: String,
}

impl OpenAI {
    ///
    pub fn new(
        open_ai_key: &str,
        model: OpenAIModels,
        max_tokens: Option<usize>,
        temperature: Option<u32>,
    ) -> Self {
        OpenAI {
            //If no max tokens limit is provided we default to max allowed for the model
            max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()),
            function_call: model.function_call_default(),
            model,
            temperature: temperature.unwrap_or(0u32), //Low number makes the output less random and more deterministic
            input_json: None,
            debug: false,
            api_key: open_ai_key.to_string(),
        }
    }

    /*
     * This function turns on debug mode which will info! the prompt to log when executing it.
     */
    pub fn debug(mut self) -> Self {
        self.debug = true;
        self
    }

    /*
     * This function turns on/off function calling mode when interacting with OpenAI API.
     */
    pub fn function_calling(mut self, function_call: bool) -> Self {
        self.function_call = function_call;
        self
    }

    /*
     * This method can be used to provide values that will be used as context for the prompt.
     * Using this function you can provide multiple input values by calling it multiple times. New values will be appended with the category name
     * It accepts any instance that implements the Serialize trait.
     */
    pub fn set_context<T: Serialize>(mut self, input_name: &str, input_data: &T) -> Result<Self> {
        let input_json = if let Ok(json) = serde_json::to_string(&input_data) {
            json
        } else {
            return Err(anyhow!("Unable serialize provided input data."));
        };
        let line_break = match self.input_json {
            Some(_) => "\n\n".to_string(),
            None => "".to_string(),
        };
        let new_json = format!(
            "{}{}{}: {}",
            self.input_json.unwrap_or_default(),
            line_break,
            input_name,
            input_json,
        );
        self.input_json = Some(new_json);
        Ok(self)
    }

    /*
     * This method is used to check how many tokens would most likely remain for the response
     * This is accomplished by estimating number of tokens needed for system/base instructions, user prompt, and function components including schema definition.
     */
    pub fn check_prompt_tokens<T: JsonSchema + DeserializeOwned>(
        &self,
        instructions: &str,
    ) -> Result<usize> {
        //Output schema is extracted from the type parameter
        let schema = schema_for!(T);
        let json_value: Value = serde_json::to_value(&schema)?;

        let prompt = format!(
            "Instructions:
            {instructions}

            Input data:
            {input_json}
            
            Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.", 
            instructions = instructions,
            input_json = self.input_json.clone().unwrap_or_default(),
        );

        let full_prompt = format!(
            "{}{}{}",
            //Base (system) instructions
            self.model.get_base_instructions(Some(self.function_call)),
            //Instructions & context data
            prompt,
            //Output schema
            serde_json::to_string(&json_value).unwrap_or_default()
        );

        //Check how many tokens are required for prompt
        let bpe = get_tokenizer(&self.model)?;
        let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len();

        //Assuming another 5% overhead for json formatting
        Ok((prompt_tokens as f64 * 1.05) as usize)
    }

    /*
     * This function leverages OpenAI API to perform any query as per the provided body.
     *
     * It returns a String the Response object that needs to be parsed based on the self.model.
     */
    async fn call_openai_api(&self, body: &serde_json::Value) -> Result<String> {
        //Get the API url
        let model_url = self.model.get_endpoint();

        //Make the API call
        let client = Client::new();

        let response = client
            .post(model_url)
            .header(header::CONTENT_TYPE, "application/json")
            .bearer_auth(&self.api_key)
            .json(&body)
            .send()
            .await?;

        let response_status = response.status();
        let response_text = response.text().await?;

        if self.debug {
            info!(
                "[debug] OpenAI API response: [{}] {:#?}",
                &response_status, &response_text
            );
        }

        Ok(response_text)
    }

    /*
     * This method is used to submit a prompt to OpenAI and process the response.
     * When calling the function you need to specify the type parameter as the response will match the schema of that type.
     * The prompt in this function is written in a way to instruct OpenAI to behave like a computer function that calculates an output based on provided input and its language model.
     */
    pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
        self,
        instructions: &str,
    ) -> Result<T> {
        //Output schema is extracted from the type parameter
        let schema = schema_for!(T);
        let json_value: Value = serde_json::to_value(&schema)?;

        let prompt = format!(
            "Instructions:
            {instructions}

            Input data:
            {input_json}
            
            Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.", 
            instructions = instructions,
            input_json = self.input_json.clone().unwrap_or_default(),
        );

        //Validate how many tokens remain for the response (and how many are used for prompt)
        let prompt_tokens = self
            .check_prompt_tokens::<T>(instructions)
            .unwrap_or_default();

        if prompt_tokens >= self.max_tokens {
            return Err(anyhow!(
                "The provided prompt requires more tokens than allocated."
            ));
        }
        let response_tokens = self.max_tokens - prompt_tokens;

        //Throw a warning if after processing the prompt there might be not enough tokens for response
        //This assumes response will be similar size as input. Because this is not always correct this is a warning and not an error
        if prompt_tokens * 2 >= self.max_tokens {
            warn!(
                "{} tokens remaining for response: {} allocated, {} used for prompt",
                response_tokens.to_string(),
                self.max_tokens.to_string(),
                prompt_tokens.to_string(),
            );
        };

        //Build the API body depending on the used model
        let model_body = self.model.get_body(
            &prompt,
            &json_value,
            self.function_call,
            &response_tokens,
            &self.temperature,
        );

        //Display debug info if requested
        if self.debug {
            info!("[debug] Model body: {:#?}", model_body);
            info!(
                "[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.",
                prompt_tokens.to_string(),
                response_tokens.to_string(),
            );
        }

        let response_text = self.call_openai_api(&model_body).await?;

        //Extract data from the returned response text based on the used model
        let response_string = self.model.get_data(&response_text, self.function_call)?;

        if self.debug {
            info!("[debug] OpenAI response data: {}", response_string);
        }
        //Deserialize the string response into the expected output type
        let response_deser: anyhow::Result<T, anyhow::Error> =
            serde_json::from_str(&response_string).map_err(|error| {
                error!("[OpenAI] Response serialization error: {}", &error);
                anyhow!("Error: {}", error)
            });
        // Sometimes openai responds with a json object that has a data property. If that's the case, we need to extract the data property and deserialize that.
        if let Err(_e) = response_deser {
            let response_deser: OpenAIDataResponse<T> = serde_json::from_str(&response_text)
                .map_err(|error| {
                    error!("[OpenAI] Response serialization error: {}", &error);
                    anyhow!("Error: {}", error)
                })?;
            Ok(response_deser.data)
        } else {
            Ok(response_deser.unwrap())
        }
    }
}