use std::net::SocketAddr;
use std::sync::Arc;
use serde::de::DeserializeOwned;
use serde_json::{Value, json};
use thiserror::Error;
use tokio::sync::{RwLock, mpsc};
use validator::Validate;
use wscall_protocol::{
EncryptionKind, ErrorPayload, FileAttachment, PacketEnvelope, ProtocolError,
};
use crate::validation;
pub(crate) enum ServerOutbound {
Packet(PacketEnvelope),
Ping(Vec<u8>),
Pong(Vec<u8>),
Close,
}
pub(crate) struct ServerState {
pub clients: RwLock<std::collections::HashMap<String, mpsc::Sender<ServerOutbound>>>,
}
#[derive(Clone)]
pub struct ServerHandle {
pub(crate) state: Arc<ServerState>,
pub(crate) default_encryption: EncryptionKind,
}
#[derive(Clone)]
pub struct ServerConnectionContext {
pub(crate) connection_id: String,
pub(crate) peer_addr: Option<SocketAddr>,
pub(crate) server: ServerHandle,
}
impl ServerConnectionContext {
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
pub fn peer_ip(&self) -> Option<String> {
self.peer_addr.map(|addr| addr.ip().to_string())
}
pub fn server(&self) -> &ServerHandle {
&self.server
}
}
#[derive(Clone)]
pub struct ServerDisconnectContext {
pub(crate) connection_id: String,
pub(crate) peer_addr: Option<SocketAddr>,
pub(crate) reason: String,
pub(crate) server: ServerHandle,
}
impl ServerDisconnectContext {
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
pub fn peer_ip(&self) -> Option<String> {
self.peer_addr.map(|addr| addr.ip().to_string())
}
pub fn reason(&self) -> &str {
&self.reason
}
pub fn server(&self) -> &ServerHandle {
&self.server
}
}
#[derive(Clone)]
pub struct ApiContext {
pub(crate) connection_id: String,
pub(crate) peer_addr: Option<SocketAddr>,
pub(crate) request_id: String,
pub(crate) route: String,
pub(crate) params: Value,
pub(crate) attachments: Vec<FileAttachment>,
pub(crate) metadata: Value,
pub(crate) server: ServerHandle,
}
pub trait ValidateParams {
fn validate(&self) -> Result<(), ApiError>;
}
impl ApiContext {
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
pub fn peer_ip(&self) -> Option<String> {
self.peer_addr.map(|addr| addr.ip().to_string())
}
pub fn request_id(&self) -> &str {
&self.request_id
}
pub fn route(&self) -> &str {
&self.route
}
pub fn params(&self) -> &Value {
&self.params
}
pub fn param(&self, key: &str) -> Option<&Value> {
self.params.as_object()?.get(key)
}
pub fn require_param(&self, key: &str) -> Result<&Value, ApiError> {
self.param(key)
.ok_or_else(|| ApiError::bad_request(format!("missing required param: {key}")))
}
pub fn bind<T>(&self) -> Result<T, ApiError>
where
T: DeserializeOwned,
{
serde_json::from_value(self.params.clone())
.map_err(|source| ApiError::bad_request(format!("invalid params: {source}")))
}
pub fn bind_and_validate<T>(&self) -> Result<T, ApiError>
where
T: DeserializeOwned + ValidateParams,
{
let params: T = self.bind()?;
params.validate()?;
Ok(params)
}
pub fn bind_validated<T>(&self) -> Result<T, ApiError>
where
T: DeserializeOwned + Validate,
{
let params: T = self.bind()?;
params.validate().map_err(|source| {
ApiError::bad_request("params validation failed").with_details(json!({
"validation_errors": validation::errors_to_details(&source),
}))
})?;
Ok(params)
}
pub fn attachments(&self) -> &[FileAttachment] {
&self.attachments
}
pub fn metadata(&self) -> &Value {
&self.metadata
}
pub fn server(&self) -> &ServerHandle {
&self.server
}
pub fn attachment_summaries(&self) -> Vec<Value> {
self.attachments
.iter()
.map(|attachment| {
json!({
"id": attachment.id,
"name": attachment.name,
"content_type": attachment.content_type,
"size": attachment.size,
})
})
.collect()
}
}
#[derive(Clone)]
pub struct EventContext {
pub(crate) connection_id: String,
pub(crate) peer_addr: Option<SocketAddr>,
pub(crate) event_id: String,
pub(crate) name: String,
pub(crate) data: Value,
pub(crate) attachments: Vec<FileAttachment>,
pub(crate) metadata: Value,
pub(crate) server: ServerHandle,
}
impl EventContext {
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
pub fn peer_ip(&self) -> Option<String> {
self.peer_addr.map(|addr| addr.ip().to_string())
}
pub fn event_id(&self) -> &str {
&self.event_id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn data(&self) -> &Value {
&self.data
}
pub fn attachments(&self) -> &[FileAttachment] {
&self.attachments
}
pub fn metadata(&self) -> &Value {
&self.metadata
}
pub fn server(&self) -> &ServerHandle {
&self.server
}
}
#[derive(Clone)]
pub struct ExceptionContext {
pub connection_id: String,
pub request_id: Option<String>,
pub target: String,
pub message_kind: &'static str,
pub error: ApiError,
}
#[derive(Debug, Clone, Error)]
#[error("{code}: {message}")]
pub struct ApiError {
pub code: String,
pub message: String,
pub status: u16,
pub details: Option<Value>,
}
impl ApiError {
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new("bad_request", message, 400)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new("not_found", message, 404)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new("internal_error", message, 500)
}
pub fn new(code: impl Into<String>, message: impl Into<String>, status: u16) -> Self {
Self {
code: code.into(),
message: message.into(),
status,
details: None,
}
}
pub fn with_details(mut self, details: Value) -> Self {
self.details = Some(details);
self
}
pub fn into_payload(self) -> ErrorPayload {
ErrorPayload {
code: self.code,
message: self.message,
status: self.status,
details: self.details,
}
}
}
#[derive(Debug, Error)]
pub enum ServerError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("websocket error: {0}")]
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("protocol error: {0}")]
Protocol(#[from] ProtocolError),
#[error("api error: {0:?}")]
Api(#[from] ApiError),
#[error("connection idle timeout: {0}")]
IdleTimeout(String),
#[error("outbound queue is full for connection {0}")]
OutboundQueueFull(String),
}