1use std::path::Path;
2
3use reqwest::header::AUTHORIZATION;
4use reqwest::header::{HeaderMap, HeaderValue};
5use reqwest::{self, Proxy};
6use tokio::fs::File;
7use tokio::io::AsyncReadExt;
8
9#[cfg(feature = "streams")]
10use reqwest::Response;
11#[cfg(feature = "streams")]
12use {
13 crate::types::InboundChunkPayload, crate::types::InboundResponseChunk,
14 crate::types::ResponseChunk, futures_util::Stream,
15};
16
17use crate::config::ModelConfiguration;
18use crate::converse::Conversation;
19use crate::types::{ChatMessage, CompletionRequest, CompletionResponse, Role, ServerResponse};
20
21#[cfg(feature = "functions")]
22use crate::functions::{FunctionArgument, FunctionDescriptor};
23
24#[derive(Debug, Clone)]
26pub struct ChatGPT {
27 client: reqwest::Client,
28 pub config: ModelConfiguration,
30}
31
32impl ChatGPT {
33 pub fn new<S: Into<String>>(api_key: S) -> crate::Result<Self> {
35 Self::new_with_config(api_key, ModelConfiguration::default())
36 }
37
38 pub fn new_with_proxy<S: Into<String>>(api_key: S, proxy: Proxy) -> crate::Result<Self> {
40 Self::new_with_config_proxy(api_key, ModelConfiguration::default(), proxy)
41 }
42
43 pub fn new_with_config<S: Into<String>>(
45 api_key: S,
46 config: ModelConfiguration,
47 ) -> crate::Result<Self> {
48 let api_key = api_key.into();
49 let mut headers = HeaderMap::new();
50 headers.insert(
51 AUTHORIZATION,
52 HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
53 );
54 let client = reqwest::ClientBuilder::new()
55 .default_headers(headers)
56 .timeout(config.timeout)
57 .build()?;
58 Ok(Self { client, config })
59 }
60
61 pub fn new_with_config_proxy<S: Into<String>>(
63 api_key: S,
64 config: ModelConfiguration,
65 proxy: Proxy,
66 ) -> crate::Result<Self> {
67 let api_key = api_key.into();
68 let mut headers = HeaderMap::new();
69 headers.insert(
70 AUTHORIZATION,
71 HeaderValue::from_bytes(format!("Bearer {api_key}").as_bytes())?,
72 );
73
74 let client = reqwest::ClientBuilder::new()
75 .default_headers(headers)
76 .timeout(config.timeout)
77 .proxy(proxy)
78 .build()?;
79 Ok(Self { client, config })
80 }
81 #[cfg(feature = "json")]
84 pub async fn restore_conversation_json<P: AsRef<Path>>(
85 &self,
86 file: P,
87 ) -> crate::Result<Conversation> {
88 let path = file.as_ref();
89 if !path.exists() {
90 return Err(crate::err::Error::ParsingError(
91 "Conversation history JSON file does not exist".to_string(),
92 ));
93 }
94 let mut file = File::open(path).await?;
95 let mut buf = String::new();
96 file.read_to_string(&mut buf).await?;
97 Ok(Conversation::new_with_history(
98 self.clone(),
99 serde_json::from_str(&buf)?,
100 ))
101 }
102
103 #[cfg(feature = "postcard")]
106 pub async fn restore_conversation_postcard<P: AsRef<Path>>(
107 &self,
108 file: P,
109 ) -> crate::Result<Conversation> {
110 let path = file.as_ref();
111 if !path.exists() {
112 return Err(crate::err::Error::ParsingError(
113 "Conversation history Postcard file does not exist".to_string(),
114 ));
115 }
116 let mut file = File::open(path).await?;
117 let mut buf = Vec::new();
118 file.read_to_end(&mut buf).await?;
119 Ok(Conversation::new_with_history(
120 self.clone(),
121 postcard::from_bytes(&buf)?,
122 ))
123 }
124
125 pub fn new_conversation(&self) -> Conversation {
129 self.new_conversation_directed(
130 "You are ChatGPT, an AI model developed by OpenAI. Answer as concisely as possible."
131 .to_string(),
132 )
133 }
134
135 pub fn new_conversation_directed<S: Into<String>>(&self, direction_message: S) -> Conversation {
139 Conversation::new(self.clone(), direction_message.into())
140 }
141
142 pub async fn send_history(
147 &self,
148 history: &Vec<ChatMessage>,
149 ) -> crate::Result<CompletionResponse> {
150 let response: ServerResponse = self
151 .client
152 .post(self.config.api_url.clone())
153 .json(&CompletionRequest {
154 model: self.config.engine.as_ref(),
155 messages: history,
156 stream: false,
157 temperature: self.config.temperature,
158 top_p: self.config.top_p,
159 max_tokens: self.config.max_tokens,
160 frequency_penalty: self.config.frequency_penalty,
161 presence_penalty: self.config.presence_penalty,
162 reply_count: self.config.reply_count,
163 #[cfg(feature = "functions")]
164 functions: &Vec::new(),
165 })
166 .send()
167 .await?
168 .json()
169 .await?;
170 match response {
171 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
172 message: error.message,
173 error_type: error.error_type,
174 }),
175 ServerResponse::Completion(completion) => Ok(completion),
176 }
177 }
178
179 #[cfg(feature = "streams")]
187 pub async fn send_history_streaming(
188 &self,
189 history: &Vec<ChatMessage>,
190 ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
191 let response = self
192 .client
193 .post(self.config.api_url.clone())
194 .json(&CompletionRequest {
195 model: self.config.engine.as_ref(),
196 stream: true,
197 messages: history,
198 temperature: self.config.temperature,
199 top_p: self.config.top_p,
200 max_tokens: self.config.max_tokens,
201 frequency_penalty: self.config.frequency_penalty,
202 presence_penalty: self.config.presence_penalty,
203 reply_count: self.config.reply_count,
204 #[cfg(feature = "functions")]
205 functions: &Vec::new(),
206 })
207 .send()
208 .await?;
209
210 Self::process_streaming_response(response)
211 }
212
213 pub async fn send_message<S: Into<String>>(
215 &self,
216 message: S,
217 ) -> crate::Result<CompletionResponse> {
218 let response: ServerResponse = self
219 .client
220 .post(self.config.api_url.clone())
221 .json(&CompletionRequest {
222 model: self.config.engine.as_ref(),
223 messages: &vec![ChatMessage {
224 role: Role::User,
225 content: message.into(),
226 #[cfg(feature = "functions")]
227 function_call: None,
228 }],
229 stream: false,
230 temperature: self.config.temperature,
231 top_p: self.config.top_p,
232 max_tokens: self.config.max_tokens,
233 frequency_penalty: self.config.frequency_penalty,
234 presence_penalty: self.config.presence_penalty,
235 reply_count: self.config.reply_count,
236 #[cfg(feature = "functions")]
237 functions: &Vec::new(),
238 })
239 .send()
240 .await?
241 .json()
242 .await?;
243 match response {
244 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
245 message: error.message,
246 error_type: error.error_type,
247 }),
248 ServerResponse::Completion(completion) => Ok(completion),
249 }
250 }
251
252 #[cfg(feature = "streams")]
257 pub async fn send_message_streaming<S: Into<String>>(
258 &self,
259 message: S,
260 ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
261 let response = self
262 .client
263 .post(self.config.api_url.clone())
264 .json(&CompletionRequest {
265 model: self.config.engine.as_ref(),
266 messages: &vec![ChatMessage {
267 role: Role::User,
268 content: message.into(),
269 #[cfg(feature = "functions")]
270 function_call: None,
271 }],
272 stream: true,
273 temperature: self.config.temperature,
274 top_p: self.config.top_p,
275 max_tokens: self.config.max_tokens,
276 frequency_penalty: self.config.frequency_penalty,
277 presence_penalty: self.config.presence_penalty,
278 reply_count: self.config.reply_count,
279 #[cfg(feature = "functions")]
280 functions: &Vec::new(),
281 })
282 .send()
283 .await?;
284
285 Self::process_streaming_response(response)
286 }
287
288 #[cfg(feature = "streams")]
289 fn process_streaming_response(
290 response: Response,
291 ) -> crate::Result<impl Stream<Item = crate::Result<ResponseChunk>>> {
292 use core::str;
293
294 use futures_util::StreamExt;
295
296 response
298 .error_for_status()
299 .map(|response| response.bytes_stream())
300 .map(|stream| {
301 let mut unparsed = "".to_string();
302 stream.map(move |part| {
303 let unwrapped_bytes = match part {
304 Ok(received_bytes) => received_bytes,
305 Err(err) => {
306 return vec![crate::Result::Err(
307 crate::err::Error::ClientError(err),
308 )]
309 }
310 };
311 let parsed_bytes = match str::from_utf8(&unwrapped_bytes) {
312 Ok(parsed_bytes) => parsed_bytes,
313 Err(parse_error) => {
314 return vec![crate::Result::Err(
315 crate::err::Error::ParsingError(format!("{}", parse_error)),
316 )]
317 }
318 };
319 let mut unparsed_for_iteration = unparsed.clone();
320 let mut content_to_iterate = parsed_bytes;
321 if !unparsed.is_empty() {
322 unparsed_for_iteration += content_to_iterate;
323 content_to_iterate = &unparsed_for_iteration;
324 unparsed = "".to_string();
325 }
326 let mut response_chunks: Vec<ResponseChunk> = vec![];
327 for chunk in content_to_iterate.split_inclusive("\n\n").filter_map(|line| line.strip_prefix("data: ")) {
328 if chunk.is_empty() {
329 continue;
330 }
331 let parsed_chunk = if let Some(data) = chunk.strip_suffix("\n\n") {
332 if data == "[DONE]" {
333 ResponseChunk::Done
334 } else {
335 let parsed_data: InboundResponseChunk = serde_json::from_str(chunk)
336 .unwrap_or_else(|_| {
337 panic!("Invalid inbound streaming response payload: {}. Total err: {:#?}", chunk, unwrapped_bytes)
338 });
339 let choice = parsed_data.choices[0].to_owned();
340 match choice.delta {
341 InboundChunkPayload::AnnounceRoles { role } => {
342 ResponseChunk::BeginResponse {
343 role,
344 response_index: choice.index,
345 }
346 }
347 InboundChunkPayload::StreamContent { content } => {
348 ResponseChunk::Content {
349 delta: content,
350 response_index: choice.index,
351 }
352 }
353 InboundChunkPayload::Close {} => ResponseChunk::CloseResponse {
354 response_index: choice.index,
355 },
356 }
357 }
358 } else {
359 unparsed = chunk.to_owned();
360 break;
361 };
362 response_chunks.push(parsed_chunk);
363 }
364
365 response_chunks
366 .into_iter()
367 .map(crate::Result::Ok)
368 .collect::<Vec<crate::Result<ResponseChunk>>>()
369 })
370 .flat_map(|results| {
371 futures::stream::iter(results)
372 })
373 })
374 .map_err(crate::err::Error::from)
375 }
376
377 #[cfg(feature = "functions")]
382 pub async fn send_message_functions<S: Into<String>, A: FunctionArgument>(
383 &self,
384 message: S,
385 functions: Vec<FunctionDescriptor<A>>,
386 ) -> crate::Result<CompletionResponse> {
387 self.send_message_functions_baked(
388 message,
389 functions
390 .into_iter()
391 .map(serde_json::to_value)
392 .collect::<serde_json::Result<Vec<serde_json::Value>>>()
393 .map_err(crate::err::Error::from)?,
394 )
395 .await
396 }
397
398 #[cfg(feature = "functions")]
403 pub async fn send_message_functions_baked<S: Into<String>>(
404 &self,
405 message: S,
406 baked_functions: Vec<serde_json::Value>,
407 ) -> crate::Result<CompletionResponse> {
408 let response: ServerResponse = self
409 .client
410 .post(self.config.api_url.clone())
411 .json(&CompletionRequest {
412 model: self.config.engine.as_ref(),
413 messages: &vec![ChatMessage {
414 role: Role::User,
415 content: message.into(),
416 #[cfg(feature = "functions")]
417 function_call: None,
418 }],
419 stream: false,
420 temperature: self.config.temperature,
421 top_p: self.config.top_p,
422 frequency_penalty: self.config.frequency_penalty,
423 presence_penalty: self.config.presence_penalty,
424 reply_count: self.config.reply_count,
425 max_tokens: self.config.max_tokens,
426 #[cfg(feature = "functions")]
427 functions: &baked_functions,
428 })
429 .send()
430 .await?
431 .json()
432 .await?;
433
434 match response {
435 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
436 message: error.message,
437 error_type: error.error_type,
438 }),
439 ServerResponse::Completion(completion) => Ok(completion),
440 }
441 }
442
443 #[cfg(feature = "functions")]
445 pub async fn send_history_functions(
446 &self,
447 history: &Vec<ChatMessage>,
448 functions: &Vec<serde_json::Value>,
449 ) -> crate::Result<CompletionResponse> {
450 let response: ServerResponse = self
451 .client
452 .post(self.config.api_url.clone())
453 .json(&CompletionRequest {
454 model: self.config.engine.as_ref(),
455 messages: history,
456 stream: false,
457 temperature: self.config.temperature,
458 top_p: self.config.top_p,
459 frequency_penalty: self.config.frequency_penalty,
460 presence_penalty: self.config.presence_penalty,
461 reply_count: self.config.reply_count,
462 max_tokens: self.config.max_tokens,
463 functions,
464 })
465 .send()
466 .await?
467 .json()
468 .await?;
469 match response {
470 ServerResponse::Error { error } => Err(crate::err::Error::BackendError {
471 message: error.message,
472 error_type: error.error_type,
473 }),
474 ServerResponse::Completion(completion) => Ok(completion),
475 }
476 }
477}