use crate::server::op_sink::WsOpSink;
use crate::server::protocol::{ErrorCode, ErrorData, RequestEnvelope};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::sync::Arc;
use crate::config::MonocleConfig;
#[derive(Clone)]
pub struct WsContext {
pub config: MonocleConfig,
}
impl WsContext {
pub fn from_config(config: MonocleConfig) -> Self {
Self { config }
}
pub fn data_dir(&self) -> &str {
&self.config.data_dir
}
}
impl Default for WsContext {
fn default() -> Self {
Self::from_config(MonocleConfig::default())
}
}
#[derive(Debug, Clone)]
pub struct WsRequest {
pub id: String,
pub op_id: Option<String>,
pub method: String,
pub params: Value,
}
impl WsRequest {
pub fn from_envelope(envelope: RequestEnvelope) -> Self {
let id = envelope
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
Self {
id,
op_id: None,
method: envelope.method,
params: envelope.params,
}
}
}
pub type WsResult<T> = Result<T, WsError>;
#[derive(Debug, Clone)]
pub struct WsError {
pub code: ErrorCode,
pub message: String,
pub details: Option<Value>,
}
impl WsError {
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
details: None,
}
}
pub fn with_details(code: ErrorCode, message: impl Into<String>, details: Value) -> Self {
Self {
code,
message: message.into(),
details: Some(details),
}
}
pub fn invalid_params(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InvalidParams, message)
}
pub fn operation_failed(message: impl Into<String>) -> Self {
Self::new(ErrorCode::OperationFailed, message)
}
pub fn not_initialized(resource: &str) -> Self {
Self::new(
ErrorCode::NotInitialized,
format!("{} data not initialized", resource),
)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(ErrorCode::InternalError, message)
}
pub fn to_error_data(&self) -> ErrorData {
match &self.details {
Some(details) => {
ErrorData::with_details(self.code, self.message.clone(), details.clone())
}
None => ErrorData::new(self.code, self.message.clone()),
}
}
}
impl std::fmt::Display for WsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}: {}", self.code, self.message)
}
}
impl std::error::Error for WsError {}
impl From<anyhow::Error> for WsError {
fn from(err: anyhow::Error) -> Self {
Self::operation_failed(err.to_string())
}
}
impl From<serde_json::Error> for WsError {
fn from(err: serde_json::Error) -> Self {
Self::invalid_params(err.to_string())
}
}
#[async_trait]
pub trait WsMethod: Send + Sync + 'static {
const METHOD: &'static str;
const IS_STREAMING: bool = false;
type Params: DeserializeOwned + Send;
fn validate(_params: &Self::Params) -> WsResult<()> {
Ok(())
}
async fn handle(
ctx: Arc<WsContext>,
req: WsRequest,
params: Self::Params,
sink: WsOpSink,
) -> WsResult<()>;
}
pub type DynHandler = Box<
dyn Fn(Arc<WsContext>, WsRequest, WsOpSink) -> futures::future::BoxFuture<'static, WsResult<()>>
+ Send
+ Sync,
>;
pub fn make_handler<M: WsMethod>() -> DynHandler {
Box::new(move |ctx, req, sink| {
Box::pin(async move {
let params: M::Params = serde_json::from_value(req.params.clone())?;
M::validate(¶ms)?;
M::handle(ctx, req, params, sink).await
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_context_default() {
let ctx = WsContext::default();
assert!(ctx.data_dir().contains("monocle"));
}
#[test]
fn test_ws_context_from_config() {
let config = MonocleConfig::default();
let ctx = WsContext::from_config(config.clone());
assert_eq!(ctx.data_dir(), &config.data_dir);
}
#[test]
fn test_ws_request_from_envelope() {
let envelope = RequestEnvelope {
id: Some("test-id".to_string()),
method: "time.parse".to_string(),
params: serde_json::json!({}),
};
let req = WsRequest::from_envelope(envelope);
assert_eq!(req.id, "test-id");
assert_eq!(req.op_id, None);
assert_eq!(req.method, "time.parse");
let envelope = RequestEnvelope {
id: None,
method: "time.parse".to_string(),
params: serde_json::json!({}),
};
let req = WsRequest::from_envelope(envelope);
assert!(!req.id.is_empty());
assert_ne!(req.id, "test-id"); }
#[test]
fn test_ws_error_conversion() {
let err = WsError::invalid_params("missing field");
assert_eq!(err.code, ErrorCode::InvalidParams);
assert!(err.message.contains("missing field"));
let error_data = err.to_error_data();
assert_eq!(error_data.code, ErrorCode::InvalidParams);
}
#[test]
fn test_ws_error_from_anyhow() {
let anyhow_err = anyhow::anyhow!("something went wrong");
let ws_err: WsError = anyhow_err.into();
assert_eq!(ws_err.code, ErrorCode::OperationFailed);
assert!(ws_err.message.contains("something went wrong"));
}
}