chatgpt/
client.rs

1use std::path::Path;
2
3use reqwest::header::AUTHORIZATION;
4use reqwest::header::{HeaderMap, HeaderValue};
5use reqwest::{self, Proxy};
6use tokio::fs::File;
7use tokio::io::AsyncReadExt;
8
9#[cfg(feature = "streams")]
10use reqwest::Response;
11#[cfg(feature = "streams")]
12use {
13    crate::types::InboundChunkPayload, crate::types::InboundResponseChunk,
14    crate::types::ResponseChunk, futures_util::Stream,
15};
16
17use crate::config::ModelConfiguration;
18use crate::converse::Conversation;
19use crate::types::{ChatMessage, CompletionRequest, CompletionResponse, Role, ServerResponse};
20
21#[cfg(feature = "functions")]
22use crate::functions::{FunctionArgument, FunctionDescriptor};
23
24/// The client that operates the ChatGPT API
25#[derive(Debug, Clone)]
26pub struct ChatGPT {
27    client: reqwest::Client,
28    /// The configuration for this ChatGPT client
29    pub config: ModelConfiguration,
30}
31
32impl ChatGPT {
33    /// Constructs a new ChatGPT API client with provided API key and default configuration
34    pub fn new<S: Into<String>>(api_key: S) -> crate::Result<Self> {
35        Self::new_with_config(api_key, ModelConfiguration::default())
36    }
37
38    /// Constructs a new ChatGPT API client with provided API key, default configuration and a reqwest proxy
39    pub fn new_with_proxy<S: Into<String>>(api_key: S, proxy: Proxy) -> crate::Result<Self> {
40        Self::new_with_config_proxy(api_key, ModelConfiguration::default(), proxy)
41    }
42
43    /// Constructs a new ChatGPT API client with provided API Key and Configuration
44    pub fn new_with_config<S: Into<String>>(
45        api_key: S,
46        config: ModelConfiguration,
47    ) -> crate::Result<Self> {
48        let api_key = api_key.into();
49        let mut headers = HeaderMap::new();
50        headers.insert(
51            AUTHORIZATION,
52            HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
53        );
54        let client = reqwest::ClientBuilder::new()
55            .default_headers(headers)
56            .timeout(config.timeout)
57            .build()?;
58        Ok(Self { client, config })
59    }
60
61    /// Constructs a new ChatGPT API client with provided API Key, Configuration and Reqwest proxy
62    pub fn new_with_config_proxy<S: Into<String>>(
63        api_key: S,
64        config: ModelConfiguration,
65        proxy: Proxy,
66    ) -> crate::Result<Self> {
67        let api_key = api_key.into();
68        let mut headers = HeaderMap::new();
69        headers.insert(
70            AUTHORIZATION,
71            HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
72        );
73
74        let client = reqwest::ClientBuilder::new()
75            .default_headers(headers)
76            .timeout(config.timeout)
77            .proxy(proxy)
78            .build()?;
79        Ok(Self { client, config })
80    }
81    /// Restores a conversation from local conversation JSON file.
82    /// The conversation file can originally be saved using the [`Conversation::save_history_json()`].
83    #[cfg(feature = "json")]
84    pub async fn restore_conversation_json<P: AsRef<Path>>(
85        &self,
86        file: P,
87    ) -> crate::Result<Conversation> {
88        let path = file.as_ref();
89        if !path.exists() {
90            return Err(crate::err::Error::ParsingError(
91                "Conversation history JSON file does not exist".to_string(),
92            ));
93        }
94        let mut file = File::open(path).await?;
95        let mut buf = String::new();
96        file.read_to_string(&mut buf).await?;
97        Ok(Conversation::new_with_history(
98            self.clone(),
99            serde_json::from_str(&buf)?,
100        ))
101    }
102
103    /// Restores a conversation from local conversation postcard file.
104    /// The conversation file can originally be saved using the [`Conversation::save_history_postcard()`].
105    #[cfg(feature = "postcard")]
106    pub async fn restore_conversation_postcard<P: AsRef<Path>>(
107        &self,
108        file: P,
109    ) -> crate::Result<Conversation> {
110        let path = file.as_ref();
111        if !path.exists() {
112            return Err(crate::err::Error::ParsingError(
113                "Conversation history Postcard file does not exist".to_string(),
114            ));
115        }
116        let mut file = File::open(path).await?;
117        let mut buf = Vec::new();
118        file.read_to_end(&mut buf).await?;
119        Ok(Conversation::new_with_history(
120            self.clone(),
121            postcard::from_bytes(&buf)?,
122        ))
123    }
124
125    /// Starts a new conversation with a default starting message.
126    ///
127    /// Conversations record message history.
128    pub fn new_conversation(&self) -> Conversation {
129        self.new_conversation_directed(
130            "You are ChatGPT, an AI model developed by OpenAI. Answer as concisely as possible."
131                .to_string(),
132        )
133    }
134
135    /// Starts a new conversation with a specified starting message.
136    ///
137    /// Conversations record message history.
138    pub fn new_conversation_directed<S: Into<String>>(&self, direction_message: S) -> Conversation {
139        Conversation::new(self.clone(), direction_message.into())
140    }
141
142    /// Explicitly sends whole message history to the API.
143    ///
144    /// In most cases, if you would like to store message history, you should be looking at the [`Conversation`] struct, and
145    /// [`Self::new_conversation()`] and [`Self::new_conversation_directed()`]
146    pub async fn send_history(
147        &self,
148        history: &Vec<ChatMessage>,
149    ) -> crate::Result<CompletionResponse> {
150        let response: ServerResponse = self
151            .client
152            .post(self.config.api_url.clone())
153            .json(&CompletionRequest {
154                model: self.config.engine.as_ref(),
155                messages: history,
156                stream: false,
157                temperature: self.config.temperature,
158                top_p: self.config.top_p,
159                max_tokens: self.config.max_tokens,
160                frequency_penalty: self.config.frequency_penalty,
161                presence_penalty: self.config.presence_penalty,
162                reply_count: self.config.reply_count,
163                #[cfg(feature = "functions")]
164                functions: &Vec::new(),
165            })
166            .send()
167            .await?
168            .json()
169            .await?;
170        match response {
171            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
172                message: error.message,
173                error_type: error.error_type,
174            }),
175            ServerResponse::Completion(completion) => Ok(completion),
176        }
177    }
178
179    /// Explicitly sends whole message history to the API and returns the response as stream. **Stream will be empty** if
180    /// any errors are returned from the server.
181    ///
182    /// In most cases, if you would like to store message history, you should be looking at the [`Conversation`] struct, and
183    /// [`Self::new_conversation()`] and [`Self::new_conversation_directed()`]
184    ///
185    /// Requires the `streams` crate feature
186    #[cfg(feature = "streams")]
187    pub async fn send_history_streaming(
188        &self,
189        history: &Vec<ChatMessage>,
190    ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
191        let response = self
192            .client
193            .post(self.config.api_url.clone())
194            .json(&CompletionRequest {
195                model: self.config.engine.as_ref(),
196                stream: true,
197                messages: history,
198                temperature: self.config.temperature,
199                top_p: self.config.top_p,
200                max_tokens: self.config.max_tokens,
201                frequency_penalty: self.config.frequency_penalty,
202                presence_penalty: self.config.presence_penalty,
203                reply_count: self.config.reply_count,
204                #[cfg(feature = "functions")]
205                functions: &Vec::new(),
206            })
207            .send()
208            .await?;
209
210        Self::process_streaming_response(response)
211    }
212
213    /// Sends a single message to the API without preserving message history.
214    pub async fn send_message<S: Into<String>>(
215        &self,
216        message: S,
217    ) -> crate::Result<CompletionResponse> {
218        let response: ServerResponse = self
219            .client
220            .post(self.config.api_url.clone())
221            .json(&CompletionRequest {
222                model: self.config.engine.as_ref(),
223                messages: &vec![ChatMessage {
224                    role: Role::User,
225                    content: message.into(),
226                    #[cfg(feature = "functions")]
227                    function_call: None,
228                }],
229                stream: false,
230                temperature: self.config.temperature,
231                top_p: self.config.top_p,
232                max_tokens: self.config.max_tokens,
233                frequency_penalty: self.config.frequency_penalty,
234                presence_penalty: self.config.presence_penalty,
235                reply_count: self.config.reply_count,
236                #[cfg(feature = "functions")]
237                functions: &Vec::new(),
238            })
239            .send()
240            .await?
241            .json()
242            .await?;
243        match response {
244            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
245                message: error.message,
246                error_type: error.error_type,
247            }),
248            ServerResponse::Completion(completion) => Ok(completion),
249        }
250    }
251
252    /// Sends a single message to the API, and returns the response as stream, without preserving message history. **Stream will be empty** if
253    /// any errors are returned from the server.
254    ///
255    /// Requires the `streams` crate feature
256    #[cfg(feature = "streams")]
257    pub async fn send_message_streaming<S: Into<String>>(
258        &self,
259        message: S,
260    ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
261        let response = self
262            .client
263            .post(self.config.api_url.clone())
264            .json(&CompletionRequest {
265                model: self.config.engine.as_ref(),
266                messages: &vec![ChatMessage {
267                    role: Role::User,
268                    content: message.into(),
269                    #[cfg(feature = "functions")]
270                    function_call: None,
271                }],
272                stream: true,
273                temperature: self.config.temperature,
274                top_p: self.config.top_p,
275                max_tokens: self.config.max_tokens,
276                frequency_penalty: self.config.frequency_penalty,
277                presence_penalty: self.config.presence_penalty,
278                reply_count: self.config.reply_count,
279                #[cfg(feature = "functions")]
280                functions: &Vec::new(),
281            })
282            .send()
283            .await?;
284
285        Self::process_streaming_response(response)
286    }
287
288    #[cfg(feature = "streams")]
289    fn process_streaming_response(
290        response: Response,
291    ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
292        use core::str;
293
294        use futures_util::StreamExt;
295
296        // also handles errors
297        response
298            .error_for_status()
299            .map(|response| response.bytes_stream())
300            .map(|stream| {
301                let mut unparsed = "".to_string();
302                stream.map(move |part| {
303                    let unwrapped_bytes = match part {
304                        Ok(received_bytes) => received_bytes,
305                        Err(err) => {
306                            return vec![crate::Result::Err(
307                                crate::err::Error::ClientError(err),
308                            )]
309                        }
310                    };
311                    let parsed_bytes = match str::from_utf8(&unwrapped_bytes) {
312                        Ok(parsed_bytes) => parsed_bytes,
313                        Err(parse_error) => {
314                            return vec![crate::Result::Err(
315                                crate::err::Error::ParsingError(format!("{}", parse_error)),
316                            )]
317                        }
318                    };
319                    let mut unparsed_for_iteration = unparsed.clone();
320                    let mut content_to_iterate = parsed_bytes;
321                    if !unparsed.is_empty() {
322                        unparsed_for_iteration += content_to_iterate;
323                        content_to_iterate = &unparsed_for_iteration;
324                        unparsed = "".to_string();
325                    }
326                    let mut response_chunks: Vec<ResponseChunk> = vec![];
327                    for chunk in content_to_iterate.split_inclusive("\n\n").filter_map(|line| line.strip_prefix("data: ")) {
328                        if chunk.is_empty() {
329                            continue;
330                        }
331                        let parsed_chunk = if let Some(data) = chunk.strip_suffix("\n\n") {
332                            if data == "[DONE]" {
333                                ResponseChunk::Done
334                            } else {
335                            let parsed_data: InboundResponseChunk = serde_json::from_str(chunk)
336                                .unwrap_or_else(|_| {
337                                    panic!("Invalid inbound streaming response payload: {}. Total err: {:#?}", chunk, unwrapped_bytes)
338                                });
339                            let choice = parsed_data.choices[0].to_owned();
340                            match choice.delta {
341                                InboundChunkPayload::AnnounceRoles { role } => {
342                                    ResponseChunk::BeginResponse {
343                                        role,
344                                        response_index: choice.index,
345                                    }
346                                }
347                                InboundChunkPayload::StreamContent { content } => {
348                                    ResponseChunk::Content {
349                                        delta: content,
350                                        response_index: choice.index,
351                                    }
352                                }
353                                InboundChunkPayload::Close {} => ResponseChunk::CloseResponse {
354                                    response_index: choice.index,
355                                },
356                            }
357                            }
358                        } else {
359                            unparsed = chunk.to_owned();
360                            break;
361                        };
362                        response_chunks.push(parsed_chunk);
363                    }
364
365                    response_chunks
366                        .into_iter()
367                        .map(crate::Result::Ok)
368                        .collect::<Vec<crate::Result<ResponseChunk>>>()
369                })
370                .flat_map(|results| {
371                    futures::stream::iter(results)
372                })
373            })
374            .map_err(crate::err::Error::from)
375    }
376
377    /// Sends a message with specified function descriptors. ChatGPT is then able to call these functions.
378    ///
379    /// **NOTE**: Functions are processed [as tokens on the backend](https://platform.openai.com/docs/guides/gpt/function-calling),
380    /// so you might want to limit the amount of functions or their description.
381    #[cfg(feature = "functions")]
382    pub async fn send_message_functions<S: Into<String>, A: FunctionArgument>(
383        &self,
384        message: S,
385        functions: Vec<FunctionDescriptor<A>>,
386    ) -> crate::Result<CompletionResponse> {
387        self.send_message_functions_baked(
388            message,
389            functions
390                .into_iter()
391                .map(serde_json::to_value)
392                .collect::<serde_json::Result<Vec<serde_json::Value>>>()
393                .map_err(crate::err::Error::from)?,
394        )
395        .await
396    }
397
398    /// Sends a message with specified pre-baked function descriptors. ChatGPT is then able to call these functions.
399    ///
400    /// **NOTE**: Functions are processed [as tokens on the backend](https://platform.openai.com/docs/guides/gpt/function-calling),
401    /// so you might want to limit the amount of functions or their description.
402    #[cfg(feature = "functions")]
403    pub async fn send_message_functions_baked<S: Into<String>>(
404        &self,
405        message: S,
406        baked_functions: Vec<serde_json::Value>,
407    ) -> crate::Result<CompletionResponse> {
408        let response: ServerResponse = self
409            .client
410            .post(self.config.api_url.clone())
411            .json(&CompletionRequest {
412                model: self.config.engine.as_ref(),
413                messages: &vec![ChatMessage {
414                    role: Role::User,
415                    content: message.into(),
416                    #[cfg(feature = "functions")]
417                    function_call: None,
418                }],
419                stream: false,
420                temperature: self.config.temperature,
421                top_p: self.config.top_p,
422                frequency_penalty: self.config.frequency_penalty,
423                presence_penalty: self.config.presence_penalty,
424                reply_count: self.config.reply_count,
425                max_tokens: self.config.max_tokens,
426                #[cfg(feature = "functions")]
427                functions: &baked_functions,
428            })
429            .send()
430            .await?
431            .json()
432            .await?;
433
434        match response {
435            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
436                message: error.message,
437                error_type: error.error_type,
438            }),
439            ServerResponse::Completion(completion) => Ok(completion),
440        }
441    }
442
443    /// Sends whole message history alongside with defined baked functions.
444    #[cfg(feature = "functions")]
445    pub async fn send_history_functions(
446        &self,
447        history: &Vec<ChatMessage>,
448        functions: &Vec<serde_json::Value>,
449    ) -> crate::Result<CompletionResponse> {
450        let response: ServerResponse = self
451            .client
452            .post(self.config.api_url.clone())
453            .json(&CompletionRequest {
454                model: self.config.engine.as_ref(),
455                messages: history,
456                stream: false,
457                temperature: self.config.temperature,
458                top_p: self.config.top_p,
459                frequency_penalty: self.config.frequency_penalty,
460                presence_penalty: self.config.presence_penalty,
461                reply_count: self.config.reply_count,
462                max_tokens: self.config.max_tokens,
463                functions,
464            })
465            .send()
466            .await?
467            .json()
468            .await?;
469        match response {
470            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
471                message: error.message,
472                error_type: error.error_type,
473            }),
474            ServerResponse::Completion(completion) => Ok(completion),
475        }
476    }
477}