1use std::pin::Pin;
2use futures::{stream::StreamExt, Stream};
3use reqwest::StatusCode;
4use serde_derive::{Deserialize, Serialize};
5use reqwest_eventsource::{EventSource, Event};
6use serde::de::DeserializeOwned;
7
8pub mod api;
9pub mod configuration;
10
11#[derive(Deserialize, Serialize, Debug)]
12pub struct ErrorInfo {
13 pub message: String,
14 #[serde(rename = "type")]
15 pub message_type: String,
16 pub param: Option<String>,
17 pub code: Option<String>,
18}
19
20#[derive(Deserialize, Serialize, Debug)]
21pub struct ReturnErrorType {
22 pub error: ErrorInfo,
23}
24
25#[derive(Deserialize, Serialize, Debug)]
26pub struct OpenAIApiError {
27 pub code: i32,
28 pub error: ErrorInfo,
29}
30
31impl OpenAIApiError {
32 pub fn new(code: i32, error: ErrorInfo) -> Self {
33 Self { code, error }
34 }
35
36 pub fn from(error: reqwest::Error) -> Self {
37 let code = error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR).as_u16() as i32;
38 let error = ErrorInfo {
39 message: error.to_string(),
40 message_type: "request error".to_string(),
41 param: None,
42 code: None,
43 };
44 Self::new(code, error)
45 }
46}
47
48pub type Error = reqwest::Error;
49
50pub(crate) async fn stream<O>(
51 mut event_source: EventSource,
52) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIApiError>> + Send>>
53where
54 O: DeserializeOwned + std::marker::Send + 'static,
55{
56 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
57
58 tokio::spawn(async move {
59 while let Some(ev) = event_source.next().await {
60 match ev {
61 Err(e) => {
62 if let Err(_e) = tx.send(Err(OpenAIApiError::new(
64 StatusCode::INTERNAL_SERVER_ERROR.as_u16() as i32,
65 ErrorInfo {
66 message: e.to_string(),
67 message_type: "request error".to_string(),
68 param: None,
69 code: None,
70 },
71 ))) {
72 break;
74 }
75 }
76 Ok(event) => match event {
77 Event::Message(message) => {
78 if message.data == "[DONE]" {
80 break;
81 }
82
83 let response = match serde_json::from_str::<O>(&message.data) {
84 Err(e) => {
85 Err(OpenAIApiError::new(
88 StatusCode::INTERNAL_SERVER_ERROR.as_u16() as i32,
89 ErrorInfo {
90 message: e.to_string(),
91 message_type: "deserialization error".to_string(),
92 param: None,
93 code: None,
94 },
95 ))
96 }
97 Ok(output) => Ok(output),
98 };
99
100 if let Err(_e) = tx.send(response) {
101 break;
103 }
104 }
105 Event::Open => continue,
106 },
107 }
108 }
109
110 event_source.close();
111 });
112
113 Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
114}