async_openai/types/
assistant_stream.rs1use std::pin::Pin;
2
3use futures::Stream;
4use serde::Deserialize;
5
6use crate::error::{map_deserialization_error, ApiError, OpenAIError};
7
8use super::{
9 MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
10};
11
12#[derive(Debug, Deserialize, Clone)]
33#[serde(tag = "event", content = "data")]
34#[non_exhaustive]
35pub enum AssistantStreamEvent {
36 #[serde(rename = "thread.created")]
38 TreadCreated(ThreadObject),
39 #[serde(rename = "thread.run.created")]
41 ThreadRunCreated(RunObject),
42 #[serde(rename = "thread.run.queued")]
44 ThreadRunQueued(RunObject),
45 #[serde(rename = "thread.run.in_progress")]
47 ThreadRunInProgress(RunObject),
48 #[serde(rename = "thread.run.requires_action")]
50 ThreadRunRequiresAction(RunObject),
51 #[serde(rename = "thread.run.completed")]
53 ThreadRunCompleted(RunObject),
54 #[serde(rename = "thread.run.incomplete")]
56 ThreadRunIncomplete(RunObject),
57 #[serde(rename = "thread.run.failed")]
59 ThreadRunFailed(RunObject),
60 #[serde(rename = "thread.run.cancelling")]
62 ThreadRunCancelling(RunObject),
63 #[serde(rename = "thread.run.cancelled")]
65 ThreadRunCancelled(RunObject),
66 #[serde(rename = "thread.run.expired")]
68 ThreadRunExpired(RunObject),
69 #[serde(rename = "thread.run.step.created")]
71 ThreadRunStepCreated(RunStepObject),
72 #[serde(rename = "thread.run.step.in_progress")]
74 ThreadRunStepInProgress(RunStepObject),
75 #[serde(rename = "thread.run.step.delta")]
77 ThreadRunStepDelta(RunStepDeltaObject),
78 #[serde(rename = "thread.run.step.completed")]
80 ThreadRunStepCompleted(RunStepObject),
81 #[serde(rename = "thread.run.step.failed")]
83 ThreadRunStepFailed(RunStepObject),
84 #[serde(rename = "thread.run.step.cancelled")]
86 ThreadRunStepCancelled(RunStepObject),
87 #[serde(rename = "thread.run.step.expired")]
89 ThreadRunStepExpired(RunStepObject),
90 #[serde(rename = "thread.message.created")]
92 ThreadMessageCreated(MessageObject),
93 #[serde(rename = "thread.message.in_progress")]
95 ThreadMessageInProgress(MessageObject),
96 #[serde(rename = "thread.message.delta")]
98 ThreadMessageDelta(MessageDeltaObject),
99 #[serde(rename = "thread.message.completed")]
101 ThreadMessageCompleted(MessageObject),
102 #[serde(rename = "thread.message.incomplete")]
104 ThreadMessageIncomplete(MessageObject),
105 #[serde(rename = "error")]
107 ErrorEvent(ApiError),
108 #[serde(rename = "done")]
110 Done(String),
111}
112
113pub type AssistantEventStream =
114 Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
115
116impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
117 type Error = OpenAIError;
118 fn try_from(value: eventsource_stream::Event) -> Result<Self, Self::Error> {
119 match value.event.as_str() {
120 "thread.created" => serde_json::from_str::<ThreadObject>(value.data.as_str())
121 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
122 .map(AssistantStreamEvent::TreadCreated),
123 "thread.run.created" => serde_json::from_str::<RunObject>(value.data.as_str())
124 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
125 .map(AssistantStreamEvent::ThreadRunCreated),
126 "thread.run.queued" => serde_json::from_str::<RunObject>(value.data.as_str())
127 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
128 .map(AssistantStreamEvent::ThreadRunQueued),
129 "thread.run.in_progress" => serde_json::from_str::<RunObject>(value.data.as_str())
130 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
131 .map(AssistantStreamEvent::ThreadRunInProgress),
132 "thread.run.requires_action" => serde_json::from_str::<RunObject>(value.data.as_str())
133 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
134 .map(AssistantStreamEvent::ThreadRunRequiresAction),
135 "thread.run.completed" => serde_json::from_str::<RunObject>(value.data.as_str())
136 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
137 .map(AssistantStreamEvent::ThreadRunCompleted),
138 "thread.run.incomplete" => serde_json::from_str::<RunObject>(value.data.as_str())
139 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
140 .map(AssistantStreamEvent::ThreadRunIncomplete),
141 "thread.run.failed" => serde_json::from_str::<RunObject>(value.data.as_str())
142 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
143 .map(AssistantStreamEvent::ThreadRunFailed),
144 "thread.run.cancelling" => serde_json::from_str::<RunObject>(value.data.as_str())
145 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
146 .map(AssistantStreamEvent::ThreadRunCancelling),
147 "thread.run.cancelled" => serde_json::from_str::<RunObject>(value.data.as_str())
148 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
149 .map(AssistantStreamEvent::ThreadRunCancelled),
150 "thread.run.expired" => serde_json::from_str::<RunObject>(value.data.as_str())
151 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
152 .map(AssistantStreamEvent::ThreadRunExpired),
153 "thread.run.step.created" => serde_json::from_str::<RunStepObject>(value.data.as_str())
154 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
155 .map(AssistantStreamEvent::ThreadRunStepCreated),
156 "thread.run.step.in_progress" => {
157 serde_json::from_str::<RunStepObject>(value.data.as_str())
158 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
159 .map(AssistantStreamEvent::ThreadRunStepInProgress)
160 }
161 "thread.run.step.delta" => {
162 serde_json::from_str::<RunStepDeltaObject>(value.data.as_str())
163 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
164 .map(AssistantStreamEvent::ThreadRunStepDelta)
165 }
166 "thread.run.step.completed" => {
167 serde_json::from_str::<RunStepObject>(value.data.as_str())
168 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
169 .map(AssistantStreamEvent::ThreadRunStepCompleted)
170 }
171 "thread.run.step.failed" => serde_json::from_str::<RunStepObject>(value.data.as_str())
172 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
173 .map(AssistantStreamEvent::ThreadRunStepFailed),
174 "thread.run.step.cancelled" => {
175 serde_json::from_str::<RunStepObject>(value.data.as_str())
176 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
177 .map(AssistantStreamEvent::ThreadRunStepCancelled)
178 }
179 "thread.run.step.expired" => serde_json::from_str::<RunStepObject>(value.data.as_str())
180 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
181 .map(AssistantStreamEvent::ThreadRunStepExpired),
182 "thread.message.created" => serde_json::from_str::<MessageObject>(value.data.as_str())
183 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
184 .map(AssistantStreamEvent::ThreadMessageCreated),
185 "thread.message.in_progress" => {
186 serde_json::from_str::<MessageObject>(value.data.as_str())
187 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
188 .map(AssistantStreamEvent::ThreadMessageInProgress)
189 }
190 "thread.message.delta" => {
191 serde_json::from_str::<MessageDeltaObject>(value.data.as_str())
192 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
193 .map(AssistantStreamEvent::ThreadMessageDelta)
194 }
195 "thread.message.completed" => {
196 serde_json::from_str::<MessageObject>(value.data.as_str())
197 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
198 .map(AssistantStreamEvent::ThreadMessageCompleted)
199 }
200 "thread.message.incomplete" => {
201 serde_json::from_str::<MessageObject>(value.data.as_str())
202 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
203 .map(AssistantStreamEvent::ThreadMessageIncomplete)
204 }
205 "error" => serde_json::from_str::<ApiError>(value.data.as_str())
206 .map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
207 .map(AssistantStreamEvent::ErrorEvent),
208 "done" => Ok(AssistantStreamEvent::Done(value.data)),
209
210 _ => Err(OpenAIError::StreamError(
211 "Unrecognized event: {value:?#}".into(),
212 )),
213 }
214 }
215}