use crate::completion::{CompletionError, GetTokenUsage};
use crate::http_client::HttpClientExt;
use crate::http_client::sse::{Event, GenericEventSource};
use crate::providers::openai::responses_api::{
ReasoningSummary, ResponsesCompletionModel, ResponsesUsage,
};
use crate::streaming;
use crate::streaming::RawStreamingChoice;
use crate::wasm_compat::WasmCompatSend;
use async_stream::stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use tracing::{Level, debug, enabled, info_span};
use tracing_futures::Instrument as _;
use super::{CompletionResponse, Output};
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum StreamingCompletionChunk {
Response(Box<ResponseChunk>),
Delta(ItemChunk),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingCompletionResponse {
pub usage: ResponsesUsage,
}
impl GetTokenUsage for StreamingCompletionResponse {
fn token_usage(&self) -> Option<crate::completion::Usage> {
let mut usage = crate::completion::Usage::new();
usage.input_tokens = self.usage.input_tokens;
usage.output_tokens = self.usage.output_tokens;
usage.total_tokens = self.usage.total_tokens;
Some(usage)
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ResponseChunk {
#[serde(rename = "type")]
pub kind: ResponseChunkKind,
pub response: CompletionResponse,
pub sequence_number: u64,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum ResponseChunkKind {
#[serde(rename = "response.created")]
ResponseCreated,
#[serde(rename = "response.in_progress")]
ResponseInProgress,
#[serde(rename = "response.completed")]
ResponseCompleted,
#[serde(rename = "response.failed")]
ResponseFailed,
#[serde(rename = "response.incomplete")]
ResponseIncomplete,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ItemChunk {
pub item_id: Option<String>,
pub output_index: u64,
#[serde(flatten)]
pub data: ItemChunkKind,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ItemChunkKind {
#[serde(rename = "response.output_item.added")]
OutputItemAdded(StreamingItemDoneOutput),
#[serde(rename = "response.output_item.done")]
OutputItemDone(StreamingItemDoneOutput),
#[serde(rename = "response.content_part.added")]
ContentPartAdded(ContentPartChunk),
#[serde(rename = "response.content_part.done")]
ContentPartDone(ContentPartChunk),
#[serde(rename = "response.output_text.delta")]
OutputTextDelta(DeltaTextChunk),
#[serde(rename = "response.output_text.done")]
OutputTextDone(OutputTextChunk),
#[serde(rename = "response.refusal.delta")]
RefusalDelta(DeltaTextChunk),
#[serde(rename = "response.refusal.done")]
RefusalDone(RefusalTextChunk),
#[serde(rename = "response.function_call_arguments.delta")]
FunctionCallArgsDelta(DeltaTextChunkWithItemId),
#[serde(rename = "response.function_call_arguments.done")]
FunctionCallArgsDone(ArgsTextChunk),
#[serde(rename = "response.reasoning_summary_part.added")]
ReasoningSummaryPartAdded(SummaryPartChunk),
#[serde(rename = "response.reasoning_summary_part.done")]
ReasoningSummaryPartDone(SummaryPartChunk),
#[serde(rename = "response.reasoning_summary_text.delta")]
ReasoningSummaryTextDelta(SummaryTextChunk),
#[serde(rename = "response.reasoning_summary_text.done")]
ReasoningSummaryTextDone(SummaryTextChunk),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StreamingItemDoneOutput {
pub sequence_number: u64,
pub item: Output,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ContentPartChunk {
pub content_index: u64,
pub sequence_number: u64,
pub part: ContentPartChunkPart,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ContentPartChunkPart {
OutputText { text: String },
SummaryText { text: String },
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DeltaTextChunk {
pub content_index: u64,
pub sequence_number: u64,
pub delta: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DeltaTextChunkWithItemId {
pub item_id: String,
pub content_index: u64,
pub sequence_number: u64,
pub delta: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OutputTextChunk {
pub content_index: u64,
pub sequence_number: u64,
pub text: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RefusalTextChunk {
pub content_index: u64,
pub sequence_number: u64,
pub refusal: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ArgsTextChunk {
pub content_index: u64,
pub sequence_number: u64,
pub arguments: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SummaryPartChunk {
pub summary_index: u64,
pub sequence_number: u64,
pub part: SummaryPartChunkPart,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SummaryTextChunk {
pub summary_index: u64,
pub sequence_number: u64,
pub delta: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum SummaryPartChunkPart {
SummaryText { text: String },
}
impl<T> ResponsesCompletionModel<T>
where
T: HttpClientExt + Clone + Default + std::fmt::Debug + WasmCompatSend + 'static,
{
pub(crate) async fn stream(
&self,
completion_request: crate::completion::CompletionRequest,
) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
{
let mut request = self.create_completion_request(completion_request)?;
request.stream = Some(true);
if enabled!(Level::TRACE) {
tracing::trace!(
target: "rig::completions",
"OpenAI Responses streaming completion request: {}",
serde_json::to_string_pretty(&request)?
);
}
let body = serde_json::to_vec(&request)?;
let req = self
.client
.post("/responses")?
.body(body)
.map_err(|e| CompletionError::HttpError(e.into()))?;
let span = if tracing::Span::current().is_disabled() {
info_span!(
target: "rig::completions",
"chat_streaming",
gen_ai.operation.name = "chat_streaming",
gen_ai.provider.name = tracing::field::Empty,
gen_ai.request.model = tracing::field::Empty,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
span.record("gen_ai.provider.name", "openai");
span.record("gen_ai.request.model", &self.model);
let client = self.client.clone();
let mut event_source = GenericEventSource::new(client, req);
let stream = stream! {
let mut final_usage = ResponsesUsage::new();
let mut tool_calls: Vec<RawStreamingChoice<StreamingCompletionResponse>> = Vec::new();
let mut combined_text = String::new();
let span = tracing::Span::current();
while let Some(event_result) = event_source.next().await {
match event_result {
Ok(Event::Open) => {
tracing::trace!("SSE connection opened");
tracing::info!("OpenAI stream started");
continue;
}
Ok(Event::Message(evt)) => {
if evt.data.trim().is_empty() {
continue;
}
let data = serde_json::from_str::<StreamingCompletionChunk>(&evt.data);
let Ok(data) = data else {
let err = data.unwrap_err();
debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err);
continue;
};
if let StreamingCompletionChunk::Delta(chunk) = &data {
match &chunk.data {
ItemChunkKind::OutputItemAdded(message) => {
if let StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } = message {
yield Ok(streaming::RawStreamingChoice::ToolCallDelta {
id: func.id.clone(),
content: streaming::ToolCallDeltaContent::Name(func.name.clone()),
});
}
}
ItemChunkKind::OutputItemDone(message) => {
match message {
StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => {
tool_calls.push(streaming::RawStreamingChoice::ToolCall(
streaming::RawStreamingToolCall::new(
func.id.clone(),
func.name.clone(),
func.arguments.clone(),
)
.with_call_id(func.call_id.clone())
));
}
StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => {
let reasoning = summary
.iter()
.map(|x| {
let ReasoningSummary::SummaryText { text } = x;
text.to_owned()
})
.collect::<Vec<String>>()
.join("\n");
yield Ok(streaming::RawStreamingChoice::Reasoning {
id: Some(id.to_string()),
reasoning,
signature: None,
})
}
_ => continue
}
}
ItemChunkKind::OutputTextDelta(delta) => {
combined_text.push_str(&delta.delta);
yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
}
ItemChunkKind::ReasoningSummaryTextDelta(delta) => {
yield Ok(streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: delta.delta.clone() })
}
ItemChunkKind::RefusalDelta(delta) => {
combined_text.push_str(&delta.delta);
yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone()))
}
ItemChunkKind::FunctionCallArgsDelta(delta) => {
yield Ok(streaming::RawStreamingChoice::ToolCallDelta { id: delta.item_id.clone(), content: streaming::ToolCallDeltaContent::Delta(delta.delta.clone()) })
}
_ => { continue }
}
}
if let StreamingCompletionChunk::Response(chunk) = data {
if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk {
span.record("gen_ai.response.id", response.id);
span.record("gen_ai.response.model", response.model);
if let Some(usage) = response.usage {
final_usage = usage;
}
} else {
continue;
}
}
}
Err(crate::http_client::Error::StreamEnded) => {
event_source.close();
}
Err(error) => {
tracing::error!(?error, "SSE error");
yield Err(CompletionError::ProviderError(error.to_string()));
break;
}
}
}
event_source.close();
for tool_call in &tool_calls {
yield Ok(tool_call.to_owned())
}
span.record("gen_ai.usage.input_tokens", final_usage.input_tokens);
span.record("gen_ai.usage.output_tokens", final_usage.output_tokens);
tracing::info!("OpenAI stream finished");
yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
usage: final_usage
}));
}.instrument(span);
Ok(streaming::StreamingCompletionResponse::stream(Box::pin(
stream,
)))
}
}