1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
2use std::fmt;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use async_stream::try_stream;
7use futures_util::{Stream, StreamExt};
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10use serde_json::{Map, Value};
11
12use super::sse::RawSseStream;
13use super::value_helpers::{
14 ensure_array_field, ensure_object, ensure_object_field, ensure_vec_len, merge_object,
15};
16use crate::error::{Result, SerializationError, StreamError};
17use crate::response_meta::ResponseMeta;
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct AssistantStreamEvent {
22 pub event: String,
24 pub data: Value,
26}
27
28impl AssistantStreamEvent {
29 pub fn is_error(&self) -> bool {
31 self.event == "error"
32 }
33
34 pub fn data_as<T>(&self) -> Result<T>
36 where
37 T: DeserializeOwned,
38 {
39 serde_json::from_value(self.data.clone()).map_err(|error| {
40 SerializationError::new(format!(
41 "Assistants 流事件反序列化失败: event={}, error={error}",
42 self.event
43 ))
44 .into()
45 })
46 }
47}
48
49#[derive(Debug, Clone, Default)]
51pub struct AssistantStreamSnapshot {
52 thread: Option<Value>,
53 runs: BTreeMap<String, Value>,
54 messages: BTreeMap<String, Value>,
55 run_steps: BTreeMap<String, Value>,
56 latest_run_id: Option<String>,
57 latest_message_id: Option<String>,
58 latest_run_step_id: Option<String>,
59}
60
61impl AssistantStreamSnapshot {
62 pub fn thread_raw(&self) -> Option<&Value> {
64 self.thread.as_ref()
65 }
66
67 pub fn thread<T>(&self) -> Option<T>
69 where
70 T: DeserializeOwned,
71 {
72 self.thread
73 .as_ref()
74 .and_then(|value| serde_json::from_value(value.clone()).ok())
75 }
76
77 pub fn run_raw(&self, run_id: &str) -> Option<&Value> {
79 self.runs.get(run_id)
80 }
81
82 pub fn latest_run_raw(&self) -> Option<&Value> {
84 self.latest_run_id
85 .as_deref()
86 .and_then(|run_id| self.runs.get(run_id))
87 }
88
89 pub fn message_raw(&self, message_id: &str) -> Option<&Value> {
91 self.messages.get(message_id)
92 }
93
94 pub fn latest_message_raw(&self) -> Option<&Value> {
96 self.latest_message_id
97 .as_deref()
98 .and_then(|message_id| self.messages.get(message_id))
99 }
100
101 pub fn run<T>(&self, run_id: &str) -> Option<T>
103 where
104 T: DeserializeOwned,
105 {
106 self.run_raw(run_id)
107 .and_then(|value| serde_json::from_value(value.clone()).ok())
108 }
109
110 pub fn latest_run<T>(&self) -> Option<T>
112 where
113 T: DeserializeOwned,
114 {
115 self.latest_run_raw()
116 .and_then(|value| serde_json::from_value(value.clone()).ok())
117 }
118
119 pub fn message<T>(&self, message_id: &str) -> Option<T>
121 where
122 T: DeserializeOwned,
123 {
124 self.messages
125 .get(message_id)
126 .and_then(|value| serde_json::from_value(value.clone()).ok())
127 }
128
129 pub fn latest_message<T>(&self) -> Option<T>
131 where
132 T: DeserializeOwned,
133 {
134 self.latest_message_id
135 .as_deref()
136 .and_then(|message_id| self.message(message_id))
137 }
138
139 pub fn run_step_raw(&self, step_id: &str) -> Option<&Value> {
141 self.run_steps.get(step_id)
142 }
143
144 pub fn latest_run_step_raw(&self) -> Option<&Value> {
146 self.latest_run_step_id
147 .as_deref()
148 .and_then(|step_id| self.run_steps.get(step_id))
149 }
150
151 pub fn run_step<T>(&self, step_id: &str) -> Option<T>
153 where
154 T: DeserializeOwned,
155 {
156 self.run_steps
157 .get(step_id)
158 .and_then(|value| serde_json::from_value(value.clone()).ok())
159 }
160
161 pub fn latest_run_step<T>(&self) -> Option<T>
163 where
164 T: DeserializeOwned,
165 {
166 self.latest_run_step_id
167 .as_deref()
168 .and_then(|step_id| self.run_step(step_id))
169 }
170
171 fn apply(&mut self, event: &AssistantStreamEvent) {
172 match event.event.as_str() {
173 "thread.created" => {
174 self.thread = Some(event.data.clone());
175 }
176 "thread.run.created"
177 | "thread.run.queued"
178 | "thread.run.in_progress"
179 | "thread.run.requires_action"
180 | "thread.run.completed"
181 | "thread.run.incomplete"
182 | "thread.run.failed"
183 | "thread.run.cancelling"
184 | "thread.run.cancelled"
185 | "thread.run.expired" => {
186 if let Some(id) = event.data.get("id").and_then(Value::as_str) {
187 self.latest_run_id = Some(id.to_owned());
188 self.runs.insert(id.to_owned(), event.data.clone());
189 }
190 }
191 "thread.message.created"
192 | "thread.message.in_progress"
193 | "thread.message.completed"
194 | "thread.message.incomplete" => {
195 if let Some(id) = event.data.get("id").and_then(Value::as_str) {
196 self.latest_message_id = Some(id.to_owned());
197 self.messages.insert(id.to_owned(), event.data.clone());
198 }
199 }
200 "thread.run.step.created"
201 | "thread.run.step.in_progress"
202 | "thread.run.step.completed"
203 | "thread.run.step.failed"
204 | "thread.run.step.cancelled"
205 | "thread.run.step.expired" => {
206 if let Some(id) = event.data.get("id").and_then(Value::as_str) {
207 self.latest_run_step_id = Some(id.to_owned());
208 self.run_steps.insert(id.to_owned(), event.data.clone());
209 }
210 }
211 "thread.message.delta" => {
212 if let Some(id) = event.data.get("id").and_then(Value::as_str) {
213 self.latest_message_id = Some(id.to_owned());
214 let entry = self
215 .messages
216 .entry(id.to_owned())
217 .or_insert_with(|| empty_assistant_snapshot(id, "thread.message"));
218 apply_message_delta(entry, &event.data);
219 }
220 }
221 "thread.run.step.delta" => {
222 if let Some(id) = event.data.get("id").and_then(Value::as_str) {
223 self.latest_run_step_id = Some(id.to_owned());
224 let entry = self
225 .run_steps
226 .entry(id.to_owned())
227 .or_insert_with(|| empty_assistant_snapshot(id, "thread.run.step"));
228 apply_run_step_delta(entry, &event.data);
229 }
230 }
231 _ => {}
232 }
233 }
234}
235
236pub struct AssistantStream {
238 inner: Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent>> + Send>>,
239 meta: ResponseMeta,
240 snapshot: AssistantStreamSnapshot,
241}
242
243impl AssistantStream {
244 #[allow(tail_expr_drop_order)]
246 pub fn new(raw: RawSseStream) -> Self {
247 let meta = raw.meta().clone();
248 let stream = try_stream! {
249 let mut raw = raw;
250 while let Some(event) = raw.next().await {
251 let event = event?;
252 if event.data == "[DONE]" {
253 break;
254 }
255
256 let data = serde_json::from_str::<Value>(&event.data).map_err(|error| {
257 StreamError::new(format!(
258 "解析 Assistants SSE 事件失败: event={:?}, error={error}, payload={}",
259 event.event,
260 event.data
261 ))
262 })?;
263 let event_name = event
264 .event
265 .or_else(|| data.get("event").and_then(Value::as_str).map(str::to_owned))
266 .or_else(|| data.get("type").and_then(Value::as_str).map(str::to_owned))
267 .unwrap_or_else(|| "message".into());
268
269 yield AssistantStreamEvent {
270 event: event_name,
271 data,
272 };
273 }
274 };
275
276 Self {
277 inner: Box::pin(stream),
278 meta,
279 snapshot: AssistantStreamSnapshot::default(),
280 }
281 }
282
283 pub fn snapshot(&self) -> &AssistantStreamSnapshot {
285 &self.snapshot
286 }
287
288 pub fn meta(&self) -> &ResponseMeta {
290 &self.meta
291 }
292
293 pub async fn final_snapshot(mut self) -> Result<AssistantStreamSnapshot> {
295 while let Some(event) = self.next().await {
296 event?;
297 }
298 Ok(self.snapshot)
299 }
300
301 pub fn events(self) -> AssistantEventStream {
303 AssistantEventStream::new(self)
304 }
305}
306
307impl fmt::Debug for AssistantStream {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 f.debug_struct("AssistantStream")
310 .field("meta", &self.meta)
311 .field("snapshot", &self.snapshot)
312 .finish()
313 }
314}
315
316impl Stream for AssistantStream {
317 type Item = Result<AssistantStreamEvent>;
318
319 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320 let this = self.get_mut();
321 match this.inner.as_mut().poll_next(cx) {
322 Poll::Ready(Some(Ok(event))) => {
323 this.snapshot.apply(&event);
324 Poll::Ready(Some(Ok(event)))
325 }
326 other => other,
327 }
328 }
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
333pub struct AssistantMessageCreatedEvent {
334 pub message: Value,
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
340pub struct AssistantMessageDeltaEvent {
341 pub delta: Value,
343 pub snapshot: Value,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
349pub struct AssistantMessageDoneEvent {
350 pub message: Value,
352}
353
354#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
356pub struct AssistantRunStepCreatedEvent {
357 pub run_step: Value,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
363pub struct AssistantRunStepDeltaEvent {
364 pub delta: Value,
366 pub snapshot: Value,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
372pub struct AssistantRunStepDoneEvent {
373 pub run_step: Value,
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
379pub struct AssistantToolCallCreatedEvent {
380 pub run_step_id: Option<String>,
382 pub tool_call_index: usize,
384 pub tool_call: Value,
386}
387
388#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390pub struct AssistantToolCallDeltaEvent {
391 pub run_step_id: Option<String>,
393 pub tool_call_index: usize,
395 pub delta: Value,
397 pub snapshot: Value,
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
403pub struct AssistantToolCallDoneEvent {
404 pub run_step_id: Option<String>,
406 pub tool_call_index: usize,
408 pub tool_call: Value,
410}
411
412#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414pub struct AssistantTextCreatedEvent {
415 pub message_id: Option<String>,
417 pub content_index: usize,
419 pub text: Value,
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
425pub struct AssistantTextDeltaEvent {
426 pub message_id: Option<String>,
428 pub content_index: usize,
430 pub delta: Value,
432 pub snapshot: Value,
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
438pub struct AssistantTextDoneEvent {
439 pub message_id: Option<String>,
441 pub content_index: usize,
443 pub text: Value,
445 pub message: Value,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
451pub struct AssistantImageFileDoneEvent {
452 pub message_id: Option<String>,
454 pub content_index: usize,
456 pub image_file: Value,
458 pub message: Value,
460}
461
462#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
464pub enum AssistantRuntimeEvent {
465 Event(AssistantStreamEvent),
467 MessageCreated(AssistantMessageCreatedEvent),
469 MessageDelta(AssistantMessageDeltaEvent),
471 MessageDone(AssistantMessageDoneEvent),
473 RunStepCreated(AssistantRunStepCreatedEvent),
475 RunStepDelta(AssistantRunStepDeltaEvent),
477 RunStepDone(AssistantRunStepDoneEvent),
479 ToolCallCreated(AssistantToolCallCreatedEvent),
481 ToolCallDelta(AssistantToolCallDeltaEvent),
483 ToolCallDone(AssistantToolCallDoneEvent),
485 TextCreated(AssistantTextCreatedEvent),
487 TextDelta(AssistantTextDeltaEvent),
489 TextDone(AssistantTextDoneEvent),
491 ImageFileDone(AssistantImageFileDoneEvent),
493}
494
495#[derive(Debug)]
497pub struct AssistantEventStream {
498 inner: AssistantStream,
499 queue: VecDeque<AssistantRuntimeEvent>,
500 seen_message_texts: HashMap<String, HashSet<usize>>,
501 seen_step_tool_calls: HashMap<String, HashSet<usize>>,
502}
503
504impl AssistantEventStream {
505 fn new(inner: AssistantStream) -> Self {
506 Self {
507 inner,
508 queue: VecDeque::new(),
509 seen_message_texts: HashMap::new(),
510 seen_step_tool_calls: HashMap::new(),
511 }
512 }
513
514 pub fn snapshot(&self) -> &AssistantStreamSnapshot {
516 self.inner.snapshot()
517 }
518
519 pub fn meta(&self) -> &ResponseMeta {
521 self.inner.meta()
522 }
523
524 pub async fn final_snapshot(mut self) -> Result<AssistantStreamSnapshot> {
526 while let Some(event) = self.next().await {
527 event?;
528 }
529 Ok(self.inner.snapshot)
530 }
531
532 fn enqueue_events(&mut self, event: &AssistantStreamEvent) {
533 self.queue
534 .push_back(AssistantRuntimeEvent::Event(event.clone()));
535
536 match event.event.as_str() {
537 "thread.message.created" => {
538 self.queue.push_back(AssistantRuntimeEvent::MessageCreated(
539 AssistantMessageCreatedEvent {
540 message: event.data.clone(),
541 },
542 ));
543 self.enqueue_text_created_from_message(&event.data);
544 }
545 "thread.message.delta" => {
546 let message_id = event
547 .data
548 .get("id")
549 .and_then(Value::as_str)
550 .map(str::to_owned);
551 let snapshot = message_id
552 .as_deref()
553 .and_then(|id| self.inner.snapshot().message_raw(id))
554 .cloned()
555 .unwrap_or_else(|| event.data.clone());
556 self.queue.push_back(AssistantRuntimeEvent::MessageDelta(
557 AssistantMessageDeltaEvent {
558 delta: event.data.get("delta").cloned().unwrap_or(Value::Null),
559 snapshot: snapshot.clone(),
560 },
561 ));
562 self.enqueue_text_delta(&message_id, event, &snapshot);
563 }
564 "thread.message.completed" | "thread.message.incomplete" => {
565 let message = event
566 .data
567 .get("id")
568 .and_then(Value::as_str)
569 .and_then(|id| self.inner.snapshot().message_raw(id))
570 .cloned()
571 .unwrap_or_else(|| event.data.clone());
572 self.queue.push_back(AssistantRuntimeEvent::MessageDone(
573 AssistantMessageDoneEvent {
574 message: message.clone(),
575 },
576 ));
577 self.enqueue_message_done_content(&message);
578 }
579 "thread.run.step.created" => {
580 self.queue.push_back(AssistantRuntimeEvent::RunStepCreated(
581 AssistantRunStepCreatedEvent {
582 run_step: event.data.clone(),
583 },
584 ));
585 }
586 "thread.run.step.delta" => {
587 let step_id = event
588 .data
589 .get("id")
590 .and_then(Value::as_str)
591 .map(str::to_owned);
592 let snapshot = step_id
593 .as_deref()
594 .and_then(|id| self.inner.snapshot().run_step_raw(id))
595 .cloned()
596 .unwrap_or_else(|| event.data.clone());
597 self.queue.push_back(AssistantRuntimeEvent::RunStepDelta(
598 AssistantRunStepDeltaEvent {
599 delta: event.data.get("delta").cloned().unwrap_or(Value::Null),
600 snapshot: snapshot.clone(),
601 },
602 ));
603 self.enqueue_tool_call_delta(&step_id, event, &snapshot);
604 }
605 "thread.run.step.completed"
606 | "thread.run.step.failed"
607 | "thread.run.step.cancelled"
608 | "thread.run.step.expired" => {
609 let run_step = event
610 .data
611 .get("id")
612 .and_then(Value::as_str)
613 .and_then(|id| self.inner.snapshot().run_step_raw(id))
614 .cloned()
615 .unwrap_or_else(|| event.data.clone());
616 self.queue.push_back(AssistantRuntimeEvent::RunStepDone(
617 AssistantRunStepDoneEvent {
618 run_step: run_step.clone(),
619 },
620 ));
621 self.enqueue_tool_call_done(&run_step);
622 }
623 _ => {}
624 }
625 }
626
627 fn enqueue_text_created_from_message(&mut self, message: &Value) {
628 let message_id = message.get("id").and_then(Value::as_str).map(str::to_owned);
629 let Some(content) = message.get("content").and_then(Value::as_array) else {
630 return;
631 };
632 for (index, part) in content.iter().enumerate() {
633 if part.get("type").and_then(Value::as_str) == Some("text") {
634 self.mark_message_text_seen(&message_id, index);
635 self.queue.push_back(AssistantRuntimeEvent::TextCreated(
636 AssistantTextCreatedEvent {
637 message_id: message_id.clone(),
638 content_index: index,
639 text: part.clone(),
640 },
641 ));
642 }
643 }
644 }
645
646 fn enqueue_text_delta(
647 &mut self,
648 message_id: &Option<String>,
649 event: &AssistantStreamEvent,
650 snapshot: &Value,
651 ) {
652 let Some(content_deltas) = event
653 .data
654 .get("delta")
655 .and_then(|value| value.get("content"))
656 .and_then(Value::as_array)
657 else {
658 return;
659 };
660
661 let snapshot_content = snapshot
662 .get("content")
663 .and_then(Value::as_array)
664 .cloned()
665 .unwrap_or_default();
666
667 for content_delta in content_deltas {
668 let index = content_delta
669 .get("index")
670 .and_then(Value::as_u64)
671 .map(|value| value as usize)
672 .unwrap_or_default();
673 if content_delta.get("type").and_then(Value::as_str) != Some("text") {
674 continue;
675 }
676
677 if !self.message_text_seen(message_id, index)
678 && let Some(snapshot_part) = snapshot_content.get(index)
679 {
680 self.mark_message_text_seen(message_id, index);
681 self.queue.push_back(AssistantRuntimeEvent::TextCreated(
682 AssistantTextCreatedEvent {
683 message_id: message_id.clone(),
684 content_index: index,
685 text: snapshot_part.clone(),
686 },
687 ));
688 }
689
690 if let Some(snapshot_part) = snapshot_content.get(index) {
691 self.queue
692 .push_back(AssistantRuntimeEvent::TextDelta(AssistantTextDeltaEvent {
693 message_id: message_id.clone(),
694 content_index: index,
695 delta: content_delta.clone(),
696 snapshot: snapshot_part.clone(),
697 }));
698 }
699 }
700 }
701
702 fn enqueue_message_done_content(&mut self, message: &Value) {
703 let message_id = message.get("id").and_then(Value::as_str).map(str::to_owned);
704 let Some(content) = message.get("content").and_then(Value::as_array) else {
705 return;
706 };
707 for (index, part) in content.iter().enumerate() {
708 match part.get("type").and_then(Value::as_str) {
709 Some("text") => {
710 self.mark_message_text_seen(&message_id, index);
711 self.queue
712 .push_back(AssistantRuntimeEvent::TextDone(AssistantTextDoneEvent {
713 message_id: message_id.clone(),
714 content_index: index,
715 text: part.clone(),
716 message: message.clone(),
717 }));
718 }
719 Some("image_file") => {
720 self.queue.push_back(AssistantRuntimeEvent::ImageFileDone(
721 AssistantImageFileDoneEvent {
722 message_id: message_id.clone(),
723 content_index: index,
724 image_file: part.clone(),
725 message: message.clone(),
726 },
727 ));
728 }
729 _ => {}
730 }
731 }
732 }
733
734 fn enqueue_tool_call_delta(
735 &mut self,
736 step_id: &Option<String>,
737 event: &AssistantStreamEvent,
738 snapshot: &Value,
739 ) {
740 let Some(tool_call_deltas) = event
741 .data
742 .get("delta")
743 .and_then(|value| value.get("step_details"))
744 .and_then(|value| value.get("tool_calls"))
745 .and_then(Value::as_array)
746 else {
747 return;
748 };
749 let snapshot_calls = snapshot
750 .get("step_details")
751 .and_then(|value| value.get("tool_calls"))
752 .and_then(Value::as_array)
753 .cloned()
754 .unwrap_or_default();
755 for tool_call_delta in tool_call_deltas {
756 let index = tool_call_delta
757 .get("index")
758 .and_then(Value::as_u64)
759 .map(|value| value as usize)
760 .unwrap_or_default();
761 if !self.step_tool_call_seen(step_id, index)
762 && let Some(snapshot_call) = snapshot_calls.get(index)
763 {
764 self.mark_step_tool_call_seen(step_id, index);
765 self.queue.push_back(AssistantRuntimeEvent::ToolCallCreated(
766 AssistantToolCallCreatedEvent {
767 run_step_id: step_id.clone(),
768 tool_call_index: index,
769 tool_call: snapshot_call.clone(),
770 },
771 ));
772 }
773 if let Some(snapshot_call) = snapshot_calls.get(index) {
774 self.queue.push_back(AssistantRuntimeEvent::ToolCallDelta(
775 AssistantToolCallDeltaEvent {
776 run_step_id: step_id.clone(),
777 tool_call_index: index,
778 delta: tool_call_delta.clone(),
779 snapshot: snapshot_call.clone(),
780 },
781 ));
782 }
783 }
784 }
785
786 fn enqueue_tool_call_done(&mut self, run_step: &Value) {
787 let step_id = run_step
788 .get("id")
789 .and_then(Value::as_str)
790 .map(str::to_owned);
791 let Some(tool_calls) = run_step
792 .get("step_details")
793 .and_then(|value| value.get("tool_calls"))
794 .and_then(Value::as_array)
795 else {
796 return;
797 };
798 for (index, tool_call) in tool_calls.iter().enumerate() {
799 self.mark_step_tool_call_seen(&step_id, index);
800 self.queue.push_back(AssistantRuntimeEvent::ToolCallDone(
801 AssistantToolCallDoneEvent {
802 run_step_id: step_id.clone(),
803 tool_call_index: index,
804 tool_call: tool_call.clone(),
805 },
806 ));
807 }
808 }
809
810 fn message_text_seen(&self, message_id: &Option<String>, index: usize) -> bool {
811 message_id
812 .as_deref()
813 .and_then(|id| self.seen_message_texts.get(id))
814 .is_some_and(|set| set.contains(&index))
815 }
816
817 fn mark_message_text_seen(&mut self, message_id: &Option<String>, index: usize) {
818 let Some(message_id) = message_id else {
819 return;
820 };
821 self.seen_message_texts
822 .entry(message_id.clone())
823 .or_default()
824 .insert(index);
825 }
826
827 fn step_tool_call_seen(&self, step_id: &Option<String>, index: usize) -> bool {
828 step_id
829 .as_deref()
830 .and_then(|id| self.seen_step_tool_calls.get(id))
831 .is_some_and(|set| set.contains(&index))
832 }
833
834 fn mark_step_tool_call_seen(&mut self, step_id: &Option<String>, index: usize) {
835 let Some(step_id) = step_id else {
836 return;
837 };
838 self.seen_step_tool_calls
839 .entry(step_id.clone())
840 .or_default()
841 .insert(index);
842 }
843}
844
845impl Stream for AssistantEventStream {
846 type Item = Result<AssistantRuntimeEvent>;
847
848 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
849 let this = self.get_mut();
850 if let Some(event) = this.queue.pop_front() {
851 return Poll::Ready(Some(Ok(event)));
852 }
853
854 match Pin::new(&mut this.inner).poll_next(cx) {
855 Poll::Ready(Some(Ok(event))) => {
856 this.enqueue_events(&event);
857 Poll::Ready(this.queue.pop_front().map(Ok))
858 }
859 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
860 Poll::Ready(None) => Poll::Ready(None),
861 Poll::Pending => Poll::Pending,
862 }
863 }
864}
865
866fn empty_assistant_snapshot(id: &str, object: &str) -> Value {
867 let mut map = Map::new();
868 map.insert("id".into(), Value::String(id.to_owned()));
869 map.insert("object".into(), Value::String(object.to_owned()));
870 Value::Object(map)
871}
872
873fn apply_message_delta(message: &mut Value, event: &Value) {
874 let Some(delta) = event.get("delta") else {
875 return;
876 };
877 if let Some(role) = delta.get("role").and_then(Value::as_str) {
878 ensure_object(message).insert("role".into(), Value::String(role.to_owned()));
879 }
880
881 let Some(content_deltas) = delta.get("content").and_then(Value::as_array) else {
882 return;
883 };
884 let content = ensure_array_field(message, "content");
885 for content_delta in content_deltas {
886 let index = content_delta
887 .get("index")
888 .and_then(Value::as_u64)
889 .map(|value| value as usize)
890 .unwrap_or(content.len());
891 ensure_vec_len(content, index + 1);
892 if content[index].is_null() {
893 content[index] = Value::Object(Map::new());
894 }
895
896 let slot = &mut content[index];
897 let slot_object = ensure_object(slot);
898 if let Some(part_type) = content_delta.get("type").and_then(Value::as_str) {
899 slot_object.insert("type".into(), Value::String(part_type.to_owned()));
900 match part_type {
901 "text" => {
902 let text_object = ensure_object_field(slot, "text");
903 let value = content_delta
904 .get("text")
905 .and_then(|value| value.get("value"))
906 .and_then(Value::as_str)
907 .unwrap_or("");
908 let current = text_object
909 .get("value")
910 .and_then(Value::as_str)
911 .unwrap_or("");
912 text_object.insert("value".into(), Value::String(format!("{current}{value}")));
913 }
914 "refusal" => {
915 let value = content_delta
916 .get("refusal")
917 .and_then(Value::as_str)
918 .unwrap_or("");
919 let current = slot_object
920 .get("refusal")
921 .and_then(Value::as_str)
922 .unwrap_or("");
923 slot_object
924 .insert("refusal".into(), Value::String(format!("{current}{value}")));
925 }
926 _ => merge_object(slot_object, content_delta),
927 }
928 }
929 }
930}
931
932fn apply_run_step_delta(run_step: &mut Value, event: &Value) {
933 let Some(delta) = event.get("delta") else {
934 return;
935 };
936 let Some(step_details) = delta.get("step_details") else {
937 return;
938 };
939 let step_details_object = ensure_object_field(run_step, "step_details");
940 if let Some(step_type) = step_details.get("type").and_then(Value::as_str) {
941 step_details_object.insert("type".into(), Value::String(step_type.to_owned()));
942 match step_type {
943 "message_creation" => {
944 if let Some(message_creation) = step_details.get("message_creation") {
945 let target = step_details_object
946 .entry("message_creation")
947 .or_insert_with(|| Value::Object(Map::new()));
948 merge_object(ensure_object(target), message_creation);
949 }
950 }
951 "tool_calls" => {
952 let tool_calls = step_details
953 .get("tool_calls")
954 .and_then(Value::as_array)
955 .cloned()
956 .unwrap_or_default();
957 let target = step_details_object
958 .entry("tool_calls")
959 .or_insert_with(|| Value::Array(Vec::new()));
960 let target_calls = if let Some(array) = target.as_array_mut() {
961 array
962 } else {
963 *target = Value::Array(Vec::new());
964 target.as_array_mut().expect("tool_calls must be array")
965 };
966 for tool_call in tool_calls {
967 let index = tool_call
968 .get("index")
969 .and_then(Value::as_u64)
970 .map(|value| value as usize)
971 .unwrap_or(target_calls.len());
972 ensure_vec_len(target_calls, index + 1);
973 if target_calls[index].is_null() {
974 target_calls[index] = Value::Object(Map::new());
975 }
976 merge_tool_call_delta(&mut target_calls[index], &tool_call);
977 }
978 }
979 _ => merge_object(step_details_object, step_details),
980 }
981 }
982}
983
984fn merge_tool_call_delta(target: &mut Value, delta: &Value) {
985 let target_object = ensure_object(target);
986 if let Some(delta_object) = delta.as_object() {
987 for (key, value) in delta_object {
988 if matches!(key.as_str(), "function" | "code_interpreter")
989 || matches!(value, Value::Null)
990 {
991 continue;
992 }
993 target_object.insert(key.clone(), value.clone());
994 }
995 }
996 if let Some(function_delta) = delta.get("function") {
997 let function_target = target_object
998 .entry("function")
999 .or_insert_with(|| Value::Object(Map::new()));
1000 let function_object = ensure_object(function_target);
1001 if let Some(arguments) = function_delta.get("arguments").and_then(Value::as_str) {
1002 let current = function_object
1003 .get("arguments")
1004 .and_then(Value::as_str)
1005 .unwrap_or("");
1006 function_object.insert(
1007 "arguments".into(),
1008 Value::String(format!("{current}{arguments}")),
1009 );
1010 }
1011 if let Some(name) = function_delta.get("name").and_then(Value::as_str) {
1012 function_object.insert("name".into(), Value::String(name.to_owned()));
1013 }
1014 }
1015 if let Some(code_interpreter_delta) = delta.get("code_interpreter") {
1016 let code_interpreter_target = target_object
1017 .entry("code_interpreter")
1018 .or_insert_with(|| Value::Object(Map::new()));
1019 let code_interpreter_object = ensure_object(code_interpreter_target);
1020 if let Some(input) = code_interpreter_delta.get("input").and_then(Value::as_str) {
1021 let current = code_interpreter_object
1022 .get("input")
1023 .and_then(Value::as_str)
1024 .unwrap_or("");
1025 code_interpreter_object
1026 .insert("input".into(), Value::String(format!("{current}{input}")));
1027 }
1028 if let Some(outputs) = code_interpreter_delta
1029 .get("outputs")
1030 .and_then(Value::as_array)
1031 {
1032 let target_outputs = code_interpreter_object
1033 .entry("outputs")
1034 .or_insert_with(|| Value::Array(Vec::new()));
1035 let output_array = if let Some(array) = target_outputs.as_array_mut() {
1036 array
1037 } else {
1038 *target_outputs = Value::Array(Vec::new());
1039 target_outputs
1040 .as_array_mut()
1041 .expect("outputs must be array")
1042 };
1043 output_array.extend(outputs.iter().cloned());
1044 }
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::{AssistantStreamEvent, AssistantStreamSnapshot};
1051 use serde_json::json;
1052
1053 #[test]
1054 fn test_should_merge_assistant_deltas_into_snapshot_before_created_events() {
1055 let mut snapshot = AssistantStreamSnapshot::default();
1056
1057 snapshot.apply(&AssistantStreamEvent {
1058 event: "thread.message.delta".into(),
1059 data: json!({
1060 "id": "msg_1",
1061 "object": "thread.message.delta",
1062 "delta": {
1063 "content": [{
1064 "index": 0,
1065 "type": "text",
1066 "text": { "value": "hel" }
1067 }]
1068 }
1069 }),
1070 });
1071 snapshot.apply(&AssistantStreamEvent {
1072 event: "thread.message.delta".into(),
1073 data: json!({
1074 "id": "msg_1",
1075 "object": "thread.message.delta",
1076 "delta": {
1077 "content": [{
1078 "index": 0,
1079 "type": "text",
1080 "text": { "value": "lo" }
1081 }]
1082 }
1083 }),
1084 });
1085 snapshot.apply(&AssistantStreamEvent {
1086 event: "thread.run.step.delta".into(),
1087 data: json!({
1088 "id": "step_1",
1089 "object": "thread.run.step.delta",
1090 "delta": {
1091 "step_details": {
1092 "type": "tool_calls",
1093 "tool_calls": [{
1094 "index": 0,
1095 "type": "function",
1096 "function": {
1097 "name": "lookup_weather",
1098 "arguments": "{\"city\":\"Sha"
1099 }
1100 }]
1101 }
1102 }
1103 }),
1104 });
1105 snapshot.apply(&AssistantStreamEvent {
1106 event: "thread.run.step.delta".into(),
1107 data: json!({
1108 "id": "step_1",
1109 "object": "thread.run.step.delta",
1110 "delta": {
1111 "step_details": {
1112 "type": "tool_calls",
1113 "tool_calls": [{
1114 "index": 0,
1115 "type": "function",
1116 "function": {
1117 "arguments": "nghai\"}"
1118 }
1119 }]
1120 }
1121 }
1122 }),
1123 });
1124
1125 assert_eq!(
1126 snapshot
1127 .message_raw("msg_1")
1128 .and_then(|message| message.get("content"))
1129 .and_then(serde_json::Value::as_array)
1130 .and_then(|content| content.first())
1131 .and_then(|part| part.get("text"))
1132 .and_then(|text| text.get("value"))
1133 .and_then(serde_json::Value::as_str),
1134 Some("hello"),
1135 );
1136 assert_eq!(
1137 snapshot
1138 .run_step_raw("step_1")
1139 .and_then(|step| step.get("step_details"))
1140 .and_then(|details| details.get("tool_calls"))
1141 .and_then(serde_json::Value::as_array)
1142 .and_then(|tool_calls| tool_calls.first())
1143 .and_then(|tool_call| tool_call.get("function"))
1144 .and_then(|function| function.get("arguments"))
1145 .and_then(serde_json::Value::as_str),
1146 Some("{\"city\":\"Shanghai\"}"),
1147 );
1148 }
1149}