Skip to main content

openai_core/stream/
assistant.rs

1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
2use std::fmt;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use async_stream::try_stream;
7use futures_util::{Stream, StreamExt};
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10use serde_json::{Map, Value};
11
12use super::sse::RawSseStream;
13use super::value_helpers::{
14    ensure_array_field, ensure_object, ensure_object_field, ensure_vec_len, merge_object,
15};
16use crate::error::{Result, SerializationError, StreamError};
17use crate::response_meta::ResponseMeta;
18
19/// 表示 Assistants/Beta Threads SSE 事件。
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub struct AssistantStreamEvent {
22    /// SSE 事件名。
23    pub event: String,
24    /// 事件对应的 JSON 负载。
25    pub data: Value,
26}
27
28impl AssistantStreamEvent {
29    /// 判断当前事件是否为错误事件。
30    pub fn is_error(&self) -> bool {
31        self.event == "error"
32    }
33
34    /// 把事件负载解析为指定类型。
35    pub fn data_as<T>(&self) -> Result<T>
36    where
37        T: DeserializeOwned,
38    {
39        serde_json::from_value(self.data.clone()).map_err(|error| {
40            SerializationError::new(format!(
41                "Assistants 流事件反序列化失败: event={}, error={error}",
42                self.event
43            ))
44            .into()
45        })
46    }
47}
48
49/// 表示 Assistants 流运行时累积出的快照。
50#[derive(Debug, Clone, Default)]
51pub struct AssistantStreamSnapshot {
52    thread: Option<Value>,
53    runs: BTreeMap<String, Value>,
54    messages: BTreeMap<String, Value>,
55    run_steps: BTreeMap<String, Value>,
56    latest_run_id: Option<String>,
57    latest_message_id: Option<String>,
58    latest_run_step_id: Option<String>,
59}
60
61impl AssistantStreamSnapshot {
62    /// 返回最新的 thread 原始快照。
63    pub fn thread_raw(&self) -> Option<&Value> {
64        self.thread.as_ref()
65    }
66
67    /// 返回最新的 thread 快照。
68    pub fn thread<T>(&self) -> Option<T>
69    where
70        T: DeserializeOwned,
71    {
72        self.thread
73            .as_ref()
74            .and_then(|value| serde_json::from_value(value.clone()).ok())
75    }
76
77    /// 返回指定 run 的原始快照。
78    pub fn run_raw(&self, run_id: &str) -> Option<&Value> {
79        self.runs.get(run_id)
80    }
81
82    /// 返回最新 run 的原始快照。
83    pub fn latest_run_raw(&self) -> Option<&Value> {
84        self.latest_run_id
85            .as_deref()
86            .and_then(|run_id| self.runs.get(run_id))
87    }
88
89    /// 返回指定 message 的原始快照。
90    pub fn message_raw(&self, message_id: &str) -> Option<&Value> {
91        self.messages.get(message_id)
92    }
93
94    /// 返回最新 message 的原始快照。
95    pub fn latest_message_raw(&self) -> Option<&Value> {
96        self.latest_message_id
97            .as_deref()
98            .and_then(|message_id| self.messages.get(message_id))
99    }
100
101    /// 返回指定 run 的结构化快照。
102    pub fn run<T>(&self, run_id: &str) -> Option<T>
103    where
104        T: DeserializeOwned,
105    {
106        self.run_raw(run_id)
107            .and_then(|value| serde_json::from_value(value.clone()).ok())
108    }
109
110    /// 返回最新 run 的结构化快照。
111    pub fn latest_run<T>(&self) -> Option<T>
112    where
113        T: DeserializeOwned,
114    {
115        self.latest_run_raw()
116            .and_then(|value| serde_json::from_value(value.clone()).ok())
117    }
118
119    /// 返回指定 message 的结构化快照。
120    pub fn message<T>(&self, message_id: &str) -> Option<T>
121    where
122        T: DeserializeOwned,
123    {
124        self.messages
125            .get(message_id)
126            .and_then(|value| serde_json::from_value(value.clone()).ok())
127    }
128
129    /// 返回最新 message 的结构化快照。
130    pub fn latest_message<T>(&self) -> Option<T>
131    where
132        T: DeserializeOwned,
133    {
134        self.latest_message_id
135            .as_deref()
136            .and_then(|message_id| self.message(message_id))
137    }
138
139    /// 返回指定 run step 的原始快照。
140    pub fn run_step_raw(&self, step_id: &str) -> Option<&Value> {
141        self.run_steps.get(step_id)
142    }
143
144    /// 返回最新 run step 的原始快照。
145    pub fn latest_run_step_raw(&self) -> Option<&Value> {
146        self.latest_run_step_id
147            .as_deref()
148            .and_then(|step_id| self.run_steps.get(step_id))
149    }
150
151    /// 返回指定 run step 的结构化快照。
152    pub fn run_step<T>(&self, step_id: &str) -> Option<T>
153    where
154        T: DeserializeOwned,
155    {
156        self.run_steps
157            .get(step_id)
158            .and_then(|value| serde_json::from_value(value.clone()).ok())
159    }
160
161    /// 返回最新 run step 的结构化快照。
162    pub fn latest_run_step<T>(&self) -> Option<T>
163    where
164        T: DeserializeOwned,
165    {
166        self.latest_run_step_id
167            .as_deref()
168            .and_then(|step_id| self.run_step(step_id))
169    }
170
171    fn apply(&mut self, event: &AssistantStreamEvent) {
172        match event.event.as_str() {
173            "thread.created" => {
174                self.thread = Some(event.data.clone());
175            }
176            "thread.run.created"
177            | "thread.run.queued"
178            | "thread.run.in_progress"
179            | "thread.run.requires_action"
180            | "thread.run.completed"
181            | "thread.run.incomplete"
182            | "thread.run.failed"
183            | "thread.run.cancelling"
184            | "thread.run.cancelled"
185            | "thread.run.expired" => {
186                if let Some(id) = event.data.get("id").and_then(Value::as_str) {
187                    self.latest_run_id = Some(id.to_owned());
188                    self.runs.insert(id.to_owned(), event.data.clone());
189                }
190            }
191            "thread.message.created"
192            | "thread.message.in_progress"
193            | "thread.message.completed"
194            | "thread.message.incomplete" => {
195                if let Some(id) = event.data.get("id").and_then(Value::as_str) {
196                    self.latest_message_id = Some(id.to_owned());
197                    self.messages.insert(id.to_owned(), event.data.clone());
198                }
199            }
200            "thread.run.step.created"
201            | "thread.run.step.in_progress"
202            | "thread.run.step.completed"
203            | "thread.run.step.failed"
204            | "thread.run.step.cancelled"
205            | "thread.run.step.expired" => {
206                if let Some(id) = event.data.get("id").and_then(Value::as_str) {
207                    self.latest_run_step_id = Some(id.to_owned());
208                    self.run_steps.insert(id.to_owned(), event.data.clone());
209                }
210            }
211            "thread.message.delta" => {
212                if let Some(id) = event.data.get("id").and_then(Value::as_str) {
213                    self.latest_message_id = Some(id.to_owned());
214                    let entry = self
215                        .messages
216                        .entry(id.to_owned())
217                        .or_insert_with(|| empty_assistant_snapshot(id, "thread.message"));
218                    apply_message_delta(entry, &event.data);
219                }
220            }
221            "thread.run.step.delta" => {
222                if let Some(id) = event.data.get("id").and_then(Value::as_str) {
223                    self.latest_run_step_id = Some(id.to_owned());
224                    let entry = self
225                        .run_steps
226                        .entry(id.to_owned())
227                        .or_insert_with(|| empty_assistant_snapshot(id, "thread.run.step"));
228                    apply_run_step_delta(entry, &event.data);
229                }
230            }
231            _ => {}
232        }
233    }
234}
235
236/// 表示 Assistants API 的流式包装器。
237pub struct AssistantStream {
238    inner: Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent>> + Send>>,
239    meta: ResponseMeta,
240    snapshot: AssistantStreamSnapshot,
241}
242
243impl AssistantStream {
244    /// 从原始 SSE 流创建 Assistants 流。
245    #[allow(tail_expr_drop_order)]
246    pub fn new(raw: RawSseStream) -> Self {
247        let meta = raw.meta().clone();
248        let stream = try_stream! {
249            let mut raw = raw;
250            while let Some(event) = raw.next().await {
251                let event = event?;
252                if event.data == "[DONE]" {
253                    break;
254                }
255
256                let data = serde_json::from_str::<Value>(&event.data).map_err(|error| {
257                    StreamError::new(format!(
258                        "解析 Assistants SSE 事件失败: event={:?}, error={error}, payload={}",
259                        event.event,
260                        event.data
261                    ))
262                })?;
263                let event_name = event
264                    .event
265                    .or_else(|| data.get("event").and_then(Value::as_str).map(str::to_owned))
266                    .or_else(|| data.get("type").and_then(Value::as_str).map(str::to_owned))
267                    .unwrap_or_else(|| "message".into());
268
269                yield AssistantStreamEvent {
270                    event: event_name,
271                    data,
272                };
273            }
274        };
275
276        Self {
277            inner: Box::pin(stream),
278            meta,
279            snapshot: AssistantStreamSnapshot::default(),
280        }
281    }
282
283    /// 返回截至目前的快照。
284    pub fn snapshot(&self) -> &AssistantStreamSnapshot {
285        &self.snapshot
286    }
287
288    /// 返回底层响应元信息。
289    pub fn meta(&self) -> &ResponseMeta {
290        &self.meta
291    }
292
293    /// 消费整个流并返回最终快照。
294    pub async fn final_snapshot(mut self) -> Result<AssistantStreamSnapshot> {
295        while let Some(event) = self.next().await {
296            event?;
297        }
298        Ok(self.snapshot)
299    }
300
301    /// 把原始 Assistants 事件流转换为带高层派生语义的运行时流。
302    pub fn events(self) -> AssistantEventStream {
303        AssistantEventStream::new(self)
304    }
305}
306
307impl fmt::Debug for AssistantStream {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        f.debug_struct("AssistantStream")
310            .field("meta", &self.meta)
311            .field("snapshot", &self.snapshot)
312            .finish()
313    }
314}
315
316impl Stream for AssistantStream {
317    type Item = Result<AssistantStreamEvent>;
318
319    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
320        let this = self.get_mut();
321        match this.inner.as_mut().poll_next(cx) {
322            Poll::Ready(Some(Ok(event))) => {
323                this.snapshot.apply(&event);
324                Poll::Ready(Some(Ok(event)))
325            }
326            other => other,
327        }
328    }
329}
330
331/// 表示 message 创建事件。
332#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
333pub struct AssistantMessageCreatedEvent {
334    /// message 快照。
335    pub message: Value,
336}
337
338/// 表示 message 增量事件。
339#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
340pub struct AssistantMessageDeltaEvent {
341    /// message 增量。
342    pub delta: Value,
343    /// 当前 message 快照。
344    pub snapshot: Value,
345}
346
347/// 表示 message 完成事件。
348#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
349pub struct AssistantMessageDoneEvent {
350    /// message 快照。
351    pub message: Value,
352}
353
354/// 表示 run step 创建事件。
355#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
356pub struct AssistantRunStepCreatedEvent {
357    /// run step 快照。
358    pub run_step: Value,
359}
360
361/// 表示 run step 增量事件。
362#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
363pub struct AssistantRunStepDeltaEvent {
364    /// run step 增量。
365    pub delta: Value,
366    /// 当前 run step 快照。
367    pub snapshot: Value,
368}
369
370/// 表示 run step 完成事件。
371#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
372pub struct AssistantRunStepDoneEvent {
373    /// run step 快照。
374    pub run_step: Value,
375}
376
377/// 表示工具调用创建事件。
378#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
379pub struct AssistantToolCallCreatedEvent {
380    /// run step ID。
381    pub run_step_id: Option<String>,
382    /// 工具调用索引。
383    pub tool_call_index: usize,
384    /// 工具调用快照。
385    pub tool_call: Value,
386}
387
388/// 表示工具调用增量事件。
389#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390pub struct AssistantToolCallDeltaEvent {
391    /// run step ID。
392    pub run_step_id: Option<String>,
393    /// 工具调用索引。
394    pub tool_call_index: usize,
395    /// 工具调用增量。
396    pub delta: Value,
397    /// 当前工具调用快照。
398    pub snapshot: Value,
399}
400
401/// 表示工具调用完成事件。
402#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
403pub struct AssistantToolCallDoneEvent {
404    /// run step ID。
405    pub run_step_id: Option<String>,
406    /// 工具调用索引。
407    pub tool_call_index: usize,
408    /// 工具调用快照。
409    pub tool_call: Value,
410}
411
412/// 表示文本内容创建事件。
413#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414pub struct AssistantTextCreatedEvent {
415    /// message ID。
416    pub message_id: Option<String>,
417    /// 内容索引。
418    pub content_index: usize,
419    /// 文本内容快照。
420    pub text: Value,
421}
422
423/// 表示文本内容增量事件。
424#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
425pub struct AssistantTextDeltaEvent {
426    /// message ID。
427    pub message_id: Option<String>,
428    /// 内容索引。
429    pub content_index: usize,
430    /// 文本增量。
431    pub delta: Value,
432    /// 当前文本内容快照。
433    pub snapshot: Value,
434}
435
436/// 表示文本内容完成事件。
437#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
438pub struct AssistantTextDoneEvent {
439    /// message ID。
440    pub message_id: Option<String>,
441    /// 内容索引。
442    pub content_index: usize,
443    /// 当前文本内容快照。
444    pub text: Value,
445    /// 当前 message 快照。
446    pub message: Value,
447}
448
449/// 表示图片文件完成事件。
450#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
451pub struct AssistantImageFileDoneEvent {
452    /// message ID。
453    pub message_id: Option<String>,
454    /// 内容索引。
455    pub content_index: usize,
456    /// 图片文件内容。
457    pub image_file: Value,
458    /// 当前 message 快照。
459    pub message: Value,
460}
461
462/// 表示 Assistants 流在运行时派生出的高层事件。
463#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
464pub enum AssistantRuntimeEvent {
465    /// 原始 Assistants SSE 事件。
466    Event(AssistantStreamEvent),
467    /// message 创建。
468    MessageCreated(AssistantMessageCreatedEvent),
469    /// message 增量。
470    MessageDelta(AssistantMessageDeltaEvent),
471    /// message 完成。
472    MessageDone(AssistantMessageDoneEvent),
473    /// run step 创建。
474    RunStepCreated(AssistantRunStepCreatedEvent),
475    /// run step 增量。
476    RunStepDelta(AssistantRunStepDeltaEvent),
477    /// run step 完成。
478    RunStepDone(AssistantRunStepDoneEvent),
479    /// 工具调用创建。
480    ToolCallCreated(AssistantToolCallCreatedEvent),
481    /// 工具调用增量。
482    ToolCallDelta(AssistantToolCallDeltaEvent),
483    /// 工具调用完成。
484    ToolCallDone(AssistantToolCallDoneEvent),
485    /// 文本内容创建。
486    TextCreated(AssistantTextCreatedEvent),
487    /// 文本内容增量。
488    TextDelta(AssistantTextDeltaEvent),
489    /// 文本内容完成。
490    TextDone(AssistantTextDoneEvent),
491    /// 图片文件完成。
492    ImageFileDone(AssistantImageFileDoneEvent),
493}
494
495/// 表示带高层派生事件的 Assistants 流。
496#[derive(Debug)]
497pub struct AssistantEventStream {
498    inner: AssistantStream,
499    queue: VecDeque<AssistantRuntimeEvent>,
500    seen_message_texts: HashMap<String, HashSet<usize>>,
501    seen_step_tool_calls: HashMap<String, HashSet<usize>>,
502}
503
504impl AssistantEventStream {
505    fn new(inner: AssistantStream) -> Self {
506        Self {
507            inner,
508            queue: VecDeque::new(),
509            seen_message_texts: HashMap::new(),
510            seen_step_tool_calls: HashMap::new(),
511        }
512    }
513
514    /// 返回当前累计快照。
515    pub fn snapshot(&self) -> &AssistantStreamSnapshot {
516        self.inner.snapshot()
517    }
518
519    /// 返回底层响应元信息。
520    pub fn meta(&self) -> &ResponseMeta {
521        self.inner.meta()
522    }
523
524    /// 消费整个事件流并返回最终快照。
525    pub async fn final_snapshot(mut self) -> Result<AssistantStreamSnapshot> {
526        while let Some(event) = self.next().await {
527            event?;
528        }
529        Ok(self.inner.snapshot)
530    }
531
532    fn enqueue_events(&mut self, event: &AssistantStreamEvent) {
533        self.queue
534            .push_back(AssistantRuntimeEvent::Event(event.clone()));
535
536        match event.event.as_str() {
537            "thread.message.created" => {
538                self.queue.push_back(AssistantRuntimeEvent::MessageCreated(
539                    AssistantMessageCreatedEvent {
540                        message: event.data.clone(),
541                    },
542                ));
543                self.enqueue_text_created_from_message(&event.data);
544            }
545            "thread.message.delta" => {
546                let message_id = event
547                    .data
548                    .get("id")
549                    .and_then(Value::as_str)
550                    .map(str::to_owned);
551                let snapshot = message_id
552                    .as_deref()
553                    .and_then(|id| self.inner.snapshot().message_raw(id))
554                    .cloned()
555                    .unwrap_or_else(|| event.data.clone());
556                self.queue.push_back(AssistantRuntimeEvent::MessageDelta(
557                    AssistantMessageDeltaEvent {
558                        delta: event.data.get("delta").cloned().unwrap_or(Value::Null),
559                        snapshot: snapshot.clone(),
560                    },
561                ));
562                self.enqueue_text_delta(&message_id, event, &snapshot);
563            }
564            "thread.message.completed" | "thread.message.incomplete" => {
565                let message = event
566                    .data
567                    .get("id")
568                    .and_then(Value::as_str)
569                    .and_then(|id| self.inner.snapshot().message_raw(id))
570                    .cloned()
571                    .unwrap_or_else(|| event.data.clone());
572                self.queue.push_back(AssistantRuntimeEvent::MessageDone(
573                    AssistantMessageDoneEvent {
574                        message: message.clone(),
575                    },
576                ));
577                self.enqueue_message_done_content(&message);
578            }
579            "thread.run.step.created" => {
580                self.queue.push_back(AssistantRuntimeEvent::RunStepCreated(
581                    AssistantRunStepCreatedEvent {
582                        run_step: event.data.clone(),
583                    },
584                ));
585            }
586            "thread.run.step.delta" => {
587                let step_id = event
588                    .data
589                    .get("id")
590                    .and_then(Value::as_str)
591                    .map(str::to_owned);
592                let snapshot = step_id
593                    .as_deref()
594                    .and_then(|id| self.inner.snapshot().run_step_raw(id))
595                    .cloned()
596                    .unwrap_or_else(|| event.data.clone());
597                self.queue.push_back(AssistantRuntimeEvent::RunStepDelta(
598                    AssistantRunStepDeltaEvent {
599                        delta: event.data.get("delta").cloned().unwrap_or(Value::Null),
600                        snapshot: snapshot.clone(),
601                    },
602                ));
603                self.enqueue_tool_call_delta(&step_id, event, &snapshot);
604            }
605            "thread.run.step.completed"
606            | "thread.run.step.failed"
607            | "thread.run.step.cancelled"
608            | "thread.run.step.expired" => {
609                let run_step = event
610                    .data
611                    .get("id")
612                    .and_then(Value::as_str)
613                    .and_then(|id| self.inner.snapshot().run_step_raw(id))
614                    .cloned()
615                    .unwrap_or_else(|| event.data.clone());
616                self.queue.push_back(AssistantRuntimeEvent::RunStepDone(
617                    AssistantRunStepDoneEvent {
618                        run_step: run_step.clone(),
619                    },
620                ));
621                self.enqueue_tool_call_done(&run_step);
622            }
623            _ => {}
624        }
625    }
626
627    fn enqueue_text_created_from_message(&mut self, message: &Value) {
628        let message_id = message.get("id").and_then(Value::as_str).map(str::to_owned);
629        let Some(content) = message.get("content").and_then(Value::as_array) else {
630            return;
631        };
632        for (index, part) in content.iter().enumerate() {
633            if part.get("type").and_then(Value::as_str) == Some("text") {
634                self.mark_message_text_seen(&message_id, index);
635                self.queue.push_back(AssistantRuntimeEvent::TextCreated(
636                    AssistantTextCreatedEvent {
637                        message_id: message_id.clone(),
638                        content_index: index,
639                        text: part.clone(),
640                    },
641                ));
642            }
643        }
644    }
645
646    fn enqueue_text_delta(
647        &mut self,
648        message_id: &Option<String>,
649        event: &AssistantStreamEvent,
650        snapshot: &Value,
651    ) {
652        let Some(content_deltas) = event
653            .data
654            .get("delta")
655            .and_then(|value| value.get("content"))
656            .and_then(Value::as_array)
657        else {
658            return;
659        };
660
661        let snapshot_content = snapshot
662            .get("content")
663            .and_then(Value::as_array)
664            .cloned()
665            .unwrap_or_default();
666
667        for content_delta in content_deltas {
668            let index = content_delta
669                .get("index")
670                .and_then(Value::as_u64)
671                .map(|value| value as usize)
672                .unwrap_or_default();
673            if content_delta.get("type").and_then(Value::as_str) != Some("text") {
674                continue;
675            }
676
677            if !self.message_text_seen(message_id, index)
678                && let Some(snapshot_part) = snapshot_content.get(index)
679            {
680                self.mark_message_text_seen(message_id, index);
681                self.queue.push_back(AssistantRuntimeEvent::TextCreated(
682                    AssistantTextCreatedEvent {
683                        message_id: message_id.clone(),
684                        content_index: index,
685                        text: snapshot_part.clone(),
686                    },
687                ));
688            }
689
690            if let Some(snapshot_part) = snapshot_content.get(index) {
691                self.queue
692                    .push_back(AssistantRuntimeEvent::TextDelta(AssistantTextDeltaEvent {
693                        message_id: message_id.clone(),
694                        content_index: index,
695                        delta: content_delta.clone(),
696                        snapshot: snapshot_part.clone(),
697                    }));
698            }
699        }
700    }
701
702    fn enqueue_message_done_content(&mut self, message: &Value) {
703        let message_id = message.get("id").and_then(Value::as_str).map(str::to_owned);
704        let Some(content) = message.get("content").and_then(Value::as_array) else {
705            return;
706        };
707        for (index, part) in content.iter().enumerate() {
708            match part.get("type").and_then(Value::as_str) {
709                Some("text") => {
710                    self.mark_message_text_seen(&message_id, index);
711                    self.queue
712                        .push_back(AssistantRuntimeEvent::TextDone(AssistantTextDoneEvent {
713                            message_id: message_id.clone(),
714                            content_index: index,
715                            text: part.clone(),
716                            message: message.clone(),
717                        }));
718                }
719                Some("image_file") => {
720                    self.queue.push_back(AssistantRuntimeEvent::ImageFileDone(
721                        AssistantImageFileDoneEvent {
722                            message_id: message_id.clone(),
723                            content_index: index,
724                            image_file: part.clone(),
725                            message: message.clone(),
726                        },
727                    ));
728                }
729                _ => {}
730            }
731        }
732    }
733
734    fn enqueue_tool_call_delta(
735        &mut self,
736        step_id: &Option<String>,
737        event: &AssistantStreamEvent,
738        snapshot: &Value,
739    ) {
740        let Some(tool_call_deltas) = event
741            .data
742            .get("delta")
743            .and_then(|value| value.get("step_details"))
744            .and_then(|value| value.get("tool_calls"))
745            .and_then(Value::as_array)
746        else {
747            return;
748        };
749        let snapshot_calls = snapshot
750            .get("step_details")
751            .and_then(|value| value.get("tool_calls"))
752            .and_then(Value::as_array)
753            .cloned()
754            .unwrap_or_default();
755        for tool_call_delta in tool_call_deltas {
756            let index = tool_call_delta
757                .get("index")
758                .and_then(Value::as_u64)
759                .map(|value| value as usize)
760                .unwrap_or_default();
761            if !self.step_tool_call_seen(step_id, index)
762                && let Some(snapshot_call) = snapshot_calls.get(index)
763            {
764                self.mark_step_tool_call_seen(step_id, index);
765                self.queue.push_back(AssistantRuntimeEvent::ToolCallCreated(
766                    AssistantToolCallCreatedEvent {
767                        run_step_id: step_id.clone(),
768                        tool_call_index: index,
769                        tool_call: snapshot_call.clone(),
770                    },
771                ));
772            }
773            if let Some(snapshot_call) = snapshot_calls.get(index) {
774                self.queue.push_back(AssistantRuntimeEvent::ToolCallDelta(
775                    AssistantToolCallDeltaEvent {
776                        run_step_id: step_id.clone(),
777                        tool_call_index: index,
778                        delta: tool_call_delta.clone(),
779                        snapshot: snapshot_call.clone(),
780                    },
781                ));
782            }
783        }
784    }
785
786    fn enqueue_tool_call_done(&mut self, run_step: &Value) {
787        let step_id = run_step
788            .get("id")
789            .and_then(Value::as_str)
790            .map(str::to_owned);
791        let Some(tool_calls) = run_step
792            .get("step_details")
793            .and_then(|value| value.get("tool_calls"))
794            .and_then(Value::as_array)
795        else {
796            return;
797        };
798        for (index, tool_call) in tool_calls.iter().enumerate() {
799            self.mark_step_tool_call_seen(&step_id, index);
800            self.queue.push_back(AssistantRuntimeEvent::ToolCallDone(
801                AssistantToolCallDoneEvent {
802                    run_step_id: step_id.clone(),
803                    tool_call_index: index,
804                    tool_call: tool_call.clone(),
805                },
806            ));
807        }
808    }
809
810    fn message_text_seen(&self, message_id: &Option<String>, index: usize) -> bool {
811        message_id
812            .as_deref()
813            .and_then(|id| self.seen_message_texts.get(id))
814            .is_some_and(|set| set.contains(&index))
815    }
816
817    fn mark_message_text_seen(&mut self, message_id: &Option<String>, index: usize) {
818        let Some(message_id) = message_id else {
819            return;
820        };
821        self.seen_message_texts
822            .entry(message_id.clone())
823            .or_default()
824            .insert(index);
825    }
826
827    fn step_tool_call_seen(&self, step_id: &Option<String>, index: usize) -> bool {
828        step_id
829            .as_deref()
830            .and_then(|id| self.seen_step_tool_calls.get(id))
831            .is_some_and(|set| set.contains(&index))
832    }
833
834    fn mark_step_tool_call_seen(&mut self, step_id: &Option<String>, index: usize) {
835        let Some(step_id) = step_id else {
836            return;
837        };
838        self.seen_step_tool_calls
839            .entry(step_id.clone())
840            .or_default()
841            .insert(index);
842    }
843}
844
845impl Stream for AssistantEventStream {
846    type Item = Result<AssistantRuntimeEvent>;
847
848    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
849        let this = self.get_mut();
850        if let Some(event) = this.queue.pop_front() {
851            return Poll::Ready(Some(Ok(event)));
852        }
853
854        match Pin::new(&mut this.inner).poll_next(cx) {
855            Poll::Ready(Some(Ok(event))) => {
856                this.enqueue_events(&event);
857                Poll::Ready(this.queue.pop_front().map(Ok))
858            }
859            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
860            Poll::Ready(None) => Poll::Ready(None),
861            Poll::Pending => Poll::Pending,
862        }
863    }
864}
865
866fn empty_assistant_snapshot(id: &str, object: &str) -> Value {
867    let mut map = Map::new();
868    map.insert("id".into(), Value::String(id.to_owned()));
869    map.insert("object".into(), Value::String(object.to_owned()));
870    Value::Object(map)
871}
872
873fn apply_message_delta(message: &mut Value, event: &Value) {
874    let Some(delta) = event.get("delta") else {
875        return;
876    };
877    if let Some(role) = delta.get("role").and_then(Value::as_str) {
878        ensure_object(message).insert("role".into(), Value::String(role.to_owned()));
879    }
880
881    let Some(content_deltas) = delta.get("content").and_then(Value::as_array) else {
882        return;
883    };
884    let content = ensure_array_field(message, "content");
885    for content_delta in content_deltas {
886        let index = content_delta
887            .get("index")
888            .and_then(Value::as_u64)
889            .map(|value| value as usize)
890            .unwrap_or(content.len());
891        ensure_vec_len(content, index + 1);
892        if content[index].is_null() {
893            content[index] = Value::Object(Map::new());
894        }
895
896        let slot = &mut content[index];
897        let slot_object = ensure_object(slot);
898        if let Some(part_type) = content_delta.get("type").and_then(Value::as_str) {
899            slot_object.insert("type".into(), Value::String(part_type.to_owned()));
900            match part_type {
901                "text" => {
902                    let text_object = ensure_object_field(slot, "text");
903                    let value = content_delta
904                        .get("text")
905                        .and_then(|value| value.get("value"))
906                        .and_then(Value::as_str)
907                        .unwrap_or("");
908                    let current = text_object
909                        .get("value")
910                        .and_then(Value::as_str)
911                        .unwrap_or("");
912                    text_object.insert("value".into(), Value::String(format!("{current}{value}")));
913                }
914                "refusal" => {
915                    let value = content_delta
916                        .get("refusal")
917                        .and_then(Value::as_str)
918                        .unwrap_or("");
919                    let current = slot_object
920                        .get("refusal")
921                        .and_then(Value::as_str)
922                        .unwrap_or("");
923                    slot_object
924                        .insert("refusal".into(), Value::String(format!("{current}{value}")));
925                }
926                _ => merge_object(slot_object, content_delta),
927            }
928        }
929    }
930}
931
932fn apply_run_step_delta(run_step: &mut Value, event: &Value) {
933    let Some(delta) = event.get("delta") else {
934        return;
935    };
936    let Some(step_details) = delta.get("step_details") else {
937        return;
938    };
939    let step_details_object = ensure_object_field(run_step, "step_details");
940    if let Some(step_type) = step_details.get("type").and_then(Value::as_str) {
941        step_details_object.insert("type".into(), Value::String(step_type.to_owned()));
942        match step_type {
943            "message_creation" => {
944                if let Some(message_creation) = step_details.get("message_creation") {
945                    let target = step_details_object
946                        .entry("message_creation")
947                        .or_insert_with(|| Value::Object(Map::new()));
948                    merge_object(ensure_object(target), message_creation);
949                }
950            }
951            "tool_calls" => {
952                let tool_calls = step_details
953                    .get("tool_calls")
954                    .and_then(Value::as_array)
955                    .cloned()
956                    .unwrap_or_default();
957                let target = step_details_object
958                    .entry("tool_calls")
959                    .or_insert_with(|| Value::Array(Vec::new()));
960                let target_calls = if let Some(array) = target.as_array_mut() {
961                    array
962                } else {
963                    *target = Value::Array(Vec::new());
964                    target.as_array_mut().expect("tool_calls must be array")
965                };
966                for tool_call in tool_calls {
967                    let index = tool_call
968                        .get("index")
969                        .and_then(Value::as_u64)
970                        .map(|value| value as usize)
971                        .unwrap_or(target_calls.len());
972                    ensure_vec_len(target_calls, index + 1);
973                    if target_calls[index].is_null() {
974                        target_calls[index] = Value::Object(Map::new());
975                    }
976                    merge_tool_call_delta(&mut target_calls[index], &tool_call);
977                }
978            }
979            _ => merge_object(step_details_object, step_details),
980        }
981    }
982}
983
984fn merge_tool_call_delta(target: &mut Value, delta: &Value) {
985    let target_object = ensure_object(target);
986    if let Some(delta_object) = delta.as_object() {
987        for (key, value) in delta_object {
988            if matches!(key.as_str(), "function" | "code_interpreter")
989                || matches!(value, Value::Null)
990            {
991                continue;
992            }
993            target_object.insert(key.clone(), value.clone());
994        }
995    }
996    if let Some(function_delta) = delta.get("function") {
997        let function_target = target_object
998            .entry("function")
999            .or_insert_with(|| Value::Object(Map::new()));
1000        let function_object = ensure_object(function_target);
1001        if let Some(arguments) = function_delta.get("arguments").and_then(Value::as_str) {
1002            let current = function_object
1003                .get("arguments")
1004                .and_then(Value::as_str)
1005                .unwrap_or("");
1006            function_object.insert(
1007                "arguments".into(),
1008                Value::String(format!("{current}{arguments}")),
1009            );
1010        }
1011        if let Some(name) = function_delta.get("name").and_then(Value::as_str) {
1012            function_object.insert("name".into(), Value::String(name.to_owned()));
1013        }
1014    }
1015    if let Some(code_interpreter_delta) = delta.get("code_interpreter") {
1016        let code_interpreter_target = target_object
1017            .entry("code_interpreter")
1018            .or_insert_with(|| Value::Object(Map::new()));
1019        let code_interpreter_object = ensure_object(code_interpreter_target);
1020        if let Some(input) = code_interpreter_delta.get("input").and_then(Value::as_str) {
1021            let current = code_interpreter_object
1022                .get("input")
1023                .and_then(Value::as_str)
1024                .unwrap_or("");
1025            code_interpreter_object
1026                .insert("input".into(), Value::String(format!("{current}{input}")));
1027        }
1028        if let Some(outputs) = code_interpreter_delta
1029            .get("outputs")
1030            .and_then(Value::as_array)
1031        {
1032            let target_outputs = code_interpreter_object
1033                .entry("outputs")
1034                .or_insert_with(|| Value::Array(Vec::new()));
1035            let output_array = if let Some(array) = target_outputs.as_array_mut() {
1036                array
1037            } else {
1038                *target_outputs = Value::Array(Vec::new());
1039                target_outputs
1040                    .as_array_mut()
1041                    .expect("outputs must be array")
1042            };
1043            output_array.extend(outputs.iter().cloned());
1044        }
1045    }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050    use super::{AssistantStreamEvent, AssistantStreamSnapshot};
1051    use serde_json::json;
1052
1053    #[test]
1054    fn test_should_merge_assistant_deltas_into_snapshot_before_created_events() {
1055        let mut snapshot = AssistantStreamSnapshot::default();
1056
1057        snapshot.apply(&AssistantStreamEvent {
1058            event: "thread.message.delta".into(),
1059            data: json!({
1060                "id": "msg_1",
1061                "object": "thread.message.delta",
1062                "delta": {
1063                    "content": [{
1064                        "index": 0,
1065                        "type": "text",
1066                        "text": { "value": "hel" }
1067                    }]
1068                }
1069            }),
1070        });
1071        snapshot.apply(&AssistantStreamEvent {
1072            event: "thread.message.delta".into(),
1073            data: json!({
1074                "id": "msg_1",
1075                "object": "thread.message.delta",
1076                "delta": {
1077                    "content": [{
1078                        "index": 0,
1079                        "type": "text",
1080                        "text": { "value": "lo" }
1081                    }]
1082                }
1083            }),
1084        });
1085        snapshot.apply(&AssistantStreamEvent {
1086            event: "thread.run.step.delta".into(),
1087            data: json!({
1088                "id": "step_1",
1089                "object": "thread.run.step.delta",
1090                "delta": {
1091                    "step_details": {
1092                        "type": "tool_calls",
1093                        "tool_calls": [{
1094                            "index": 0,
1095                            "type": "function",
1096                            "function": {
1097                                "name": "lookup_weather",
1098                                "arguments": "{\"city\":\"Sha"
1099                            }
1100                        }]
1101                    }
1102                }
1103            }),
1104        });
1105        snapshot.apply(&AssistantStreamEvent {
1106            event: "thread.run.step.delta".into(),
1107            data: json!({
1108                "id": "step_1",
1109                "object": "thread.run.step.delta",
1110                "delta": {
1111                    "step_details": {
1112                        "type": "tool_calls",
1113                        "tool_calls": [{
1114                            "index": 0,
1115                            "type": "function",
1116                            "function": {
1117                                "arguments": "nghai\"}"
1118                            }
1119                        }]
1120                    }
1121                }
1122            }),
1123        });
1124
1125        assert_eq!(
1126            snapshot
1127                .message_raw("msg_1")
1128                .and_then(|message| message.get("content"))
1129                .and_then(serde_json::Value::as_array)
1130                .and_then(|content| content.first())
1131                .and_then(|part| part.get("text"))
1132                .and_then(|text| text.get("value"))
1133                .and_then(serde_json::Value::as_str),
1134            Some("hello"),
1135        );
1136        assert_eq!(
1137            snapshot
1138                .run_step_raw("step_1")
1139                .and_then(|step| step.get("step_details"))
1140                .and_then(|details| details.get("tool_calls"))
1141                .and_then(serde_json::Value::as_array)
1142                .and_then(|tool_calls| tool_calls.first())
1143                .and_then(|tool_call| tool_call.get("function"))
1144                .and_then(|function| function.get("arguments"))
1145                .and_then(serde_json::Value::as_str),
1146            Some("{\"city\":\"Shanghai\"}"),
1147        );
1148    }
1149}