Skip to main content

openrouter_rust/
streaming.rs

1use crate::{
2    chat::{ChatCompletionRequest, ChatCompletionResponse, Choice},
3    client::OpenRouterClient,
4    error::{OpenRouterError, Result},
5    types::{Message, Usage},
6};
7use futures::{Stream, StreamExt};
8use serde::{Deserialize, Serialize};
9use std::pin::Pin;
10
11pub type ChatCompletionStream = Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ChatCompletionChunk {
15    pub id: String,
16    pub object: String,
17    pub created: i64,
18    pub model: String,
19    pub choices: Vec<StreamingChoice>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub usage: Option<Usage>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub error: Option<ChunkError>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct StreamingChoice {
28    pub index: u32,
29    pub delta: DeltaMessage,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub finish_reason: Option<String>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub native_finish_reason: Option<String>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub error: Option<ChoiceError>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DeltaMessage {
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub role: Option<String>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub content: Option<String>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ChoiceError {
48    pub code: u16,
49    pub message: String,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ChunkError {
54    pub code: u16,
55    pub message: String,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub metadata: Option<serde_json::Value>,
58}
59
60impl OpenRouterClient {
61    pub async fn chat_completion_stream(
62        &self,
63        mut request: ChatCompletionRequest,
64    ) -> Result<ChatCompletionStream> {
65        request.stream = Some(true);
66        
67        let url = format!("{}/chat/completions", self.base_url);
68        let headers = self.build_headers()?;
69
70        let response = self
71            .client
72            .post(&url)
73            .headers(headers)
74            .json(&request)
75            .send()
76            .await
77            .map_err(OpenRouterError::HttpError)?;
78
79        let status = response.status();
80        
81        if !status.is_success() {
82            let error_text = response.text().await.unwrap_or_default();
83            return Err(OpenRouterError::ApiError {
84                code: status.as_u16(),
85                message: error_text,
86            });
87        }
88
89        let stream = response
90            .bytes_stream()
91            .map(|result| {
92                result.map_err(OpenRouterError::HttpError)
93            })
94            .filter_map(|result| async move {
95                match result {
96                    Ok(bytes) => {
97                        let text = String::from_utf8_lossy(&bytes);
98                        parse_sse_chunk(&text)
99                    }
100                    Err(e) => Some(Err(e)),
101                }
102            });
103
104        Ok(Box::pin(stream))
105    }
106}
107
108fn parse_sse_chunk(text: &str) -> Option<Result<ChatCompletionChunk>> {
109    let mut result = None;
110    
111    for line in text.lines() {
112        let line = line.trim();
113        
114        if line.is_empty() || line.starts_with(':') {
115            continue;
116        }
117        
118        if line.starts_with("data: ") {
119            let data = &line[6..];
120            
121            if data == "[DONE]" {
122                return None;
123            }
124            
125            match serde_json::from_str::<ChatCompletionChunk>(data) {
126                Ok(chunk) => {
127                    if let Some(ref error) = chunk.error {
128                        return Some(Err(OpenRouterError::StreamError(format!(
129                            "Stream error: {} - {}",
130                            error.code, error.message
131                        ))));
132                    }
133                    result = Some(Ok(chunk));
134                }
135                Err(_) => continue,
136            }
137        }
138    }
139    
140    result
141}
142
143pub async fn collect_stream(stream: ChatCompletionStream) -> Result<ChatCompletionResponse> {
144    let mut chunks: Vec<ChatCompletionChunk> = Vec::new();
145    let mut full_content = String::new();
146    let mut role = String::new();
147    let mut last_usage: Option<Usage> = None;
148    let mut finish_reason: Option<String> = None;
149    let mut native_finish_reason: Option<String> = None;
150    let mut id = String::new();
151    let mut object = String::new();
152    let mut created: i64 = 0;
153    let mut model = String::new();
154
155    let mut stream = stream;
156    
157    while let Some(result) = stream.next().await {
158        let chunk = result?;
159        
160        if id.is_empty() {
161            id = chunk.id.clone();
162            object = chunk.object.clone();
163            created = chunk.created;
164            model = chunk.model.clone();
165        }
166        
167        if let Some(ref usage) = chunk.usage {
168            last_usage = Some(usage.clone());
169        }
170        
171        for choice in &chunk.choices {
172            if let Some(ref r) = choice.delta.role {
173                role = r.clone();
174            }
175            if let Some(ref content) = choice.delta.content {
176                full_content.push_str(content);
177            }
178            if let Some(ref fr) = choice.finish_reason {
179                finish_reason = Some(fr.clone());
180            }
181            if let Some(ref nfr) = choice.native_finish_reason {
182                native_finish_reason = Some(nfr.clone());
183            }
184        }
185        
186        chunks.push(chunk);
187    }
188
189    Ok(ChatCompletionResponse {
190        id,
191        object,
192        created,
193        model,
194        choices: vec![Choice {
195            index: 0,
196            message: Message {
197                role: if role == "assistant" {
198                    crate::types::Role::Assistant
199                } else {
200                    crate::types::Role::User
201                },
202                content: Some(full_content),
203                name: None,
204                tool_calls: None,
205            },
206            finish_reason,
207            native_finish_reason,
208            error: None,
209        }],
210        usage: last_usage,
211        system_fingerprint: None,
212    })
213}