openrouter_rust/
streaming.rs1use crate::{
2 chat::{ChatCompletionRequest, ChatCompletionResponse, Choice},
3 client::OpenRouterClient,
4 error::{OpenRouterError, Result},
5 types::{Message, Usage},
6};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10
11pub type ChatCompletionStream = Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ChatCompletionChunk {
15 pub id: String,
16 pub object: String,
17 pub created: i64,
18 pub model: String,
19 pub choices: Vec<StreamingChoice>,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub usage: Option<Usage>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub error: Option<ChunkError>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct StreamingChoice {
28 pub index: u32,
29 pub delta: DeltaMessage,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub finish_reason: Option<String>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub native_finish_reason: Option<String>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub error: Option<ChoiceError>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DeltaMessage {
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub role: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub content: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ChoiceError {
48 pub code: u16,
49 pub message: String,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ChunkError {
54 pub code: u16,
55 pub message: String,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub metadata: Option<serde_json::Value>,
58}
59
60impl OpenRouterClient {
61 pub async fn chat_completion_stream(
62 &self,
63 mut request: ChatCompletionRequest,
64 ) -> Result<ChatCompletionStream> {
65 request.stream = Some(true);
66
67 let url = format!("{}/chat/completions", self.base_url);
68 let headers = self.build_headers()?;
69
70 let response = self
71 .client
72 .post(&url)
73 .headers(headers)
74 .json(&request)
75 .send()
76 .await
77 .map_err(OpenRouterError::HttpError)?;
78
79 let status = response.status();
80
81 if !status.is_success() {
82 let error_text = response.text().await.unwrap_or_default();
83 return Err(OpenRouterError::ApiError {
84 code: status.as_u16(),
85 message: error_text,
86 });
87 }
88
89 let stream = response
90 .bytes_stream()
91 .map(|result| {
92 result.map_err(OpenRouterError::HttpError)
93 })
94 .filter_map(|result| async move {
95 match result {
96 Ok(bytes) => {
97 let text = String::from_utf8_lossy(&bytes);
98 parse_sse_chunk(&text)
99 }
100 Err(e) => Some(Err(e)),
101 }
102 });
103
104 Ok(Box::pin(stream))
105 }
106}
107
108fn parse_sse_chunk(text: &str) -> Option<Result<ChatCompletionChunk>> {
109 let mut result = None;
110
111 for line in text.lines() {
112 let line = line.trim();
113
114 if line.is_empty() || line.starts_with(':') {
115 continue;
116 }
117
118 if line.starts_with("data: ") {
119 let data = &line[6..];
120
121 if data == "[DONE]" {
122 return None;
123 }
124
125 match serde_json::from_str::<ChatCompletionChunk>(data) {
126 Ok(chunk) => {
127 if let Some(ref error) = chunk.error {
128 return Some(Err(OpenRouterError::StreamError(format!(
129 "Stream error: {} - {}",
130 error.code, error.message
131 ))));
132 }
133 result = Some(Ok(chunk));
134 }
135 Err(_) => continue,
136 }
137 }
138 }
139
140 result
141}
142
143pub async fn collect_stream(stream: ChatCompletionStream) -> Result<ChatCompletionResponse> {
144 let mut chunks: Vec<ChatCompletionChunk> = Vec::new();
145 let mut full_content = String::new();
146 let mut role = String::new();
147 let mut last_usage: Option<Usage> = None;
148 let mut finish_reason: Option<String> = None;
149 let mut native_finish_reason: Option<String> = None;
150 let mut id = String::new();
151 let mut object = String::new();
152 let mut created: i64 = 0;
153 let mut model = String::new();
154
155 let mut stream = stream;
156
157 while let Some(result) = stream.next().await {
158 let chunk = result?;
159
160 if id.is_empty() {
161 id = chunk.id.clone();
162 object = chunk.object.clone();
163 created = chunk.created;
164 model = chunk.model.clone();
165 }
166
167 if let Some(ref usage) = chunk.usage {
168 last_usage = Some(usage.clone());
169 }
170
171 for choice in &chunk.choices {
172 if let Some(ref r) = choice.delta.role {
173 role = r.clone();
174 }
175 if let Some(ref content) = choice.delta.content {
176 full_content.push_str(content);
177 }
178 if let Some(ref fr) = choice.finish_reason {
179 finish_reason = Some(fr.clone());
180 }
181 if let Some(ref nfr) = choice.native_finish_reason {
182 native_finish_reason = Some(nfr.clone());
183 }
184 }
185
186 chunks.push(chunk);
187 }
188
189 Ok(ChatCompletionResponse {
190 id,
191 object,
192 created,
193 model,
194 choices: vec![Choice {
195 index: 0,
196 message: Message {
197 role: if role == "assistant" {
198 crate::types::Role::Assistant
199 } else {
200 crate::types::Role::User
201 },
202 content: Some(full_content),
203 name: None,
204 tool_calls: None,
205 },
206 finish_reason,
207 native_finish_reason,
208 error: None,
209 }],
210 usage: last_usage,
211 system_fingerprint: None,
212 })
213}