use std::collections::BTreeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use crate::error::WebSocketError;
use crate::json_payload::JsonPayload;
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct WebSocketServerEvent {
#[serde(rename = "type", default)]
pub event_type: String,
#[serde(flatten)]
pub data: BTreeMap<String, Value>,
}
impl WebSocketServerEvent {
pub fn is_error(&self) -> bool {
self.event_type == "error"
}
pub fn error_message(&self) -> Option<String> {
self.data
.get("error")
.and_then(|value| {
value
.get("message")
.or_else(|| value.get("error"))
.or_else(|| value.get("detail"))
})
.or_else(|| self.data.get("message"))
.and_then(Value::as_str)
.map(str::to_owned)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ResponseCreatedEvent {
pub id: Option<String>,
pub response: Option<JsonPayload>,
pub raw: WebSocketServerEvent,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ResponseOutputTextDeltaEvent {
pub delta: Option<String>,
pub response_id: Option<String>,
pub item_id: Option<String>,
pub raw: WebSocketServerEvent,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SessionCreatedEvent {
pub id: Option<String>,
pub session: Option<JsonPayload>,
pub raw: WebSocketServerEvent,
}
#[derive(Debug, Clone, PartialEq)]
pub enum RealtimeServerEvent {
SessionCreated(SessionCreatedEvent),
ResponseCreated(ResponseCreatedEvent),
ResponseOutputTextDelta(ResponseOutputTextDeltaEvent),
Unknown(WebSocketServerEvent),
}
#[derive(Debug, Clone, PartialEq)]
pub enum ResponsesServerEvent {
ResponseCreated(ResponseCreatedEvent),
ResponseOutputTextDelta(ResponseOutputTextDeltaEvent),
Unknown(WebSocketServerEvent),
}
impl RealtimeServerEvent {
pub fn event_type(&self) -> &str {
self.raw().event_type.as_str()
}
pub fn raw(&self) -> &WebSocketServerEvent {
match self {
Self::SessionCreated(event) => &event.raw,
Self::ResponseCreated(event) => &event.raw,
Self::ResponseOutputTextDelta(event) => &event.raw,
Self::Unknown(event) => event,
}
}
}
impl ResponsesServerEvent {
pub fn event_type(&self) -> &str {
self.raw().event_type.as_str()
}
pub fn raw(&self) -> &WebSocketServerEvent {
match self {
Self::ResponseCreated(event) => &event.raw,
Self::ResponseOutputTextDelta(event) => &event.raw,
Self::Unknown(event) => event,
}
}
}
impl From<WebSocketServerEvent> for RealtimeServerEvent {
fn from(raw: WebSocketServerEvent) -> Self {
match raw.event_type.as_str() {
"session.created" => Self::SessionCreated(SessionCreatedEvent {
id: extract_event_string(&raw, "id").or_else(|| {
raw.data
.get("session")
.and_then(|value| value.get("id"))
.and_then(Value::as_str)
.map(str::to_owned)
}),
session: raw.data.get("session").cloned().map(JsonPayload::from),
raw,
}),
"response.created" => Self::ResponseCreated(ResponseCreatedEvent {
id: extract_event_string(&raw, "id").or_else(|| {
raw.data
.get("response")
.and_then(|value| value.get("id"))
.and_then(Value::as_str)
.map(str::to_owned)
}),
response: raw.data.get("response").cloned().map(JsonPayload::from),
raw,
}),
"response.output_text.delta" => {
Self::ResponseOutputTextDelta(ResponseOutputTextDeltaEvent {
delta: extract_event_string(&raw, "delta"),
response_id: extract_event_string(&raw, "response_id"),
item_id: extract_event_string(&raw, "item_id"),
raw,
})
}
_ => Self::Unknown(raw),
}
}
}
impl From<WebSocketServerEvent> for ResponsesServerEvent {
fn from(raw: WebSocketServerEvent) -> Self {
match raw.event_type.as_str() {
"response.created" => Self::ResponseCreated(ResponseCreatedEvent {
id: extract_event_string(&raw, "id").or_else(|| {
raw.data
.get("response")
.and_then(|value| value.get("id"))
.and_then(Value::as_str)
.map(str::to_owned)
}),
response: raw.data.get("response").cloned().map(JsonPayload::from),
raw,
}),
"response.output_text.delta" => {
Self::ResponseOutputTextDelta(ResponseOutputTextDeltaEvent {
delta: extract_event_string(&raw, "delta"),
response_id: extract_event_string(&raw, "response_id"),
item_id: extract_event_string(&raw, "item_id"),
raw,
})
}
_ => Self::Unknown(raw),
}
}
}
impl<'de> Deserialize<'de> for RealtimeServerEvent {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
WebSocketServerEvent::deserialize(deserializer).map(Self::from)
}
}
impl<'de> Deserialize<'de> for ResponsesServerEvent {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
WebSocketServerEvent::deserialize(deserializer).map(Self::from)
}
}
impl Serialize for RealtimeServerEvent {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
self.raw().serialize(serializer)
}
}
impl Serialize for ResponsesServerEvent {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
self.raw().serialize(serializer)
}
}
fn extract_event_string(raw: &WebSocketServerEvent, key: &str) -> Option<String> {
raw.data.get(key).and_then(Value::as_str).map(str::to_owned)
}
#[derive(Debug, Clone)]
pub enum SocketStreamMessage<T> {
Connecting,
Open,
Closing,
Close,
Message(T),
Error(WebSocketError),
}
pub type RealtimeStreamMessage = SocketStreamMessage<RealtimeServerEvent>;
pub type ResponsesStreamMessage = SocketStreamMessage<ResponsesServerEvent>;
#[derive(Debug, Clone)]
pub struct SocketCloseOptions {
pub code: u16,
pub reason: String,
}
impl Default for SocketCloseOptions {
fn default() -> Self {
Self {
code: 1000,
reason: "OK".into(),
}
}
}