ru_openai/
lib.rs

1use std::pin::Pin;
2use futures::{stream::StreamExt, Stream};
3use reqwest::StatusCode;
4use serde_derive::{Deserialize, Serialize};
5use reqwest_eventsource::{EventSource, Event};
6use serde::de::DeserializeOwned;
7
8pub mod api;
9pub mod configuration;
10
11#[derive(Deserialize, Serialize, Debug)]
12pub struct ErrorInfo {
13    pub message: String,
14    #[serde(rename = "type")]
15    pub message_type: String,
16    pub param: Option<String>,
17    pub code: Option<String>,
18}
19
20#[derive(Deserialize, Serialize, Debug)]
21pub struct ReturnErrorType {
22    pub error: ErrorInfo,
23}
24
25#[derive(Deserialize, Serialize, Debug)]
26pub struct OpenAIApiError {
27    pub code: i32,
28    pub error: ErrorInfo,
29}
30
31impl OpenAIApiError {
32    pub fn new(code: i32, error: ErrorInfo) -> Self {
33        Self { code, error }
34    }
35
36    pub fn from(error: reqwest::Error) -> Self {
37        let code = error.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR).as_u16() as i32;
38        let error = ErrorInfo {
39            message: error.to_string(),
40            message_type: "request error".to_string(),
41            param: None,
42            code: None,
43        };
44        Self::new(code, error)
45    }
46}
47
48pub type Error = reqwest::Error;
49
50pub(crate) async fn stream<O>(
51    mut event_source: EventSource,
52) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIApiError>> + Send>>
53where
54    O: DeserializeOwned + std::marker::Send + 'static,
55{
56    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
57
58    tokio::spawn(async move {
59        while let Some(ev) = event_source.next().await {
60            match ev {
61                Err(e) => {
62                    // println!("{:?}", e);
63                    if let Err(_e) = tx.send(Err(OpenAIApiError::new(
64                        StatusCode::INTERNAL_SERVER_ERROR.as_u16() as i32,
65                        ErrorInfo {
66                            message: e.to_string(),
67                            message_type: "request error".to_string(),
68                            param: None,
69                            code: None,
70                        },
71                    ))) {
72                        // rx dropped
73                        break;
74                    }
75                }
76                Ok(event) => match event {
77                    Event::Message(message) => {
78                        // println!("{:?}", message);
79                        if message.data == "[DONE]" {
80                            break;
81                        }
82
83                        let response = match serde_json::from_str::<O>(&message.data) {
84                            Err(e) => {
85                                // Err(map_deserialization_error(e, &message.data.as_bytes()))
86
87                                Err(OpenAIApiError::new(
88                                    StatusCode::INTERNAL_SERVER_ERROR.as_u16() as i32,
89                                    ErrorInfo {
90                                        message: e.to_string(),
91                                        message_type: "deserialization error".to_string(),
92                                        param: None,
93                                        code: None,
94                                    },
95                                ))
96                            }
97                            Ok(output) => Ok(output),
98                        };
99
100                        if let Err(_e) = tx.send(response) {
101                            // rx dropped
102                            break;
103                        }
104                    }
105                    Event::Open => continue,
106                },
107            }
108        }
109
110        event_source.close();
111    });
112
113    Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
114}