openai_api_stream_rs/
openai.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures_util::stream::Stream;
5use reqwest::{Client, Response};
6use serde_json::Value;
7
8use serde::{Serialize, Deserialize};
9
10#[derive(Serialize, Deserialize, Default)]
11pub struct GptStreamConfig {
12    model: Option<String>,
13    messages: Vec<Message>,
14    temperature: Option<f64>,
15    top_p: Option<f64>,
16    n: Option<usize>,
17    stream: Option<bool>,
18    presence_penalty: Option<f64>,
19    frequency_penalty: Option<f64>,
20}
21
22#[derive(Serialize, Deserialize)]
23pub struct Message {
24    role: String,
25    content: String,
26}
27
28pub struct OpenAIStream {
29    api_key: String,
30}
31
32pub struct GptStream {
33    response: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
34    buffer: String,
35    first_chunk: bool,
36}
37
38impl OpenAIStream {
39    pub fn new(api_key: String) -> Self {
40        OpenAIStream { api_key }
41    }
42
43    pub async fn gpt_stream(&self, input: &str) -> Result<GptStream, String> {
44        let api_url = "https://api.openai.com/v1/chat/completions";
45
46        let config: GptStreamConfig = match serde_json::from_str(input) {
47            Ok(config) => config,
48            Err(error) => return Err(format!("JSON parsing error: {}", error)),
49        };
50
51        let payload = serde_json::json!({
52            "model": config.model.unwrap_or("gpt-3.5-turbo".to_string()),
53            "messages": config.messages,
54            "temperature": config.temperature.unwrap_or(1.0),
55            "top_p": config.top_p.unwrap_or(1.0),
56            "n": config.n.unwrap_or(1),
57            "stream": true,
58            "presence_penalty": config.presence_penalty.unwrap_or(0.0),
59            "frequency_penalty": config.frequency_penalty.unwrap_or(0.0)
60        });
61
62        let client = Client::new();
63        let response: Response = match client
64            .post(api_url)
65            .header("Content-Type", "application/json")
66            .header("Authorization", format!("Bearer {}", self.api_key))
67            .json(&payload)
68            .send()
69            .await
70        {
71            Ok(response) => response,
72            Err(error) => return Err(format!("API request error: {}", error)),
73        };
74
75        if response.status().is_success() {
76            Ok(GptStream {
77                response: Box::pin(response.bytes_stream()),
78                buffer: String::new(),
79                first_chunk: true,
80            })
81        } else {
82            let error_text = response.text().await.unwrap_or_else(|_| String::from("Unknown error"));
83            Err(format!("API request error: {}", error_text))
84        }
85    }
86}
87
88impl Stream for GptStream {
89    type Item = String;
90
91    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        loop {
93            match self.response.as_mut().poll_next(cx) {
94                Poll::Ready(Some(Ok(chunk))) => {
95                    let mut utf8_str = String::from_utf8_lossy(&chunk).to_string();
96
97                    if self.first_chunk {
98                        let lines: Vec<&str> = utf8_str.lines().collect();
99                        utf8_str = if lines.len() >= 2 {
100                            lines[lines.len() - 2].to_string()
101                        } else {
102                            utf8_str.clone()
103                        };
104                        self.first_chunk = false;
105                    }
106
107                    let trimmed_str = utf8_str.trim_start_matches("data: ");
108
109                    let json_result: Result<Value, _> = serde_json::from_str(trimmed_str);
110
111                    match json_result {
112                        Ok(json) => {
113                            if let Some(choices) = json.get("choices") {
114                                if let Some(choice) = choices.get(0) {
115                                    if let Some(content) = choice.get("delta").and_then(|delta| delta.get("content")) {
116                                        if let Some(content_str) = content.as_str() {
117                                            self.buffer.push_str(content_str);
118                                            let output = self.buffer.replace("\\n", "\n");
119                                            return Poll::Ready(Some(output));
120                                        }
121                                    }
122                                }
123                            }
124                        }
125                        Err(_) => {}
126                    }
127                }
128                Poll::Ready(Some(Err(error))) => {
129                    eprintln!("Error in stream: {:?}", error);
130                    return Poll::Ready(None);
131                }
132                Poll::Ready(None) => {
133                    return Poll::Ready(None);
134                }
135                Poll::Pending => {
136                    return Poll::Pending;
137                }
138            }
139        }
140    }
141}
142
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use futures_util::stream::StreamExt;
148
149    #[tokio::test]
150    async fn test_gpt_stream_raw_line() {
151        let api_key = "sk-...".to_string(); // Replace with your API key
152        let openai_stream = OpenAIStream::new(api_key);
153
154        let config_json = r#"
155            {
156                "model": "gpt-3.5-turbo",
157                "messages": [
158                    {
159                        "role": "user",
160                        "content": "One sentence to describe a simple advanced usage of Rust"
161                    }
162                ]
163            }
164        "#;
165/*
166        let config_json = r#"
167            {
168                "model": "gpt-3.5-turbo",
169                "messages": [
170                    {
171                        "role": "user",
172                        "content": "One sentence to describe a simple advanced usage of Rust"
173                    }
174                ],
175                "temperature": 1.0,
176                "top_p": 1.0,
177                "n": 1,
178                "stream": true,
179                "presence_penalty": 0.0,
180                "frequency_penalty": 0.0
181            }
182        "#;
183*/
184        let gpt_stream = openai_stream.gpt_stream(config_json).await.unwrap();
185        let mut gpt_stream = Box::pin(gpt_stream);
186
187        // Using the while let syntax to asynchronously iterate over a GptStream stream.
188        while let Some(value) = gpt_stream.next().await {
189            println!("{}", value);
190        }
191    }
192}