openai_api_stream_rs/
openai.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures_util::stream::Stream;
5use reqwest::{Client, Response};
6use serde_json::Value;
7
8use serde::{Serialize, Deserialize};
9
10#[derive(Serialize, Deserialize, Default)]
11pub struct GptStreamConfig {
12 model: Option<String>,
13 messages: Vec<Message>,
14 temperature: Option<f64>,
15 top_p: Option<f64>,
16 n: Option<usize>,
17 stream: Option<bool>,
18 presence_penalty: Option<f64>,
19 frequency_penalty: Option<f64>,
20}
21
22#[derive(Serialize, Deserialize)]
23pub struct Message {
24 role: String,
25 content: String,
26}
27
28pub struct OpenAIStream {
29 api_key: String,
30}
31
32pub struct GptStream {
33 response: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
34 buffer: String,
35 first_chunk: bool,
36}
37
38impl OpenAIStream {
39 pub fn new(api_key: String) -> Self {
40 OpenAIStream { api_key }
41 }
42
43 pub async fn gpt_stream(&self, input: &str) -> Result<GptStream, String> {
44 let api_url = "https://api.openai.com/v1/chat/completions";
45
46 let config: GptStreamConfig = match serde_json::from_str(input) {
47 Ok(config) => config,
48 Err(error) => return Err(format!("JSON parsing error: {}", error)),
49 };
50
51 let payload = serde_json::json!({
52 "model": config.model.unwrap_or("gpt-3.5-turbo".to_string()),
53 "messages": config.messages,
54 "temperature": config.temperature.unwrap_or(1.0),
55 "top_p": config.top_p.unwrap_or(1.0),
56 "n": config.n.unwrap_or(1),
57 "stream": true,
58 "presence_penalty": config.presence_penalty.unwrap_or(0.0),
59 "frequency_penalty": config.frequency_penalty.unwrap_or(0.0)
60 });
61
62 let client = Client::new();
63 let response: Response = match client
64 .post(api_url)
65 .header("Content-Type", "application/json")
66 .header("Authorization", format!("Bearer {}", self.api_key))
67 .json(&payload)
68 .send()
69 .await
70 {
71 Ok(response) => response,
72 Err(error) => return Err(format!("API request error: {}", error)),
73 };
74
75 if response.status().is_success() {
76 Ok(GptStream {
77 response: Box::pin(response.bytes_stream()),
78 buffer: String::new(),
79 first_chunk: true,
80 })
81 } else {
82 let error_text = response.text().await.unwrap_or_else(|_| String::from("Unknown error"));
83 Err(format!("API request error: {}", error_text))
84 }
85 }
86}
87
88impl Stream for GptStream {
89 type Item = String;
90
91 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92 loop {
93 match self.response.as_mut().poll_next(cx) {
94 Poll::Ready(Some(Ok(chunk))) => {
95 let mut utf8_str = String::from_utf8_lossy(&chunk).to_string();
96
97 if self.first_chunk {
98 let lines: Vec<&str> = utf8_str.lines().collect();
99 utf8_str = if lines.len() >= 2 {
100 lines[lines.len() - 2].to_string()
101 } else {
102 utf8_str.clone()
103 };
104 self.first_chunk = false;
105 }
106
107 let trimmed_str = utf8_str.trim_start_matches("data: ");
108
109 let json_result: Result<Value, _> = serde_json::from_str(trimmed_str);
110
111 match json_result {
112 Ok(json) => {
113 if let Some(choices) = json.get("choices") {
114 if let Some(choice) = choices.get(0) {
115 if let Some(content) = choice.get("delta").and_then(|delta| delta.get("content")) {
116 if let Some(content_str) = content.as_str() {
117 self.buffer.push_str(content_str);
118 let output = self.buffer.replace("\\n", "\n");
119 return Poll::Ready(Some(output));
120 }
121 }
122 }
123 }
124 }
125 Err(_) => {}
126 }
127 }
128 Poll::Ready(Some(Err(error))) => {
129 eprintln!("Error in stream: {:?}", error);
130 return Poll::Ready(None);
131 }
132 Poll::Ready(None) => {
133 return Poll::Ready(None);
134 }
135 Poll::Pending => {
136 return Poll::Pending;
137 }
138 }
139 }
140 }
141}
142
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use futures_util::stream::StreamExt;
148
149 #[tokio::test]
150 async fn test_gpt_stream_raw_line() {
151 let api_key = "sk-...".to_string(); let openai_stream = OpenAIStream::new(api_key);
153
154 let config_json = r#"
155 {
156 "model": "gpt-3.5-turbo",
157 "messages": [
158 {
159 "role": "user",
160 "content": "One sentence to describe a simple advanced usage of Rust"
161 }
162 ]
163 }
164 "#;
165let gpt_stream = openai_stream.gpt_stream(config_json).await.unwrap();
185 let mut gpt_stream = Box::pin(gpt_stream);
186
187 while let Some(value) = gpt_stream.next().await {
189 println!("{}", value);
190 }
191 }
192}