rig/providers/openai/completion/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use reqwest_eventsource::Event;
11use reqwest_eventsource::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::debug;
16
17// ================================================================
18// OpenAI Completion Streaming API
19// ================================================================
20#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct StreamingFunction {
22    #[serde(default)]
23    pub name: Option<String>,
24    #[serde(default)]
25    pub arguments: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct StreamingToolCall {
30    pub index: usize,
31    pub id: Option<String>,
32    pub function: StreamingFunction,
33}
34
35#[derive(Deserialize, Debug)]
36struct StreamingDelta {
37    #[serde(default)]
38    content: Option<String>,
39    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
40    tool_calls: Vec<StreamingToolCall>,
41}
42
43#[derive(Deserialize, Debug)]
44struct StreamingChoice {
45    delta: StreamingDelta,
46}
47
48#[derive(Deserialize, Debug)]
49struct StreamingCompletionChunk {
50    choices: Vec<StreamingChoice>,
51    usage: Option<Usage>,
52}
53
54#[derive(Clone, Serialize, Deserialize)]
55pub struct StreamingCompletionResponse {
56    pub usage: Usage,
57}
58
59impl GetTokenUsage for StreamingCompletionResponse {
60    fn token_usage(&self) -> Option<crate::completion::Usage> {
61        let mut usage = crate::completion::Usage::new();
62        usage.input_tokens = self.usage.prompt_tokens as u64;
63        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
64        usage.total_tokens = self.usage.total_tokens as u64;
65        Some(usage)
66    }
67}
68
69impl CompletionModel {
70    pub(crate) async fn stream(
71        &self,
72        completion_request: CompletionRequest,
73    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
74    {
75        let mut request = self.create_completion_request(completion_request)?;
76        request = merge(
77            request,
78            json!({"stream": true, "stream_options": {"include_usage": true}}),
79        );
80
81        let builder = self.client.post("/chat/completions").json(&request);
82        send_compatible_streaming_request(builder).await
83    }
84}
85
86pub async fn send_compatible_streaming_request(
87    request_builder: RequestBuilder,
88) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
89    // Build the request with proper headers for SSE
90    let mut event_source = request_builder
91        .eventsource()
92        .expect("Cloning request must always succeed");
93
94    let stream = Box::pin(stream! {
95        let mut final_usage = Usage::new();
96
97        // Track in-progress tool calls
98        let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
99
100        while let Some(event_result) = event_source.next().await {
101            match event_result {
102                Ok(Event::Open) => {
103                    tracing::trace!("SSE connection opened");
104                    continue;
105                }
106                Ok(Event::Message(message)) => {
107                    if message.data.trim().is_empty() || message.data == "[DONE]" {
108                        continue;
109                    }
110
111                    let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
112                    let Ok(data) = data else {
113                        let err = data.unwrap_err();
114                        debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
115                        continue;
116                    };
117
118                    if let Some(choice) = data.choices.first() {
119                        let delta = &choice.delta;
120
121                        // Tool calls
122                        if !delta.tool_calls.is_empty() {
123                            for tool_call in &delta.tool_calls {
124                                let function = tool_call.function.clone();
125
126                                // Start of tool call
127                                if function.name.is_some() && function.arguments.is_empty() {
128                                    let id = tool_call.id.clone().unwrap_or_default();
129                                    tool_calls.insert(
130                                        tool_call.index,
131                                        (id, function.name.clone().unwrap(), "".to_string()),
132                                    );
133                                }
134                                // tool call partial (ie, a continuation of a previously received tool call)
135                                // name: None or Empty String
136                                // arguments: Some(String)
137                                else if function.name.clone().is_none_or(|s| s.is_empty())
138                                    && !function.arguments.is_empty()
139                                {
140                                    if let Some((id, name, arguments)) =
141                                        tool_calls.get(&tool_call.index)
142                                    {
143                                        let new_arguments = &tool_call.function.arguments;
144                                        let arguments = format!("{arguments}{new_arguments}");
145                                        tool_calls.insert(
146                                            tool_call.index,
147                                            (id.clone(), name.clone(), arguments),
148                                        );
149                                    } else {
150                                        debug!("Partial tool call received but tool call was never started.");
151                                    }
152                                }
153                                // Complete tool call
154                                else {
155                                    let id = tool_call.id.clone().unwrap_or_default();
156                                    let name = function.name.expect("tool call should have a name");
157                                    let arguments = function.arguments;
158                                    let Ok(arguments) = serde_json::from_str(&arguments) else {
159                                        debug!("Couldn't serialize '{arguments}' as JSON");
160                                        continue;
161                                    };
162
163                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
164                                        id,
165                                        name,
166                                        arguments,
167                                        call_id: None,
168                                    });
169                                }
170                            }
171                        }
172
173                        // Message content
174                        if let Some(content) = &choice.delta.content {
175                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
176                        }
177                    }
178
179                    // Usage updates
180                    if let Some(usage) = data.usage {
181                        final_usage = usage.clone();
182                    }
183                }
184                Err(reqwest_eventsource::Error::StreamEnded) => {
185                    break;
186                }
187                Err(error) => {
188                    tracing::error!(?error, "SSE error");
189                    yield Err(CompletionError::ResponseError(error.to_string()));
190                    break;
191                }
192            }
193        }
194
195        // Ensure event source is closed when stream ends
196        event_source.close();
197
198        // Flush any tool calls that weren’t fully yielded
199        for (_, (id, name, arguments)) in tool_calls {
200            let Ok(arguments) = serde_json::from_str(&arguments) else {
201                continue;
202            };
203
204            yield Ok(RawStreamingChoice::ToolCall {
205                id,
206                name,
207                arguments,
208                call_id: None,
209            });
210        }
211
212        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
213            usage: final_usage.clone()
214        }));
215    });
216
217    Ok(streaming::StreamingCompletionResponse::stream(stream))
218}