rig/providers/openai/completion/
streaming.rs

1use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
2use crate::json_utils;
3use crate::json_utils::merge;
4use crate::providers::openai::completion::{CompletionModel, Usage};
5use crate::streaming;
6use crate::streaming::RawStreamingChoice;
7use async_stream::stream;
8use futures::StreamExt;
9use reqwest::RequestBuilder;
10use reqwest_eventsource::Event;
11use reqwest_eventsource::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use serde_json::json;
14use std::collections::HashMap;
15use tracing::{debug, info_span};
16use tracing_futures::Instrument;
17
18// ================================================================
19// OpenAI Completion Streaming API
20// ================================================================
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct StreamingFunction {
23    #[serde(default)]
24    pub name: Option<String>,
25    #[serde(default)]
26    pub arguments: String,
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct StreamingToolCall {
31    pub index: usize,
32    pub id: Option<String>,
33    pub function: StreamingFunction,
34}
35
36#[derive(Deserialize, Debug)]
37struct StreamingDelta {
38    #[serde(default)]
39    content: Option<String>,
40    #[serde(default, deserialize_with = "json_utils::null_or_vec")]
41    tool_calls: Vec<StreamingToolCall>,
42}
43
44#[derive(Deserialize, Debug)]
45struct StreamingChoice {
46    delta: StreamingDelta,
47}
48
49#[derive(Deserialize, Debug)]
50struct StreamingCompletionChunk {
51    choices: Vec<StreamingChoice>,
52    usage: Option<Usage>,
53}
54
55#[derive(Clone, Serialize, Deserialize)]
56pub struct StreamingCompletionResponse {
57    pub usage: Usage,
58}
59
60impl GetTokenUsage for StreamingCompletionResponse {
61    fn token_usage(&self) -> Option<crate::completion::Usage> {
62        let mut usage = crate::completion::Usage::new();
63        usage.input_tokens = self.usage.prompt_tokens as u64;
64        usage.output_tokens = self.usage.total_tokens as u64 - self.usage.prompt_tokens as u64;
65        usage.total_tokens = self.usage.total_tokens as u64;
66        Some(usage)
67    }
68}
69
70impl CompletionModel<reqwest::Client> {
71    pub(crate) async fn stream(
72        &self,
73        completion_request: CompletionRequest,
74    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
75    {
76        let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?;
77        let request_messages = serde_json::to_string(&request.messages)
78            .expect("Converting to JSON from a Rust struct shouldn't fail");
79        let mut request_as_json = serde_json::to_value(request).expect("this should never fail");
80
81        request_as_json = merge(
82            request_as_json,
83            json!({"stream": true, "stream_options": {"include_usage": true}}),
84        );
85
86        let builder = self
87            .client
88            .post_reqwest("/chat/completions")
89            .json(&request_as_json);
90
91        let span = if tracing::Span::current().is_disabled() {
92            info_span!(
93                target: "rig::completions",
94                "chat",
95                gen_ai.operation.name = "chat",
96                gen_ai.provider.name = "openai",
97                gen_ai.request.model = self.model,
98                gen_ai.response.id = tracing::field::Empty,
99                gen_ai.response.model = self.model,
100                gen_ai.usage.output_tokens = tracing::field::Empty,
101                gen_ai.usage.input_tokens = tracing::field::Empty,
102                gen_ai.input.messages = request_messages,
103                gen_ai.output.messages = tracing::field::Empty,
104            )
105        } else {
106            tracing::Span::current()
107        };
108
109        tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await
110    }
111}
112
113pub async fn send_compatible_streaming_request(
114    request_builder: RequestBuilder,
115) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError> {
116    let span = tracing::Span::current();
117    // Build the request with proper headers for SSE
118    let mut event_source = request_builder
119        .eventsource()
120        .expect("Cloning request must always succeed");
121
122    let stream = stream! {
123        let span = tracing::Span::current();
124        let mut final_usage = Usage::new();
125
126        // Track in-progress tool calls
127        let mut tool_calls: HashMap<usize, (String, String, String)> = HashMap::new();
128
129        let mut text_content = String::new();
130
131        while let Some(event_result) = event_source.next().await {
132            match event_result {
133                Ok(Event::Open) => {
134                    tracing::trace!("SSE connection opened");
135                    continue;
136                }
137                Ok(Event::Message(message)) => {
138                    if message.data.trim().is_empty() || message.data == "[DONE]" {
139                        continue;
140                    }
141
142                    let data = serde_json::from_str::<StreamingCompletionChunk>(&message.data);
143                    let Ok(data) = data else {
144                        let err = data.unwrap_err();
145                        debug!("Couldn't serialize data as StreamingCompletionChunk: {:?}", err);
146                        continue;
147                    };
148
149                    if let Some(choice) = data.choices.first() {
150                        let delta = &choice.delta;
151
152                        // Tool calls
153                        if !delta.tool_calls.is_empty() {
154                            for tool_call in &delta.tool_calls {
155                                let function = tool_call.function.clone();
156
157                                // Start of tool call
158                                if function.name.is_some() && function.arguments.is_empty() {
159                                    let id = tool_call.id.clone().unwrap_or_default();
160                                    tool_calls.insert(
161                                        tool_call.index,
162                                        (id, function.name.clone().unwrap(), "".to_string()),
163                                    );
164                                }
165                                // tool call partial (ie, a continuation of a previously received tool call)
166                                // name: None or Empty String
167                                // arguments: Some(String)
168                                else if function.name.clone().is_none_or(|s| s.is_empty())
169                                    && !function.arguments.is_empty()
170                                {
171                                    if let Some((id, name, arguments)) =
172                                        tool_calls.get(&tool_call.index)
173                                    {
174                                        let new_arguments = &tool_call.function.arguments;
175                                        let arguments = format!("{arguments}{new_arguments}");
176                                        tool_calls.insert(
177                                            tool_call.index,
178                                            (id.clone(), name.clone(), arguments),
179                                        );
180                                    } else {
181                                        debug!("Partial tool call received but tool call was never started.");
182                                    }
183                                }
184                                // Complete tool call
185                                else {
186                                    let id = tool_call.id.clone().unwrap_or_default();
187                                    let name = function.name.expect("tool call should have a name");
188                                    let arguments = function.arguments;
189                                    let Ok(arguments) = serde_json::from_str(&arguments) else {
190                                        debug!("Couldn't serialize '{arguments}' as JSON");
191                                        continue;
192                                    };
193
194                                    yield Ok(streaming::RawStreamingChoice::ToolCall {
195                                        id,
196                                        name,
197                                        arguments,
198                                        call_id: None,
199                                    });
200                                }
201                            }
202                        }
203
204                        // Message content
205                        if let Some(content) = &choice.delta.content {
206                            text_content += content;
207                            yield Ok(streaming::RawStreamingChoice::Message(content.clone()))
208                        }
209                    }
210
211                    // Usage updates
212                    if let Some(usage) = data.usage {
213                        final_usage = usage.clone();
214                    }
215                }
216                Err(reqwest_eventsource::Error::StreamEnded) => {
217                    break;
218                }
219                Err(error) => {
220                    tracing::error!(?error, "SSE error");
221                    yield Err(CompletionError::ResponseError(error.to_string()));
222                    break;
223                }
224            }
225        }
226
227        // Ensure event source is closed when stream ends
228        event_source.close();
229
230        let mut vec_toolcalls = vec![];
231
232        // Flush any tool calls that weren’t fully yielded
233        for (_, (id, name, arguments)) in tool_calls {
234            let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&arguments) else {
235                continue;
236            };
237
238            vec_toolcalls.push(super::ToolCall {
239                r#type: super::ToolType::Function,
240                id: id.clone(),
241                function: super::Function {
242                    name: name.clone(), arguments: arguments.clone()
243                },
244            });
245
246            yield Ok(RawStreamingChoice::ToolCall {
247                id,
248                name,
249                arguments,
250                call_id: None,
251            });
252        }
253
254        let message_output = super::Message::Assistant {
255            content: vec![super::AssistantContent::Text { text: text_content }],
256            refusal: None,
257            audio: None,
258            name: None,
259            tool_calls: vec_toolcalls
260        };
261
262        span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens);
263        span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens);
264        span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing"));
265
266        yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
267            usage: final_usage.clone()
268        }));
269    }.instrument(span);
270
271    Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
272        stream,
273    )))
274}