chatgpt/
client.rs

1use std::path::Path;
2
3use reqwest::header::AUTHORIZATION;
4use reqwest::header::{HeaderMap, HeaderValue};
5use reqwest::{self, Proxy};
6use tokio::fs::File;
7use tokio::io::AsyncReadExt;
8
9#[cfg(feature = "streams")]
10use reqwest::Response;
11#[cfg(feature = "streams")]
12use {
13    crate::types::InboundChunkPayload, crate::types::InboundResponseChunk,
14    crate::types::ResponseChunk, futures_util::Stream,
15};
16
17use crate::config::ModelConfiguration;
18use crate::converse::Conversation;
19use crate::types::{ChatMessage, CompletionRequest, CompletionResponse, Role, ServerResponse};
20
21#[cfg(feature = "functions")]
22use crate::functions::{FunctionArgument, FunctionDescriptor};
23
24/// The client that operates the ChatGPT API
25#[derive(Debug, Clone)]
26pub struct ChatGPT {
27    client: reqwest::Client,
28    /// The configuration for this ChatGPT client
29    pub config: ModelConfiguration,
30}
31
32impl ChatGPT {
33    /// Constructs a new ChatGPT API client with provided API key and default configuration
34    pub fn new<S: Into<String>>(api_key: S) -> crate::Result<Self> {
35        Self::new_with_config(api_key, ModelConfiguration::default())
36    }
37
38    /// Constructs a new ChatGPT API client with provided API key, default configuration and a reqwest proxy
39    pub fn new_with_proxy<S: Into<String>>(api_key: S, proxy: Proxy) -> crate::Result<Self> {
40        Self::new_with_config_proxy(api_key, ModelConfiguration::default(), proxy)
41    }
42
43    /// Constructs a new ChatGPT API client with provided API Key and Configuration
44    pub fn new_with_config<S: Into<String>>(
45        api_key: S,
46        config: ModelConfiguration,
47    ) -> crate::Result<Self> {
48        let api_key = api_key.into();
49        let mut headers = HeaderMap::new();
50        headers.insert(
51            AUTHORIZATION,
52            HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
53        );
54        let client = reqwest::ClientBuilder::new()
55            .default_headers(headers)
56            .timeout(config.timeout)
57            .build()?;
58        Ok(Self { client, config })
59    }
60
61    /// Constructs a new ChatGPT API client with provided API Key, Configuration and Reqwest proxy
62    pub fn new_with_config_proxy<S: Into<String>>(
63        api_key: S,
64        config: ModelConfiguration,
65        proxy: Proxy,
66    ) -> crate::Result<Self> {
67        let api_key = api_key.into();
68        let mut headers = HeaderMap::new();
69        headers.insert(
70            AUTHORIZATION,
71            HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
72        );
73
74        let client = reqwest::ClientBuilder::new()
75            .default_headers(headers)
76            .timeout(config.timeout)
77            .proxy(proxy)
78            .build()?;
79        Ok(Self { client, config })
80    }
81    /// Restores a conversation from local conversation JSON file.
82    /// The conversation file can originally be saved using the [`Conversation::save_history_json()`].
83    #[cfg(feature = "json")]
84    pub async fn restore_conversation_json<P: AsRef<Path>>(
85        &self,
86        file: P,
87    ) -> crate::Result<Conversation> {
88        let path = file.as_ref();
89        if !path.exists() {
90            return Err(crate::err::Error::ParsingError(
91                "Conversation history JSON file does not exist".to_string(),
92            ));
93        }
94        let mut file = File::open(path).await?;
95        let mut buf = String::new();
96        file.read_to_string(&mut buf).await?;
97        Ok(Conversation::new_with_history(
98            self.clone(),
99            serde_json::from_str(&buf)?,
100        ))
101    }
102
103    /// Restores a conversation from local conversation postcard file.
104    /// The conversation file can originally be saved using the [`Conversation::save_history_postcard()`].
105    #[cfg(feature = "postcard")]
106    pub async fn restore_conversation_postcard<P: AsRef<Path>>(
107        &self,
108        file: P,
109    ) -> crate::Result<Conversation> {
110        let path = file.as_ref();
111        if !path.exists() {
112            return Err(crate::err::Error::ParsingError(
113                "Conversation history Postcard file does not exist".to_string(),
114            ));
115        }
116        let mut file = File::open(path).await?;
117        let mut buf = Vec::new();
118        file.read_to_end(&mut buf).await?;
119        Ok(Conversation::new_with_history(
120            self.clone(),
121            postcard::from_bytes(&buf)?,
122        ))
123    }
124
125    /// Starts a new conversation with a default starting message.
126    ///
127    /// Conversations record message history.
128    pub fn new_conversation(&self) -> Conversation {
129        self.new_conversation_directed(
130            "You are ChatGPT, an AI model developed by OpenAI. Answer as concisely as possible."
131                .to_string(),
132        )
133    }
134
135    /// Starts a new conversation with a specified starting message.
136    ///
137    /// Conversations record message history.
138    pub fn new_conversation_directed<S: Into<String>>(&self, direction_message: S) -> Conversation {
139        Conversation::new(self.clone(), direction_message.into())
140    }
141
142    /// Explicitly sends whole message history to the API.
143    ///
144    /// In most cases, if you would like to store message history, you should be looking at the [`Conversation`] struct, and
145    /// [`Self::new_conversation()`] and [`Self::new_conversation_directed()`]
146    pub async fn send_history(
147        &self,
148        history: &Vec<ChatMessage>,
149    ) -> crate::Result<CompletionResponse> {
150        let response: ServerResponse = self
151            .client
152            .post(self.config.api_url.clone())
153            .json(&CompletionRequest {
154                model: self.config.engine.as_ref(),
155                messages: history,
156                stream: false,
157                temperature: self.config.temperature,
158                top_p: self.config.top_p,
159                max_tokens: self.config.max_tokens,
160                frequency_penalty: self.config.frequency_penalty,
161                presence_penalty: self.config.presence_penalty,
162                reply_count: self.config.reply_count,
163                #[cfg(feature = "functions")]
164                functions: &Vec::new(),
165            })
166            .send()
167            .await?
168            .json()
169            .await?;
170        match response {
171            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
172                message: error.message,
173                error_type: error.error_type,
174            }),
175            ServerResponse::Completion(completion) => Ok(completion),
176        }
177    }
178
179    /// Explicitly sends whole message history to the API and returns the response as stream. **Stream will be empty** if
180    /// any errors are returned from the server.
181    ///
182    /// In most cases, if you would like to store message history, you should be looking at the [`Conversation`] struct, and
183    /// [`Self::new_conversation()`] and [`Self::new_conversation_directed()`]
184    ///
185    /// Requires the `streams` crate feature
186    #[cfg(feature = "streams")]
187    pub async fn send_history_streaming(
188        &self,
189        history: &Vec<ChatMessage>,
190    ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
191        let response = self
192            .client
193            .post(self.config.api_url.clone())
194            .json(&CompletionRequest {
195                model: self.config.engine.as_ref(),
196                stream: true,
197                messages: history,
198                temperature: self.config.temperature,
199                top_p: self.config.top_p,
200                max_tokens: self.config.max_tokens,
201                frequency_penalty: self.config.frequency_penalty,
202                presence_penalty: self.config.presence_penalty,
203                reply_count: self.config.reply_count,
204                #[cfg(feature = "functions")]
205                functions: &Vec::new(),
206            })
207            .send()
208            .await?;
209
210        Self::process_streaming_response(response)
211    }
212
213    /// Sends a single message to the API without preserving message history.
214    pub async fn send_message<S: Into<String>>(
215        &self,
216        message: S,
217    ) -> crate::Result<CompletionResponse> {
218        let response: ServerResponse = self
219            .client
220            .post(self.config.api_url.clone())
221            .json(&CompletionRequest {
222                model: self.config.engine.as_ref(),
223                messages: &vec![ChatMessage {
224                    role: Role::User,
225                    content: message.into(),
226                    #[cfg(feature = "functions")]
227                    function_call: None,
228                }],
229                stream: false,
230                temperature: self.config.temperature,
231                top_p: self.config.top_p,
232                max_tokens: self.config.max_tokens,
233                frequency_penalty: self.config.frequency_penalty,
234                presence_penalty: self.config.presence_penalty,
235                reply_count: self.config.reply_count,
236                #[cfg(feature = "functions")]
237                functions: &Vec::new(),
238            })
239            .send()
240            .await?
241            .json()
242            .await?;
243        match response {
244            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
245                message: error.message,
246                error_type: error.error_type,
247            }),
248            ServerResponse::Completion(completion) => Ok(completion),
249        }
250    }
251
252    /// Sends a single message to the API, and returns the response as stream, without preserving message history. **Stream will be empty** if
253    /// any errors are returned from the server.
254    ///
255    /// Requires the `streams` crate feature
256    #[cfg(feature = "streams")]
257    pub async fn send_message_streaming<S: Into<String>>(
258        &self,
259        message: S,
260    ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
261        let response = self
262            .client
263            .post(self.config.api_url.clone())
264            .json(&CompletionRequest {
265                model: self.config.engine.as_ref(),
266                messages: &vec![ChatMessage {
267                    role: Role::User,
268                    content: message.into(),
269                    #[cfg(feature = "functions")]
270                    function_call: None,
271                }],
272                stream: true,
273                temperature: self.config.temperature,
274                top_p: self.config.top_p,
275                max_tokens: self.config.max_tokens,
276                frequency_penalty: self.config.frequency_penalty,
277                presence_penalty: self.config.presence_penalty,
278                reply_count: self.config.reply_count,
279                #[cfg(feature = "functions")]
280                functions: &Vec::new(),
281            })
282            .send()
283            .await?;
284
285        Self::process_streaming_response(response)
286    }
287
288    #[cfg(feature = "streams")]
289    fn process_streaming_response(
290        response: Response,
291    ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
292        use eventsource_stream::Eventsource;
293        use futures_util::StreamExt;
294
295        // also handles errors
296        response
297            .error_for_status()
298            .map(|response| {
299                let response_stream = response.bytes_stream().eventsource();
300                response_stream.map(move |part| {
301                    let chunk = &part.expect("Stream closed abruptly!").data;
302                    if chunk == "[DONE]" {
303                        return ResponseChunk::Done;
304                    }
305                    let data: InboundResponseChunk = serde_json::from_str(chunk)
306                        .expect("Invalid inbound streaming response payload!");
307                    let choice = data.choices[0].to_owned();
308                    match choice.delta {
309                        InboundChunkPayload::AnnounceRoles { role } => {
310                            ResponseChunk::BeginResponse {
311                                role,
312                                response_index: choice.index,
313                            }
314                        }
315                        InboundChunkPayload::StreamContent { content } => ResponseChunk::Content {
316                            delta: content,
317                            response_index: choice.index,
318                        },
319                        InboundChunkPayload::Close {} => ResponseChunk::CloseResponse {
320                            response_index: choice.index,
321                        },
322                    }
323                })
324            })
325            .map_err(crate::err::Error::from)
326    }
327
328    /// Sends a message with specified function descriptors. ChatGPT is then able to call these functions.
329    ///
330    /// **NOTE**: Functions are processed [as tokens on the backend](https://platform.openai.com/docs/guides/gpt/function-calling),
331    /// so you might want to limit the amount of functions or their description.
332    #[cfg(feature = "functions")]
333    pub async fn send_message_functions<S: Into<String>, A: FunctionArgument>(
334        &self,
335        message: S,
336        functions: Vec<FunctionDescriptor<A>>,
337    ) -> crate::Result<CompletionResponse> {
338        self.send_message_functions_baked(
339            message,
340            functions
341                .into_iter()
342                .map(serde_json::to_value)
343                .collect::<serde_json::Result<Vec<serde_json::Value>>>()
344                .map_err(crate::err::Error::from)?,
345        )
346        .await
347    }
348
349    /// Sends a message with specified pre-baked function descriptors. ChatGPT is then able to call these functions.
350    ///
351    /// **NOTE**: Functions are processed [as tokens on the backend](https://platform.openai.com/docs/guides/gpt/function-calling),
352    /// so you might want to limit the amount of functions or their description.
353    #[cfg(feature = "functions")]
354    pub async fn send_message_functions_baked<S: Into<String>>(
355        &self,
356        message: S,
357        baked_functions: Vec<serde_json::Value>,
358    ) -> crate::Result<CompletionResponse> {
359        let response: ServerResponse = self
360            .client
361            .post(self.config.api_url.clone())
362            .json(&CompletionRequest {
363                model: self.config.engine.as_ref(),
364                messages: &vec![ChatMessage {
365                    role: Role::User,
366                    content: message.into(),
367                    #[cfg(feature = "functions")]
368                    function_call: None,
369                }],
370                stream: false,
371                temperature: self.config.temperature,
372                top_p: self.config.top_p,
373                frequency_penalty: self.config.frequency_penalty,
374                presence_penalty: self.config.presence_penalty,
375                reply_count: self.config.reply_count,
376                max_tokens: self.config.max_tokens,
377                #[cfg(feature = "functions")]
378                functions: &baked_functions,
379            })
380            .send()
381            .await?
382            .json()
383            .await?;
384
385        match response {
386            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
387                message: error.message,
388                error_type: error.error_type,
389            }),
390            ServerResponse::Completion(completion) => Ok(completion),
391        }
392    }
393
394    /// Sends whole message history alongside with defined baked functions.
395    #[cfg(feature = "functions")]
396    pub async fn send_history_functions(
397        &self,
398        history: &Vec<ChatMessage>,
399        functions: &Vec<serde_json::Value>,
400    ) -> crate::Result<CompletionResponse> {
401        let response: ServerResponse = self
402            .client
403            .post(self.config.api_url.clone())
404            .json(&CompletionRequest {
405                model: self.config.engine.as_ref(),
406                messages: history,
407                stream: false,
408                temperature: self.config.temperature,
409                top_p: self.config.top_p,
410                frequency_penalty: self.config.frequency_penalty,
411                presence_penalty: self.config.presence_penalty,
412                reply_count: self.config.reply_count,
413                max_tokens: self.config.max_tokens,
414                functions,
415            })
416            .send()
417            .await?
418            .json()
419            .await?;
420        match response {
421            ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
422                message: error.message,
423                error_type: error.error_type,
424            }),
425            ServerResponse::Completion(completion) => Ok(completion),
426        }
427    }
428}