1use std::path::Path;
2
3use reqwest::header::AUTHORIZATION;
4use reqwest::header::{HeaderMap, HeaderValue};
5use reqwest::{self, Proxy};
6use tokio::fs::File;
7use tokio::io::AsyncReadExt;
8
9#[cfg(feature = "streams")]
10use reqwest::Response;
11#[cfg(feature = "streams")]
12use {
13 crate::types::InboundChunkPayload, crate::types::InboundResponseChunk,
14 crate::types::ResponseChunk, futures_util::Stream,
15};
16
17use crate::config::ModelConfiguration;
18use crate::converse::Conversation;
19use crate::types::{ChatMessage, CompletionRequest, CompletionResponse, Role, ServerResponse};
20
21#[cfg(feature = "functions")]
22use crate::functions::{FunctionArgument, FunctionDescriptor};
23
24#[derive(Debug, Clone)]
26pub struct ChatGPT {
27 client: reqwest::Client,
28 pub config: ModelConfiguration,
30}
31
32impl ChatGPT {
33 pub fn new<S: Into<String>>(api_key: S) -> crate::Result<Self> {
35 Self::new_with_config(api_key, ModelConfiguration::default())
36 }
37
38 pub fn new_with_proxy<S: Into<String>>(api_key: S, proxy: Proxy) -> crate::Result<Self> {
40 Self::new_with_config_proxy(api_key, ModelConfiguration::default(), proxy)
41 }
42
43 pub fn new_with_config<S: Into<String>>(
45 api_key: S,
46 config: ModelConfiguration,
47 ) -> crate::Result<Self> {
48 let api_key = api_key.into();
49 let mut headers = HeaderMap::new();
50 headers.insert(
51 AUTHORIZATION,
52 HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
53 );
54 let client = reqwest::ClientBuilder::new()
55 .default_headers(headers)
56 .timeout(config.timeout)
57 .build()?;
58 Ok(Self { client, config })
59 }
60
61 pub fn new_with_config_proxy<S: Into<String>>(
63 api_key: S,
64 config: ModelConfiguration,
65 proxy: Proxy,
66 ) -> crate::Result<Self> {
67 let api_key = api_key.into();
68 let mut headers = HeaderMap::new();
69 headers.insert(
70 AUTHORIZATION,
71 HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
72 );
73
74 let client = reqwest::ClientBuilder::new()
75 .default_headers(headers)
76 .timeout(config.timeout)
77 .proxy(proxy)
78 .build()?;
79 Ok(Self { client, config })
80 }
81 #[cfg(feature = "json")]
84 pub async fn restore_conversation_json<P: AsRef<Path>>(
85 &self,
86 file: P,
87 ) -> crate::Result<Conversation> {
88 let path = file.as_ref();
89 if !path.exists() {
90 return Err(crate::err::Error::ParsingError(
91 "Conversation history JSON file does not exist".to_string(),
92 ));
93 }
94 let mut file = File::open(path).await?;
95 let mut buf = String::new();
96 file.read_to_string(&mut buf).await?;
97 Ok(Conversation::new_with_history(
98 self.clone(),
99 serde_json::from_str(&buf)?,
100 ))
101 }
102
103 #[cfg(feature = "postcard")]
106 pub async fn restore_conversation_postcard<P: AsRef<Path>>(
107 &self,
108 file: P,
109 ) -> crate::Result<Conversation> {
110 let path = file.as_ref();
111 if !path.exists() {
112 return Err(crate::err::Error::ParsingError(
113 "Conversation history Postcard file does not exist".to_string(),
114 ));
115 }
116 let mut file = File::open(path).await?;
117 let mut buf = Vec::new();
118 file.read_to_end(&mut buf).await?;
119 Ok(Conversation::new_with_history(
120 self.clone(),
121 postcard::from_bytes(&buf)?,
122 ))
123 }
124
125 pub fn new_conversation(&self) -> Conversation {
129 self.new_conversation_directed(
130 "You are ChatGPT, an AI model developed by OpenAI. Answer as concisely as possible."
131 .to_string(),
132 )
133 }
134
135 pub fn new_conversation_directed<S: Into<String>>(&self, direction_message: S) -> Conversation {
139 Conversation::new(self.clone(), direction_message.into())
140 }
141
142 pub async fn send_history(
147 &self,
148 history: &Vec<ChatMessage>,
149 ) -> crate::Result<CompletionResponse> {
150 let response: ServerResponse = self
151 .client
152 .post(self.config.api_url.clone())
153 .json(&CompletionRequest {
154 model: self.config.engine.as_ref(),
155 messages: history,
156 stream: false,
157 temperature: self.config.temperature,
158 top_p: self.config.top_p,
159 max_tokens: self.config.max_tokens,
160 frequency_penalty: self.config.frequency_penalty,
161 presence_penalty: self.config.presence_penalty,
162 reply_count: self.config.reply_count,
163 #[cfg(feature = "functions")]
164 functions: &Vec::new(),
165 })
166 .send()
167 .await?
168 .json()
169 .await?;
170 match response {
171 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
172 message: error.message,
173 error_type: error.error_type,
174 }),
175 ServerResponse::Completion(completion) => Ok(completion),
176 }
177 }
178
179 #[cfg(feature = "streams")]
187 pub async fn send_history_streaming(
188 &self,
189 history: &Vec<ChatMessage>,
190 ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
191 let response = self
192 .client
193 .post(self.config.api_url.clone())
194 .json(&CompletionRequest {
195 model: self.config.engine.as_ref(),
196 stream: true,
197 messages: history,
198 temperature: self.config.temperature,
199 top_p: self.config.top_p,
200 max_tokens: self.config.max_tokens,
201 frequency_penalty: self.config.frequency_penalty,
202 presence_penalty: self.config.presence_penalty,
203 reply_count: self.config.reply_count,
204 #[cfg(feature = "functions")]
205 functions: &Vec::new(),
206 })
207 .send()
208 .await?;
209
210 Self::process_streaming_response(response)
211 }
212
213 pub async fn send_message<S: Into<String>>(
215 &self,
216 message: S,
217 ) -> crate::Result<CompletionResponse> {
218 let response: ServerResponse = self
219 .client
220 .post(self.config.api_url.clone())
221 .json(&CompletionRequest {
222 model: self.config.engine.as_ref(),
223 messages: &vec![ChatMessage {
224 role: Role::User,
225 content: message.into(),
226 #[cfg(feature = "functions")]
227 function_call: None,
228 }],
229 stream: false,
230 temperature: self.config.temperature,
231 top_p: self.config.top_p,
232 max_tokens: self.config.max_tokens,
233 frequency_penalty: self.config.frequency_penalty,
234 presence_penalty: self.config.presence_penalty,
235 reply_count: self.config.reply_count,
236 #[cfg(feature = "functions")]
237 functions: &Vec::new(),
238 })
239 .send()
240 .await?
241 .json()
242 .await?;
243 match response {
244 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
245 message: error.message,
246 error_type: error.error_type,
247 }),
248 ServerResponse::Completion(completion) => Ok(completion),
249 }
250 }
251
252 #[cfg(feature = "streams")]
257 pub async fn send_message_streaming<S: Into<String>>(
258 &self,
259 message: S,
260 ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
261 let response = self
262 .client
263 .post(self.config.api_url.clone())
264 .json(&CompletionRequest {
265 model: self.config.engine.as_ref(),
266 messages: &vec![ChatMessage {
267 role: Role::User,
268 content: message.into(),
269 #[cfg(feature = "functions")]
270 function_call: None,
271 }],
272 stream: true,
273 temperature: self.config.temperature,
274 top_p: self.config.top_p,
275 max_tokens: self.config.max_tokens,
276 frequency_penalty: self.config.frequency_penalty,
277 presence_penalty: self.config.presence_penalty,
278 reply_count: self.config.reply_count,
279 #[cfg(feature = "functions")]
280 functions: &Vec::new(),
281 })
282 .send()
283 .await?;
284
285 Self::process_streaming_response(response)
286 }
287
288 #[cfg(feature = "streams")]
289 fn process_streaming_response(
290 response: Response,
291 ) -> crate::Result<impl Stream<Item = ResponseChunk>> {
292 use eventsource_stream::Eventsource;
293 use futures_util::StreamExt;
294
295 response
297 .error_for_status()
298 .map(|response| {
299 let response_stream = response.bytes_stream().eventsource();
300 response_stream.map(move |part| {
301 let chunk = &part.expect("Stream closed abruptly!").data;
302 if chunk == "[DONE]" {
303 return ResponseChunk::Done;
304 }
305 let data: InboundResponseChunk = serde_json::from_str(chunk)
306 .expect("Invalid inbound streaming response payload!");
307 let choice = data.choices[0].to_owned();
308 match choice.delta {
309 InboundChunkPayload::AnnounceRoles { role } => {
310 ResponseChunk::BeginResponse {
311 role,
312 response_index: choice.index,
313 }
314 }
315 InboundChunkPayload::StreamContent { content } => ResponseChunk::Content {
316 delta: content,
317 response_index: choice.index,
318 },
319 InboundChunkPayload::Close {} => ResponseChunk::CloseResponse {
320 response_index: choice.index,
321 },
322 }
323 })
324 })
325 .map_err(crate::err::Error::from)
326 }
327
328 #[cfg(feature = "functions")]
333 pub async fn send_message_functions<S: Into<String>, A: FunctionArgument>(
334 &self,
335 message: S,
336 functions: Vec<FunctionDescriptor<A>>,
337 ) -> crate::Result<CompletionResponse> {
338 self.send_message_functions_baked(
339 message,
340 functions
341 .into_iter()
342 .map(serde_json::to_value)
343 .collect::<serde_json::Result<Vec<serde_json::Value>>>()
344 .map_err(crate::err::Error::from)?,
345 )
346 .await
347 }
348
349 #[cfg(feature = "functions")]
354 pub async fn send_message_functions_baked<S: Into<String>>(
355 &self,
356 message: S,
357 baked_functions: Vec<serde_json::Value>,
358 ) -> crate::Result<CompletionResponse> {
359 let response: ServerResponse = self
360 .client
361 .post(self.config.api_url.clone())
362 .json(&CompletionRequest {
363 model: self.config.engine.as_ref(),
364 messages: &vec![ChatMessage {
365 role: Role::User,
366 content: message.into(),
367 #[cfg(feature = "functions")]
368 function_call: None,
369 }],
370 stream: false,
371 temperature: self.config.temperature,
372 top_p: self.config.top_p,
373 frequency_penalty: self.config.frequency_penalty,
374 presence_penalty: self.config.presence_penalty,
375 reply_count: self.config.reply_count,
376 max_tokens: self.config.max_tokens,
377 #[cfg(feature = "functions")]
378 functions: &baked_functions,
379 })
380 .send()
381 .await?
382 .json()
383 .await?;
384
385 match response {
386 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
387 message: error.message,
388 error_type: error.error_type,
389 }),
390 ServerResponse::Completion(completion) => Ok(completion),
391 }
392 }
393
394 #[cfg(feature = "functions")]
396 pub async fn send_history_functions(
397 &self,
398 history: &Vec<ChatMessage>,
399 functions: &Vec<serde_json::Value>,
400 ) -> crate::Result<CompletionResponse> {
401 let response: ServerResponse = self
402 .client
403 .post(self.config.api_url.clone())
404 .json(&CompletionRequest {
405 model: self.config.engine.as_ref(),
406 messages: history,
407 stream: false,
408 temperature: self.config.temperature,
409 top_p: self.config.top_p,
410 frequency_penalty: self.config.frequency_penalty,
411 presence_penalty: self.config.presence_penalty,
412 reply_count: self.config.reply_count,
413 max_tokens: self.config.max_tokens,
414 functions,
415 })
416 .send()
417 .await?
418 .json()
419 .await?;
420 match response {
421 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
422 message: error.message,
423 error_type: error.error_type,
424 }),
425 ServerResponse::Completion(completion) => Ok(completion),
426 }
427 }
428}