async_openai/types/
assistant_stream.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::Deserialize;
5
6use crate::error::{map_deserialization_error, ApiError, OpenAIError};
7
8use super::{
9    MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
10};
11
12/// Represents an event emitted when streaming a Run.
13///
14/// Each event in a server-sent events stream has an `event` and `data` property:
15///
16/// ```text
17/// event: thread.created
18/// data: {"id": "thread_123", "object": "thread", ...}
19/// ```
20///
21/// We emit events whenever a new object is created, transitions to a new state, or is being
22/// streamed in parts (deltas). For example, we emit `thread.run.created` when a new run
23/// is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses
24/// to create a message during a run, we emit a `thread.message.created event`, a
25/// `thread.message.in_progress` event, many `thread.message.delta` events, and finally a
26/// `thread.message.completed` event.
27///
28/// We may add additional events over time, so we recommend handling unknown events gracefully
29/// in your code. See the [Assistants API quickstart](https://platform.openai.com/docs/assistants/overview) to learn how to
30/// integrate the Assistants API with streaming.
31
32#[derive(Debug, Deserialize, Clone)]
33#[serde(tag = "event", content = "data")]
34#[non_exhaustive]
35pub enum AssistantStreamEvent {
36    /// Occurs when a new [thread](https://platform.openai.com/docs/api-reference/threads/object) is created.
37    #[serde(rename = "thread.created")]
38    TreadCreated(ThreadObject),
39    /// Occurs when a new [run](https://platform.openai.com/docs/api-reference/runs/object) is created.
40    #[serde(rename = "thread.run.created")]
41    ThreadRunCreated(RunObject),
42    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `queued` status.
43    #[serde(rename = "thread.run.queued")]
44    ThreadRunQueued(RunObject),
45    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to an `in_progress` status.
46    #[serde(rename = "thread.run.in_progress")]
47    ThreadRunInProgress(RunObject),
48    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `requires_action` status.
49    #[serde(rename = "thread.run.requires_action")]
50    ThreadRunRequiresAction(RunObject),
51    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) is completed.
52    #[serde(rename = "thread.run.completed")]
53    ThreadRunCompleted(RunObject),
54    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) ends with status `incomplete`.
55    #[serde(rename = "thread.run.incomplete")]
56    ThreadRunIncomplete(RunObject),
57    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) fails.
58    #[serde(rename = "thread.run.failed")]
59    ThreadRunFailed(RunObject),
60    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) moves to a `cancelling` status.
61    #[serde(rename = "thread.run.cancelling")]
62    ThreadRunCancelling(RunObject),
63    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) is cancelled.
64    #[serde(rename = "thread.run.cancelled")]
65    ThreadRunCancelled(RunObject),
66    /// Occurs when a [run](https://platform.openai.com/docs/api-reference/runs/object) expires.
67    #[serde(rename = "thread.run.expired")]
68    ThreadRunExpired(RunObject),
69    /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is created.
70    #[serde(rename = "thread.run.step.created")]
71    ThreadRunStepCreated(RunStepObject),
72    /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) moves to an `in_progress` state.
73    #[serde(rename = "thread.run.step.in_progress")]
74    ThreadRunStepInProgress(RunStepObject),
75    /// Occurs when parts of a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) are being streamed.
76    #[serde(rename = "thread.run.step.delta")]
77    ThreadRunStepDelta(RunStepDeltaObject),
78    ///  Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is completed.
79    #[serde(rename = "thread.run.step.completed")]
80    ThreadRunStepCompleted(RunStepObject),
81    /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) fails.
82    #[serde(rename = "thread.run.step.failed")]
83    ThreadRunStepFailed(RunStepObject),
84    /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) is cancelled.
85    #[serde(rename = "thread.run.step.cancelled")]
86    ThreadRunStepCancelled(RunStepObject),
87    /// Occurs when a [run step](https://platform.openai.com/docs/api-reference/run-steps/step-object) expires.
88    #[serde(rename = "thread.run.step.expired")]
89    ThreadRunStepExpired(RunStepObject),
90    /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) is created.
91    #[serde(rename = "thread.message.created")]
92    ThreadMessageCreated(MessageObject),
93    /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) moves to an `in_progress` state.
94    #[serde(rename = "thread.message.in_progress")]
95    ThreadMessageInProgress(MessageObject),
96    /// Occurs when parts of a [Message](https://platform.openai.com/docs/api-reference/messages/object) are being streamed.
97    #[serde(rename = "thread.message.delta")]
98    ThreadMessageDelta(MessageDeltaObject),
99    /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) is completed.
100    #[serde(rename = "thread.message.completed")]
101    ThreadMessageCompleted(MessageObject),
102    /// Occurs when a [message](https://platform.openai.com/docs/api-reference/messages/object) ends before it is completed.
103    #[serde(rename = "thread.message.incomplete")]
104    ThreadMessageIncomplete(MessageObject),
105    /// Occurs when an [error](https://platform.openai.com/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout.
106    #[serde(rename = "error")]
107    ErrorEvent(ApiError),
108    /// Occurs when a stream ends.
109    #[serde(rename = "done")]
110    Done(String),
111}
112
113pub type AssistantEventStream =
114    Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
115
116impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
117    type Error = OpenAIError;
118    fn try_from(value: eventsource_stream::Event) -> Result<Self, Self::Error> {
119        match value.event.as_str() {
120            "thread.created" => serde_json::from_str::<ThreadObject>(value.data.as_str())
121                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
122                .map(AssistantStreamEvent::TreadCreated),
123            "thread.run.created" => serde_json::from_str::<RunObject>(value.data.as_str())
124                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
125                .map(AssistantStreamEvent::ThreadRunCreated),
126            "thread.run.queued" => serde_json::from_str::<RunObject>(value.data.as_str())
127                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
128                .map(AssistantStreamEvent::ThreadRunQueued),
129            "thread.run.in_progress" => serde_json::from_str::<RunObject>(value.data.as_str())
130                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
131                .map(AssistantStreamEvent::ThreadRunInProgress),
132            "thread.run.requires_action" => serde_json::from_str::<RunObject>(value.data.as_str())
133                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
134                .map(AssistantStreamEvent::ThreadRunRequiresAction),
135            "thread.run.completed" => serde_json::from_str::<RunObject>(value.data.as_str())
136                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
137                .map(AssistantStreamEvent::ThreadRunCompleted),
138            "thread.run.incomplete" => serde_json::from_str::<RunObject>(value.data.as_str())
139                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
140                .map(AssistantStreamEvent::ThreadRunIncomplete),
141            "thread.run.failed" => serde_json::from_str::<RunObject>(value.data.as_str())
142                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
143                .map(AssistantStreamEvent::ThreadRunFailed),
144            "thread.run.cancelling" => serde_json::from_str::<RunObject>(value.data.as_str())
145                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
146                .map(AssistantStreamEvent::ThreadRunCancelling),
147            "thread.run.cancelled" => serde_json::from_str::<RunObject>(value.data.as_str())
148                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
149                .map(AssistantStreamEvent::ThreadRunCancelled),
150            "thread.run.expired" => serde_json::from_str::<RunObject>(value.data.as_str())
151                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
152                .map(AssistantStreamEvent::ThreadRunExpired),
153            "thread.run.step.created" => serde_json::from_str::<RunStepObject>(value.data.as_str())
154                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
155                .map(AssistantStreamEvent::ThreadRunStepCreated),
156            "thread.run.step.in_progress" => {
157                serde_json::from_str::<RunStepObject>(value.data.as_str())
158                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
159                    .map(AssistantStreamEvent::ThreadRunStepInProgress)
160            }
161            "thread.run.step.delta" => {
162                serde_json::from_str::<RunStepDeltaObject>(value.data.as_str())
163                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
164                    .map(AssistantStreamEvent::ThreadRunStepDelta)
165            }
166            "thread.run.step.completed" => {
167                serde_json::from_str::<RunStepObject>(value.data.as_str())
168                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
169                    .map(AssistantStreamEvent::ThreadRunStepCompleted)
170            }
171            "thread.run.step.failed" => serde_json::from_str::<RunStepObject>(value.data.as_str())
172                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
173                .map(AssistantStreamEvent::ThreadRunStepFailed),
174            "thread.run.step.cancelled" => {
175                serde_json::from_str::<RunStepObject>(value.data.as_str())
176                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
177                    .map(AssistantStreamEvent::ThreadRunStepCancelled)
178            }
179            "thread.run.step.expired" => serde_json::from_str::<RunStepObject>(value.data.as_str())
180                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
181                .map(AssistantStreamEvent::ThreadRunStepExpired),
182            "thread.message.created" => serde_json::from_str::<MessageObject>(value.data.as_str())
183                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
184                .map(AssistantStreamEvent::ThreadMessageCreated),
185            "thread.message.in_progress" => {
186                serde_json::from_str::<MessageObject>(value.data.as_str())
187                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
188                    .map(AssistantStreamEvent::ThreadMessageInProgress)
189            }
190            "thread.message.delta" => {
191                serde_json::from_str::<MessageDeltaObject>(value.data.as_str())
192                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
193                    .map(AssistantStreamEvent::ThreadMessageDelta)
194            }
195            "thread.message.completed" => {
196                serde_json::from_str::<MessageObject>(value.data.as_str())
197                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
198                    .map(AssistantStreamEvent::ThreadMessageCompleted)
199            }
200            "thread.message.incomplete" => {
201                serde_json::from_str::<MessageObject>(value.data.as_str())
202                    .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
203                    .map(AssistantStreamEvent::ThreadMessageIncomplete)
204            }
205            "error" => serde_json::from_str::<ApiError>(value.data.as_str())
206                .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
207                .map(AssistantStreamEvent::ErrorEvent),
208            "done" => Ok(AssistantStreamEvent::Done(value.data)),
209
210            _ => Err(OpenAIError::StreamError(
211                "Unrecognized event: {value:?#}".into(),
212            )),
213        }
214    }
215}