1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures_util::Stream;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9use super::partial_json::parse_optional_json;
10use super::sse::SseStream;
11use crate::error::Result;
12use crate::json_payload::JsonPayload;
13use crate::resources::{
14 ChatCompletion, ChatCompletionChoiceLogprobs, ChatCompletionChunk, ChatCompletionChunkDelta,
15 ChatCompletionMessage, ChatCompletionTokenLogprob, ChatCompletionToolCall,
16};
17use crate::response_meta::ResponseMeta;
18
19#[derive(Debug)]
21pub struct ChatCompletionStream {
22 inner: SseStream<ChatCompletionChunk>,
23 accumulator: ChatCompletionAccumulator,
24}
25
26impl ChatCompletionStream {
27 pub fn new(inner: SseStream<ChatCompletionChunk>) -> Self {
29 Self {
30 inner,
31 accumulator: ChatCompletionAccumulator::default(),
32 }
33 }
34
35 pub fn snapshot(&self) -> Option<ChatCompletion> {
37 self.accumulator.snapshot()
38 }
39
40 pub async fn into_final_response(mut self) -> Result<Option<ChatCompletion>> {
42 while let Some(chunk) = futures_util::StreamExt::next(&mut self).await {
43 chunk?;
44 }
45 Ok(self.snapshot())
46 }
47
48 pub fn meta(&self) -> &ResponseMeta {
50 self.inner.meta()
51 }
52
53 pub async fn final_chat_completion(self) -> Result<Option<ChatCompletion>> {
55 self.into_final_response().await
56 }
57
58 pub async fn final_chat_completion_checked(self) -> Result<Option<ChatCompletion>> {
60 let response = self.into_final_response().await?;
61 if let Some(response) = &response {
62 response.ensure_not_truncated()?;
63 }
64 Ok(response)
65 }
66
67 pub async fn final_message(self) -> Result<Option<ChatCompletionMessage>> {
69 Ok(self.into_final_response().await?.and_then(|response| {
70 response
71 .choices
72 .into_iter()
73 .next()
74 .map(|choice| choice.message)
75 }))
76 }
77
78 pub async fn final_content(self) -> Result<Option<String>> {
80 Ok(self
81 .final_message()
82 .await?
83 .and_then(|message| message.content))
84 }
85
86 pub async fn final_tool_calls(self) -> Result<Option<Vec<ChatCompletionToolCall>>> {
88 Ok(self
89 .final_message()
90 .await?
91 .map(|message| message.tool_calls)
92 .filter(|tool_calls| !tool_calls.is_empty()))
93 }
94
95 pub fn events(self) -> ChatCompletionEventStream {
97 ChatCompletionEventStream::new(self)
98 }
99}
100
101impl Stream for ChatCompletionStream {
102 type Item = Result<ChatCompletionChunk>;
103
104 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
105 let this = self.get_mut();
106 match Pin::new(&mut this.inner).poll_next(cx) {
107 Poll::Ready(Some(Ok(chunk))) => {
108 this.accumulator.apply(&chunk);
109 Poll::Ready(Some(Ok(chunk)))
110 }
111 other => other,
112 }
113 }
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub struct ChatContentSnapshotEvent {
119 pub choice_index: u32,
121 pub delta: String,
123 pub snapshot: String,
125 pub parsed: Option<JsonPayload>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub struct ChatContentDoneEvent {
132 pub choice_index: u32,
134 pub content: String,
136 pub parsed: Option<JsonPayload>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
142pub struct ChatRefusalSnapshotEvent {
143 pub choice_index: u32,
145 pub delta: String,
147 pub snapshot: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
153pub struct ChatRefusalDoneEvent {
154 pub choice_index: u32,
156 pub refusal: String,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
162pub struct ChatToolArgumentsSnapshotEvent {
163 pub choice_index: u32,
165 pub tool_call_index: u32,
167 pub name: String,
169 pub arguments_delta: String,
171 pub arguments: String,
173 pub parsed_arguments: Option<JsonPayload>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
179pub struct ChatToolArgumentsDoneEvent {
180 pub choice_index: u32,
182 pub tool_call_index: u32,
184 pub name: String,
186 pub arguments: String,
188 pub parsed_arguments: Option<JsonPayload>,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
194pub struct ChatLogProbsSnapshotEvent {
195 pub choice_index: u32,
197 pub values: Vec<ChatCompletionTokenLogprob>,
199 pub snapshot: Vec<ChatCompletionTokenLogprob>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
205pub struct ChatLogProbsDoneEvent {
206 pub choice_index: u32,
208 pub values: Vec<ChatCompletionTokenLogprob>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214#[allow(clippy::large_enum_variant)]
215pub enum ChatCompletionRuntimeEvent {
216 Chunk {
218 chunk: ChatCompletionChunk,
220 snapshot: ChatCompletion,
222 },
223 ContentDelta(ChatContentSnapshotEvent),
225 ContentDone(ChatContentDoneEvent),
227 RefusalDelta(ChatRefusalSnapshotEvent),
229 RefusalDone(ChatRefusalDoneEvent),
231 ToolCallArgumentsDelta(ChatToolArgumentsSnapshotEvent),
233 ToolCallArgumentsDone(ChatToolArgumentsDoneEvent),
235 LogProbsContentDelta(ChatLogProbsSnapshotEvent),
237 LogProbsContentDone(ChatLogProbsDoneEvent),
239 LogProbsRefusalDelta(ChatLogProbsSnapshotEvent),
241 LogProbsRefusalDone(ChatLogProbsDoneEvent),
243}
244
245#[derive(Debug, Default, Clone)]
246struct ChatChoiceEventState {
247 content_done: bool,
248 refusal_done: bool,
249 logprobs_content_done: bool,
250 logprobs_refusal_done: bool,
251 current_tool_call_index: Option<u32>,
252 done_tool_calls: HashSet<u32>,
253}
254
255#[derive(Debug)]
257pub struct ChatCompletionEventStream {
258 inner: ChatCompletionStream,
259 queue: VecDeque<ChatCompletionRuntimeEvent>,
260 choice_states: HashMap<u32, ChatChoiceEventState>,
261}
262
263impl ChatCompletionEventStream {
264 fn new(inner: ChatCompletionStream) -> Self {
265 Self {
266 inner,
267 queue: VecDeque::new(),
268 choice_states: HashMap::new(),
269 }
270 }
271
272 pub fn snapshot(&self) -> Option<ChatCompletion> {
274 self.inner.snapshot()
275 }
276
277 pub fn meta(&self) -> &ResponseMeta {
279 self.inner.meta()
280 }
281
282 pub async fn final_chat_completion(mut self) -> Result<Option<ChatCompletion>> {
284 while let Some(event) = futures_util::StreamExt::next(&mut self).await {
285 event?;
286 }
287 Ok(self.snapshot())
288 }
289
290 pub async fn final_chat_completion_checked(self) -> Result<Option<ChatCompletion>> {
292 let response = self.final_chat_completion().await?;
293 if let Some(response) = &response {
294 response.ensure_not_truncated()?;
295 }
296 Ok(response)
297 }
298
299 pub async fn final_message(self) -> Result<Option<ChatCompletionMessage>> {
301 Ok(self.final_chat_completion().await?.and_then(|response| {
302 response
303 .choices
304 .into_iter()
305 .next()
306 .map(|choice| choice.message)
307 }))
308 }
309
310 pub async fn final_content(self) -> Result<Option<String>> {
312 Ok(self
313 .final_message()
314 .await?
315 .and_then(|message| message.content))
316 }
317
318 pub async fn final_tool_calls(self) -> Result<Option<Vec<ChatCompletionToolCall>>> {
320 Ok(self
321 .final_message()
322 .await?
323 .map(|message| message.tool_calls)
324 .filter(|tool_calls| !tool_calls.is_empty()))
325 }
326
327 fn enqueue_events(&mut self, chunk: &ChatCompletionChunk, snapshot: &ChatCompletion) {
328 self.queue.push_back(ChatCompletionRuntimeEvent::Chunk {
329 chunk: chunk.clone(),
330 snapshot: snapshot.clone(),
331 });
332
333 for choice in &chunk.choices {
334 let Some(snapshot_choice) = snapshot
335 .choices
336 .iter()
337 .find(|item| item.index == choice.index)
338 else {
339 continue;
340 };
341 let state = self
342 .choice_states
343 .get(&choice.index)
344 .cloned()
345 .unwrap_or_default();
346 let (events, state) = derive_chat_choice_events(choice, snapshot_choice, state);
347 self.choice_states.insert(choice.index, state);
348 self.queue.extend(events);
349 }
350 }
351}
352
353impl Stream for ChatCompletionEventStream {
354 type Item = Result<ChatCompletionRuntimeEvent>;
355
356 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
357 let this = self.get_mut();
358 if let Some(event) = this.queue.pop_front() {
359 return Poll::Ready(Some(Ok(event)));
360 }
361
362 match Pin::new(&mut this.inner).poll_next(cx) {
363 Poll::Ready(Some(Ok(chunk))) => {
364 if let Some(snapshot) = this.inner.snapshot() {
365 this.enqueue_events(&chunk, &snapshot);
366 }
367 Poll::Ready(this.queue.pop_front().map(Ok))
368 }
369 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
370 Poll::Ready(None) => Poll::Ready(None),
371 Poll::Pending => Poll::Pending,
372 }
373 }
374}
375
376#[derive(Debug, Default, Clone)]
377struct ChatCompletionAccumulator {
378 id: Option<String>,
379 model: Option<String>,
380 created: Option<i64>,
381 object: String,
382 choices: HashMap<u32, AccumulatedChoice>,
383}
384
385#[derive(Debug, Default, Clone)]
386struct AccumulatedChoice {
387 role: Option<String>,
388 content: String,
389 refusal: String,
390 reasoning_content: String,
391 finish_reason: Option<String>,
392 tool_calls: HashMap<u32, ChatCompletionToolCall>,
393 logprobs: Option<ChatCompletionChoiceLogprobs>,
394}
395
396impl ChatCompletionAccumulator {
397 fn apply(&mut self, chunk: &ChatCompletionChunk) {
398 self.id = Some(chunk.id.clone());
399 self.model = Some(chunk.model.clone());
400 self.created = chunk.created.or(self.created);
401 self.object = if chunk.object.is_empty() {
402 "chat.completion".into()
403 } else {
404 chunk.object.clone()
405 };
406
407 for choice in &chunk.choices {
408 let state = self.choices.entry(choice.index).or_default();
409 state.finish_reason = choice.finish_reason.clone().or(state.finish_reason.clone());
410 if let Some(logprobs) = &choice.logprobs {
411 merge_logprobs(&mut state.logprobs, logprobs);
412 }
413 apply_delta(state, &choice.delta);
414 }
415 }
416
417 fn snapshot(&self) -> Option<ChatCompletion> {
418 let id = self.id.clone()?;
419 let model = self.model.clone().unwrap_or_default();
420 let mut choices: Vec<_> = self
421 .choices
422 .iter()
423 .map(|(index, choice)| {
424 let mut extra = BTreeMap::new();
425 if !choice.reasoning_content.is_empty() {
426 extra.insert(
427 "reasoning_content".into(),
428 Value::String(choice.reasoning_content.clone()),
429 );
430 }
431
432 crate::resources::ChatCompletionChoice {
433 index: *index,
434 finish_reason: choice.finish_reason.clone(),
435 message: ChatCompletionMessage {
436 role: choice.role.clone().unwrap_or_else(|| "assistant".into()),
437 content: (!choice.content.is_empty()).then(|| choice.content.clone()),
438 name: None,
439 tool_call_id: None,
440 tool_calls: {
441 let mut tool_calls = choice
442 .tool_calls
443 .iter()
444 .map(|(tool_call_index, tool_call)| {
445 (*tool_call_index, tool_call.clone())
446 })
447 .collect::<Vec<_>>();
448 tool_calls.sort_by_key(|(tool_call_index, _)| *tool_call_index);
449 tool_calls
450 .into_iter()
451 .map(|(_, tool_call)| tool_call)
452 .collect()
453 },
454 refusal: (!choice.refusal.is_empty()).then(|| choice.refusal.clone()),
455 reasoning_content: (!choice.reasoning_content.is_empty())
456 .then(|| choice.reasoning_content.clone()),
457 reasoning_details: Vec::new(),
458 extra,
459 },
460 logprobs: choice.logprobs.clone(),
461 extra: BTreeMap::new(),
462 }
463 })
464 .collect();
465 choices.sort_by_key(|choice| choice.index);
466
467 Some(ChatCompletion {
468 id,
469 object: self.object.clone(),
470 created: self.created,
471 model,
472 choices,
473 usage: None,
474 extra: BTreeMap::new(),
475 })
476 }
477}
478
479fn apply_delta(state: &mut AccumulatedChoice, delta: &ChatCompletionChunkDelta) {
480 if let Some(role) = &delta.role {
481 state.role = Some(role.clone());
482 }
483 if let Some(content) = &delta.content {
484 state.content.push_str(content);
485 }
486 if let Some(refusal) = &delta.refusal {
487 state.refusal.push_str(refusal);
488 }
489 if let Some(reasoning_content) = &delta.reasoning_content {
490 state.reasoning_content.push_str(reasoning_content);
491 }
492
493 for tool_call in &delta.tool_calls {
494 let index = tool_call.index.unwrap_or_default();
495 let entry = state
496 .tool_calls
497 .entry(index)
498 .or_insert_with(|| ChatCompletionToolCall {
499 id: tool_call.id.clone().unwrap_or_default(),
500 call_type: tool_call
501 .call_type
502 .clone()
503 .unwrap_or_else(|| "function".into()),
504 function: crate::resources::ChatCompletionFunctionCall {
505 name: tool_call
506 .function
507 .as_ref()
508 .and_then(|function| function.name.clone())
509 .unwrap_or_default(),
510 arguments: String::new(),
511 },
512 extra: BTreeMap::new(),
513 });
514
515 if let Some(id) = &tool_call.id {
516 entry.id = id.clone();
517 }
518 if let Some(call_type) = &tool_call.call_type {
519 entry.call_type = call_type.clone();
520 }
521 if let Some(function) = &tool_call.function {
522 if let Some(name) = &function.name {
523 entry.function.name = name.clone();
524 }
525 if let Some(arguments) = &function.arguments {
526 entry.function.arguments.push_str(arguments);
527 }
528 }
529 }
530}
531
532fn merge_logprobs(
533 target: &mut Option<ChatCompletionChoiceLogprobs>,
534 incoming: &ChatCompletionChoiceLogprobs,
535) {
536 let target_logprobs = target.get_or_insert_with(ChatCompletionChoiceLogprobs::default);
537 target_logprobs
538 .content
539 .extend(incoming.content.iter().cloned());
540 target_logprobs
541 .refusal
542 .extend(incoming.refusal.iter().cloned());
543 for (key, value) in &incoming.extra {
544 target_logprobs.extra.insert(key.clone(), value.clone());
545 }
546}
547
548fn logprobs_values(
549 logprobs: Option<&ChatCompletionChoiceLogprobs>,
550 field_name: &str,
551) -> Option<Vec<ChatCompletionTokenLogprob>> {
552 logprobs?.values(field_name).map(<[_]>::to_vec)
553}
554
555fn derive_chat_choice_events(
556 choice: &crate::resources::ChatCompletionChunkChoice,
557 snapshot_choice: &crate::resources::ChatCompletionChoice,
558 mut state: ChatChoiceEventState,
559) -> (Vec<ChatCompletionRuntimeEvent>, ChatChoiceEventState) {
560 let mut events = Vec::new();
561
562 if let Some(delta) = &choice.delta.content
563 && let Some(snapshot_content) = snapshot_choice.message.content.clone()
564 {
565 events.push(ChatCompletionRuntimeEvent::ContentDelta(
566 ChatContentSnapshotEvent {
567 choice_index: choice.index,
568 delta: delta.clone(),
569 parsed: parse_optional_json(&snapshot_content).map(JsonPayload::from),
570 snapshot: snapshot_content,
571 },
572 ));
573 }
574
575 if let Some(delta) = &choice.delta.refusal
576 && let Some(snapshot_refusal) = snapshot_choice.message.refusal.clone()
577 {
578 events.push(ChatCompletionRuntimeEvent::RefusalDelta(
579 ChatRefusalSnapshotEvent {
580 choice_index: choice.index,
581 delta: delta.clone(),
582 snapshot: snapshot_refusal,
583 },
584 ));
585 }
586
587 if let Some(values) = logprobs_values(choice.logprobs.as_ref(), "content") {
588 events.push(ChatCompletionRuntimeEvent::LogProbsContentDelta(
589 ChatLogProbsSnapshotEvent {
590 choice_index: choice.index,
591 snapshot: logprobs_values(snapshot_choice.logprobs.as_ref(), "content")
592 .unwrap_or_default(),
593 values,
594 },
595 ));
596 }
597
598 if let Some(values) = logprobs_values(choice.logprobs.as_ref(), "refusal") {
599 events.push(ChatCompletionRuntimeEvent::LogProbsRefusalDelta(
600 ChatLogProbsSnapshotEvent {
601 choice_index: choice.index,
602 snapshot: logprobs_values(snapshot_choice.logprobs.as_ref(), "refusal")
603 .unwrap_or_default(),
604 values,
605 },
606 ));
607 }
608
609 for tool_call in &choice.delta.tool_calls {
610 let tool_call_index = tool_call.index.unwrap_or_default();
611 if state.current_tool_call_index != Some(tool_call_index) {
612 if let Some(previous_index) = state.current_tool_call_index.take() {
613 emit_chat_tool_call_done(
614 &mut events,
615 choice.index,
616 previous_index,
617 snapshot_choice,
618 &mut state,
619 );
620 }
621 state.current_tool_call_index = Some(tool_call_index);
622 }
623
624 if let Some(arguments_delta) = tool_call
625 .function
626 .as_ref()
627 .and_then(|function| function.arguments.clone())
628 && let Some(snapshot_tool_call) = snapshot_choice
629 .message
630 .tool_calls
631 .get(tool_call_index as usize)
632 {
633 events.push(ChatCompletionRuntimeEvent::ToolCallArgumentsDelta(
634 ChatToolArgumentsSnapshotEvent {
635 choice_index: choice.index,
636 tool_call_index,
637 name: snapshot_tool_call.function.name.clone(),
638 parsed_arguments: parse_optional_json(&snapshot_tool_call.function.arguments)
639 .map(JsonPayload::from),
640 arguments_delta,
641 arguments: snapshot_tool_call.function.arguments.clone(),
642 },
643 ));
644 }
645 }
646
647 if choice.finish_reason.is_some() || snapshot_choice.finish_reason.is_some() {
648 emit_chat_choice_done_events(&mut events, choice.index, snapshot_choice, &mut state);
649 }
650
651 (events, state)
652}
653
654fn emit_chat_choice_done_events(
655 events: &mut Vec<ChatCompletionRuntimeEvent>,
656 choice_index: u32,
657 snapshot_choice: &crate::resources::ChatCompletionChoice,
658 state: &mut ChatChoiceEventState,
659) {
660 if !state.content_done
661 && let Some(content) = snapshot_choice.message.content.clone()
662 {
663 events.push(ChatCompletionRuntimeEvent::ContentDone(
664 ChatContentDoneEvent {
665 choice_index,
666 parsed: parse_optional_json(&content).map(JsonPayload::from),
667 content,
668 },
669 ));
670 state.content_done = true;
671 }
672
673 if !state.refusal_done
674 && let Some(refusal) = snapshot_choice.message.refusal.clone()
675 {
676 events.push(ChatCompletionRuntimeEvent::RefusalDone(
677 ChatRefusalDoneEvent {
678 choice_index,
679 refusal,
680 },
681 ));
682 state.refusal_done = true;
683 }
684
685 if !state.logprobs_content_done
686 && let Some(values) = logprobs_values(snapshot_choice.logprobs.as_ref(), "content")
687 {
688 events.push(ChatCompletionRuntimeEvent::LogProbsContentDone(
689 ChatLogProbsDoneEvent {
690 choice_index,
691 values,
692 },
693 ));
694 state.logprobs_content_done = true;
695 }
696
697 if !state.logprobs_refusal_done
698 && let Some(values) = logprobs_values(snapshot_choice.logprobs.as_ref(), "refusal")
699 {
700 events.push(ChatCompletionRuntimeEvent::LogProbsRefusalDone(
701 ChatLogProbsDoneEvent {
702 choice_index,
703 values,
704 },
705 ));
706 state.logprobs_refusal_done = true;
707 }
708
709 if let Some(tool_call_index) = state.current_tool_call_index.take() {
710 emit_chat_tool_call_done(
711 events,
712 choice_index,
713 tool_call_index,
714 snapshot_choice,
715 state,
716 );
717 }
718}
719
720fn emit_chat_tool_call_done(
721 events: &mut Vec<ChatCompletionRuntimeEvent>,
722 choice_index: u32,
723 tool_call_index: u32,
724 snapshot_choice: &crate::resources::ChatCompletionChoice,
725 state: &mut ChatChoiceEventState,
726) {
727 if state.done_tool_calls.contains(&tool_call_index) {
728 return;
729 }
730
731 let Some(snapshot_tool_call) = snapshot_choice
732 .message
733 .tool_calls
734 .get(tool_call_index as usize)
735 else {
736 return;
737 };
738
739 events.push(ChatCompletionRuntimeEvent::ToolCallArgumentsDone(
740 ChatToolArgumentsDoneEvent {
741 choice_index,
742 tool_call_index,
743 name: snapshot_tool_call.function.name.clone(),
744 parsed_arguments: parse_optional_json(&snapshot_tool_call.function.arguments)
745 .map(JsonPayload::from),
746 arguments: snapshot_tool_call.function.arguments.clone(),
747 },
748 ));
749 state.done_tool_calls.insert(tool_call_index);
750}