use std::pin::Pin;
use futures::Stream;
use serde::Deserialize;
use crate::error::{map_deserialization_error, ApiError, OpenAIError};
use super::{
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
};
#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "event", content = "data")]
#[non_exhaustive]
pub enum AssistantStreamEvent {
#[serde(rename = "thread.created")]
TreadCreated(ThreadObject),
#[serde(rename = "thread.run.created")]
ThreadRunCreated(RunObject),
#[serde(rename = "thread.run.queued")]
ThreadRunQueued(RunObject),
#[serde(rename = "thread.run.in_progress")]
ThreadRunInProgress(RunObject),
#[serde(rename = "thread.run.requires_action")]
ThreadRunRequiresAction(RunObject),
#[serde(rename = "thread.run.completed")]
ThreadRunCompleted(RunObject),
#[serde(rename = "thread.run.incomplete")]
ThreadRunIncomplete(RunObject),
#[serde(rename = "thread.run.failed")]
ThreadRunFailed(RunObject),
#[serde(rename = "thread.run.cancelling")]
ThreadRunCancelling(RunObject),
#[serde(rename = "thread.run.cancelled")]
ThreadRunCancelled(RunObject),
#[serde(rename = "thread.run.expired")]
ThreadRunExpired(RunObject),
#[serde(rename = "thread.run.step.created")]
ThreadRunStepCreated(RunStepObject),
#[serde(rename = "thread.run.step.in_progress")]
ThreadRunStepInProgress(RunStepObject),
#[serde(rename = "thread.run.step.delta")]
ThreadRunStepDelta(RunStepDeltaObject),
#[serde(rename = "thread.run.step.completed")]
ThreadRunStepCompleted(RunStepObject),
#[serde(rename = "thread.run.step.failed")]
ThreadRunStepFailed(RunStepObject),
#[serde(rename = "thread.run.step.cancelled")]
ThreadRunStepCancelled(RunStepObject),
#[serde(rename = "thread.run.step.expired")]
ThreadRunStepExpired(RunStepObject),
#[serde(rename = "thread.message.created")]
ThreadMessageCreated(MessageObject),
#[serde(rename = "thread.message.in_progress")]
ThreadMessageInProgress(MessageObject),
#[serde(rename = "thread.message.delta")]
ThreadMessageDelta(MessageDeltaObject),
#[serde(rename = "thread.message.completed")]
ThreadMessageCompleted(MessageObject),
#[serde(rename = "thread.message.incomplete")]
ThreadMessageIncomplete(MessageObject),
#[serde(rename = "error")]
ErrorEvent(ApiError),
#[serde(rename = "done")]
Done(String),
}
pub type AssistantEventStream =
Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
type Error = OpenAIError;
fn try_from(value: eventsource_stream::Event) -> Result<Self, Self::Error> {
match value.event.as_str() {
"thread.created" => serde_json::from_str::<ThreadObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::TreadCreated),
"thread.run.created" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunCreated),
"thread.run.queued" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunQueued),
"thread.run.in_progress" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunInProgress),
"thread.run.requires_action" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunRequiresAction),
"thread.run.completed" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunCompleted),
"thread.run.incomplete" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunIncomplete),
"thread.run.failed" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunFailed),
"thread.run.cancelling" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunCancelling),
"thread.run.cancelled" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunCancelled),
"thread.run.expired" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunExpired),
"thread.run.step.created" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepCreated),
"thread.run.step.in_progress" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepInProgress)
}
"thread.run.step.delta" => {
serde_json::from_str::<RunStepDeltaObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepDelta)
}
"thread.run.step.completed" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepCompleted)
}
"thread.run.step.failed" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepFailed),
"thread.run.step.cancelled" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepCancelled)
}
"thread.run.step.expired" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadRunStepExpired),
"thread.message.created" => serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadMessageCreated),
"thread.message.in_progress" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadMessageInProgress)
}
"thread.message.delta" => {
serde_json::from_str::<MessageDeltaObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadMessageDelta)
}
"thread.message.completed" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadMessageCompleted)
}
"thread.message.incomplete" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ThreadMessageIncomplete)
}
"error" => serde_json::from_str::<ApiError>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map(AssistantStreamEvent::ErrorEvent),
"done" => Ok(AssistantStreamEvent::Done(value.data)),
_ => Err(OpenAIError::StreamError(
"Unrecognized event: {value:?#}".into(),
)),
}
}
}