rig/providers/openai/
streaming.rs

1use super::completion::CompletionModel;
2use crate::completion::{CompletionError, CompletionRequest};
3use crate::json_utils;
4use crate::json_utils::merge;
5use crate::streaming;
6use crate::streaming::{StreamingCompletionModel, StreamingResult};
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13
14// ================================================================
15// OpenAI Completion Streaming API
16// ================================================================
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct StreamingFunction {
19    #[serde(default)]
20    name: Option<String>,
21    #[serde(default)]
22    arguments: String,
23}
24
25#[derive(Debug, Serialize, Deserialize, Clone)]
26pub struct StreamingToolCall {
27    pub index: usize,
28    pub function: StreamingFunction,
29}
30
31#[derive(Deserialize)]
32struct StreamingDelta {
33    #[serde(default)]
34    content: Option<String>,
35    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
36    tool_calls: Vec<StreamingToolCall>,
37}
38
39#[derive(Deserialize)]
40struct StreamingChoice {
41    delta: StreamingDelta,
42}
43
44#[derive(Deserialize)]
45struct StreamingCompletionResponse {
46    choices: Vec<StreamingChoice>,
47}
48
49impl StreamingCompletionModel for CompletionModel {
50    async fn stream(
51        &self,
52        completion_request: CompletionRequest,
53    ) -> Result<StreamingResult, CompletionError> {
54        let mut request = self.create_completion_request(completion_request)?;
55        request = merge(request, json!({"stream": true}));
56
57        let builder = self.client.post("/chat/completions").json(&request);
58        send_compatible_streaming_request(builder).await
59    }
60}
61
62pub async fn send_compatible_streaming_request(
63    request_builder: RequestBuilder,
64) -> Result<StreamingResult, CompletionError> {
65    let response = request_builder.send().await?;
66
67    if !response.status().is_success() {
68        return Err(CompletionError::ProviderError(format!(
69            "{}: {}",
70            response.status(),
71            response.text().await?
72        )));
73    }
74
75    // Handle OpenAI Compatible SSE chunks
76    Ok(Box::pin(stream! {
77        let mut stream = response.bytes_stream();
78
79        let mut partial_data = None;
80        let mut calls: HashMap<usize, (String, String)> = HashMap::new();
81
82        while let Some(chunk_result) = stream.next().await {
83            let chunk = match chunk_result {
84                Ok(c) => c,
85                Err(e) => {
86                    yield Err(CompletionError::from(e));
87                    break;
88                }
89            };
90
91            let text = match String::from_utf8(chunk.to_vec()) {
92                Ok(t) => t,
93                Err(e) => {
94                    yield Err(CompletionError::ResponseError(e.to_string()));
95                    break;
96                }
97            };
98
99
100            for line in text.lines() {
101                let mut line = line.to_string();
102
103
104
105                // If there was a remaining part, concat with current line
106                if partial_data.is_some() {
107                    line = format!("{}{}", partial_data.unwrap(), line);
108                    partial_data = None;
109                }
110                // Otherwise full data line
111                else {
112                    let Some(data) = line.strip_prefix("data: ") else {
113                        continue;
114                    };
115
116                    // Partial data, split somewhere in the middle
117                    if !line.ends_with("}") {
118                        partial_data = Some(data.to_string());
119                    } else {
120                        line = data.to_string();
121                    }
122                }
123
124                let data = serde_json::from_str::<StreamingCompletionResponse>(&line);
125
126                let Ok(data) = data else {
127                    continue;
128                };
129
130                let choice = data.choices.first().expect("Should have at least one choice");
131
132                let delta = &choice.delta;
133
134                if !delta.tool_calls.is_empty() {
135                    for tool_call in &delta.tool_calls {
136                        let function = tool_call.function.clone();
137
138                        // Start of tool call
139                        // name: Some(String)
140                        // arguments: None
141                        if function.name.is_some() && function.arguments.is_empty() {
142                            calls.insert(tool_call.index, (function.name.clone().unwrap(), "".to_string()));
143                        }
144                        // Part of tool call
145                        // name: None
146                        // arguments: Some(String)
147                        else if function.name.is_none() && !function.arguments.is_empty() {
148                            let Some((name, arguments)) = calls.get(&tool_call.index) else {
149                                continue;
150                            };
151
152                            let new_arguments = &tool_call.function.arguments;
153                            let arguments = format!("{}{}", arguments, new_arguments);
154
155                            calls.insert(tool_call.index, (name.clone(), arguments));
156                        }
157                        // Entire tool call
158                        else {
159                            let name = function.name.unwrap();
160                            let arguments = function.arguments;
161                            let Ok(arguments) = serde_json::from_str(&arguments) else {
162                                continue;
163                            };
164
165                            yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
166                        }
167                    }
168                }
169
170                if let Some(content) = &choice.delta.content {
171                    yield Ok(streaming::StreamingChoice::Message(content.clone()))
172                }
173            }
174        }
175
176        for (_, (name, arguments)) in calls {
177            let Ok(arguments) = serde_json::from_str(&arguments) else {
178                continue;
179            };
180
181            yield Ok(streaming::StreamingChoice::ToolCall(name, "".to_string(), arguments))
182        }
183    }))
184}