use crate::server::handler::{make_handler, DynHandler, WsContext, WsMethod, WsRequest};
use crate::server::op_sink::WsOpSink;
use crate::server::operations::OperationRegistry;
use crate::server::protocol::{ErrorData, RequestEnvelope};
use crate::server::sink::WsSink;
use std::collections::HashMap;
use std::sync::Arc;
pub struct Router {
handlers: HashMap<&'static str, DynHandler>,
streaming_methods: HashMap<&'static str, bool>,
}
impl Router {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
streaming_methods: HashMap::new(),
}
}
pub fn register<M: WsMethod>(&mut self) -> &mut Self {
let handler = make_handler::<M>();
self.handlers.insert(M::METHOD, handler);
self.streaming_methods.insert(M::METHOD, M::IS_STREAMING);
self
}
pub fn has_method(&self, method: &str) -> bool {
self.handlers.contains_key(method)
}
pub fn is_streaming(&self, method: &str) -> bool {
self.streaming_methods.get(method).copied().unwrap_or(false)
}
pub fn method_names(&self) -> Vec<&'static str> {
self.handlers.keys().copied().collect()
}
pub fn get_handler(&self, method: &str) -> Option<&DynHandler> {
self.handlers.get(method)
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
pub struct Dispatcher {
router: Arc<Router>,
context: Arc<WsContext>,
operations: Arc<OperationRegistry>,
}
impl Dispatcher {
pub fn new(router: Router, context: WsContext, operations: OperationRegistry) -> Self {
Self {
router: Arc::new(router),
context: Arc::new(context),
operations: Arc::new(operations),
}
}
pub fn with_router(router: Router) -> Self {
Self::new(router, WsContext::default(), OperationRegistry::new())
}
pub fn context(&self) -> &Arc<WsContext> {
&self.context
}
pub fn operations(&self) -> &Arc<OperationRegistry> {
&self.operations
}
pub fn router(&self) -> &Arc<Router> {
&self.router
}
pub async fn dispatch(&self, message: &str, sink: WsSink) {
let envelope: RequestEnvelope = match serde_json::from_str(message) {
Ok(env) => env,
Err(e) => {
let id = uuid::Uuid::new_v4().to_string();
let op_sink = WsOpSink::new(sink.clone(), id.clone(), None);
let _ = op_sink
.send_error(ErrorData::invalid_request(format!(
"Failed to parse request: {}",
e
)))
.await;
return;
}
};
let mut request = WsRequest::from_envelope(envelope);
let id = request.id.clone();
let method = request.method.clone();
let handler = match self.router.get_handler(&method) {
Some(h) => h,
None => {
let op_sink = WsOpSink::new(sink.clone(), id.clone(), None);
let _ = op_sink.send_error(ErrorData::unknown_method(&method)).await;
return;
}
};
let is_streaming = self.router.is_streaming(&method);
let op_id = if is_streaming {
match self.operations.register(id.clone(), method.clone()).await {
Ok((op_id, _cancel_token)) => {
request.op_id = Some(op_id.clone());
Some(op_id)
}
Err(_) => {
let op_sink = WsOpSink::new(sink.clone(), id.clone(), None);
let _ = op_sink.send_error(ErrorData::rate_limited()).await;
return;
}
}
} else {
request.op_id = None;
None
};
let op_sink = WsOpSink::new(sink.clone(), id.clone(), op_id.clone());
let ctx = Arc::clone(&self.context);
let result = handler(ctx, request, op_sink.clone()).await;
if let Err(e) = result {
let _ = op_sink.send_error(e.to_error_data()).await;
if let Some(ref op_id) = op_id {
let _ = self.operations.fail_and_remove(op_id).await;
}
return;
}
if let Some(ref op_id) = op_id {
let _ = self.operations.complete_and_remove(op_id).await;
}
}
pub async fn cancel(&self, op_id: &str, request_id: String, sink: WsSink) {
let op_sink = WsOpSink::new(sink.clone(), request_id.clone(), None);
match self.operations.cancel_and_remove(op_id).await {
Ok(()) => {
let _ = op_sink
.send_result(serde_json::json!({
"cancelled": true,
"op_id": op_id
}))
.await;
}
Err(crate::server::operations::RegistryError::OperationNotFound) => {
let _ = op_sink
.send_error(ErrorData::invalid_params(format!(
"Unknown operation: {}",
op_id
)))
.await;
}
Err(crate::server::operations::RegistryError::OperationNotRunning) => {
let _ = op_sink
.send_error(ErrorData::invalid_params(format!(
"Operation {} is not running",
op_id
)))
.await;
}
Err(e) => {
let _ = op_sink
.send_error(ErrorData::operation_failed(e.to_string()))
.await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_new() {
let router = Router::new();
assert!(router.method_names().is_empty());
}
#[test]
fn test_router_has_method() {
let router = Router::new();
assert!(!router.has_method("time.parse"));
}
#[test]
fn test_router_is_streaming_unknown() {
let router = Router::new();
assert!(!router.is_streaming("unknown.method"));
}
}