use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{KnafehError, RpcStatusCode};
use crate::rpc::message::{RpcRequest, RpcResponse};
use crate::rpc::service::{MethodDescriptor, MethodKind, Service};
use crate::rpc::stream::{RpcStreamRequest, RpcStreamResponse};
pub struct MethodRouter {
services: HashMap<String, Arc<dyn Service>>,
methods: HashMap<String, (Arc<dyn Service>, MethodDescriptor)>,
}
impl MethodRouter {
pub fn new() -> Self {
Self {
services: HashMap::new(),
methods: HashMap::new(),
}
}
pub fn add_service(&mut self, service: Arc<dyn Service>) {
let service_name = service.name().to_string();
for method in service.methods() {
let key = format!("{}/{}", service_name, method.name);
self.methods.insert(key, (Arc::clone(&service), method));
}
self.services.insert(service_name, service);
}
pub fn lookup(&self, path: &str) -> Option<(&dyn Service, &MethodDescriptor)> {
self.methods
.get(path)
.map(|(svc, desc)| (svc.as_ref(), desc))
}
pub async fn route_unary(&self, request: RpcRequest) -> Result<RpcResponse, KnafehError> {
let method_path = &request.method;
let (service, descriptor) =
self.methods
.get(method_path)
.ok_or_else(|| KnafehError::Service {
code: RpcStatusCode::NotFound,
message: format!("method not found: {method_path}"),
})?;
if descriptor.kind != MethodKind::Unary {
return Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: format!("method {method_path} is {:?}, not Unary", descriptor.kind),
});
}
let method_name = &descriptor.name;
service.call_unary(method_name, request).await
}
pub async fn route_server_stream(
&self,
request: RpcRequest,
) -> Result<RpcStreamResponse, KnafehError> {
let method_path = &request.method;
let (service, descriptor) =
self.methods
.get(method_path)
.ok_or_else(|| KnafehError::Service {
code: RpcStatusCode::NotFound,
message: format!("method not found: {method_path}"),
})?;
if descriptor.kind != MethodKind::ServerStreaming {
return Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: format!(
"method {method_path} is {:?}, not ServerStreaming",
descriptor.kind
),
});
}
let method_name = &descriptor.name;
service.call_server_stream(method_name, request).await
}
pub async fn route_client_stream(
&self,
method_path: &str,
stream: RpcStreamRequest,
) -> Result<RpcResponse, KnafehError> {
let (service, descriptor) =
self.methods
.get(method_path)
.ok_or_else(|| KnafehError::Service {
code: RpcStatusCode::NotFound,
message: format!("method not found: {method_path}"),
})?;
if descriptor.kind != MethodKind::ClientStreaming {
return Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: format!(
"method {method_path} is {:?}, not ClientStreaming",
descriptor.kind
),
});
}
let method_name = &descriptor.name;
service.call_client_stream(method_name, stream).await
}
pub async fn route_bidi_stream(
&self,
method_path: &str,
stream: RpcStreamRequest,
) -> Result<RpcStreamResponse, KnafehError> {
let (service, descriptor) =
self.methods
.get(method_path)
.ok_or_else(|| KnafehError::Service {
code: RpcStatusCode::NotFound,
message: format!("method not found: {method_path}"),
})?;
if descriptor.kind != MethodKind::BidiStreaming {
return Err(KnafehError::Service {
code: RpcStatusCode::Unimplemented,
message: format!(
"method {method_path} is {:?}, not BidiStreaming",
descriptor.kind
),
});
}
let method_name = &descriptor.name;
service.call_bidi_stream(method_name, stream).await
}
pub fn service_names(&self) -> Vec<&str> {
self.services.keys().map(|s| s.as_str()).collect()
}
pub fn method_paths(&self) -> Vec<&str> {
self.methods.keys().map(|s| s.as_str()).collect()
}
}
impl Default for MethodRouter {
fn default() -> Self {
Self::new()
}
}