use anyhow::format_err;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
use crate::types::RiverResultInternal;
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all_fields = "camelCase", tag = "type")]
pub enum Control {
#[serde(rename = "CLOSE")]
Close,
#[serde(rename = "ACK")]
Ack,
#[serde(rename = "HANDSHAKE_REQ")]
HandshakeRequest(HandshakeRequest),
#[serde(rename = "HANDSHAKE_RESP")]
HandshakeResponse(HandshakeResponse),
}
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct HandshakeRequest {
pub protocol_version: ProtocolVersion,
pub session_id: String,
pub expected_session_state: ExpectedSessionState,
pub metadata: Option<serde_json::Value>, }
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct HandshakeResponse {
pub status: RiverResultInternal<HandshakeResponseOk>,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct HandshakeResponseOk {
pub session_id: String,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HandshakeError {
SessionStateMismatch,
MalformedHandshakeMeta,
MalformedHandshake,
ProtocolVersionMismatch,
RejectedByCustomHandler,
}
impl Display for HandshakeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let to_write = match self {
HandshakeError::SessionStateMismatch => "SESSION_STATE_MISMATCH",
HandshakeError::MalformedHandshakeMeta => "MALFORMED_HANDSHAKE_META",
HandshakeError::MalformedHandshake => "MALFORMED_HANDSHAKE",
HandshakeError::ProtocolVersionMismatch => "PROTOCOL_VERSION_MISMATCH",
HandshakeError::RejectedByCustomHandler => "REJECTED_BY_CUSTOM_HANDLER",
};
f.write_str(to_write)
}
}
impl TryFrom<String> for HandshakeError {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
let value: &str = &value;
match value {
"SESSION_STATE_MISMATCH" => Ok(HandshakeError::SessionStateMismatch),
"MALFORMED_HANDSHAKE_META" => Ok(HandshakeError::MalformedHandshakeMeta),
"MALFORMED_HANDSHAKE" => Ok(HandshakeError::MalformedHandshake),
"PROTOCOL_VERSION_MISMATCH" => Ok(HandshakeError::ProtocolVersionMismatch),
"REJECTED_BY_CUSTOM_HANDLER" => Ok(HandshakeError::RejectedByCustomHandler),
_ => Err(format_err!("Unknown HandshakeError: `{value}`")),
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)]
pub enum ProtocolVersion {
#[serde(rename = "v0")]
V0,
#[serde(rename = "v1")]
V1,
#[serde(rename = "v1.1")]
V1_1,
#[serde(rename = "v2.0")]
V2_0,
#[serde(untagged)]
Unknown(String),
}
impl Display for ProtocolVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let to_write = match self {
ProtocolVersion::V0 => "v0",
ProtocolVersion::V1 => "v1",
ProtocolVersion::V1_1 => "v1.1",
ProtocolVersion::V2_0 => "v2.0",
ProtocolVersion::Unknown(version) => version.as_str(),
};
f.write_str(to_write)
}
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ExpectedSessionState {
pub next_expected_seq: i64,
pub next_sent_seq: i64,
}