use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::{Map, Value};
use futures_core::Stream;
use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
pub mod adapter;
pub mod builder;
#[cfg(feature = "router")]
pub mod config;
pub mod contract;
pub mod docs;
pub mod graphiql;
pub mod graphql;
pub mod grpc;
pub mod grpc_explorer;
pub mod handler;
pub mod metadata;
pub mod method;
pub mod openapi;
pub mod rest;
pub mod scalar;
pub mod schema;
pub mod ts_codegen;
#[cfg(feature = "router-graphql")]
pub mod graphql_prod;
#[cfg(feature = "router-grpc")]
pub mod grpc_prod;
pub use adapter::ProtocolAdapter;
pub use builder::RouteBuilder;
#[cfg(feature = "router")]
pub use config::{GraphQLConfig, GrpcConfig, RestConfig, RouterConfig, ServerConfig};
pub use contract::{
ContractTestConfig, ContractTestResult, ContractTestResults, ContractTestable, ContractTester,
};
pub use docs::DocsConfig;
pub use graphiql::{graphiql_html, GraphiQLConfig, GraphiQLTheme};
pub use graphql::{GraphQLAdapter, GraphQLOperation, OperationType};
#[cfg(feature = "router-graphql")]
pub use graphql_prod::GraphQLProductionAdapter;
pub use grpc::{GrpcAdapter, GrpcMethod, GrpcMethodType, GrpcRequest, GrpcStatus};
pub use grpc_explorer::{grpc_explorer_html, GrpcExplorerConfig, GrpcExplorerTheme};
#[cfg(feature = "router-grpc")]
pub use grpc_prod::{protobuf, status, streaming, GrpcProductionAdapter, GrpcService};
pub use handler::{
ErasedHandler, ErasedStreamHandler, Handler, HandlerCallFn, HandlerFn, HandlerWithArgs,
HandlerWithState, HandlerWithStateOnly, IntoHandlerResult, IntoStreamItem, Json,
SharedStateMap, State, StreamError, StreamHandler, StreamHandlerCallFn, StreamReceiver,
StreamSender, StreamingHandlerFn, StreamingHandlerWithArgs, StreamingHandlerWithState,
StreamingHandlerWithStateOnly, DEFAULT_STREAM_CAPACITY,
};
pub use handler::resolve_state;
pub use handler::resolve_state_erased;
pub use metadata::RouteMetadata;
pub use method::Method;
pub use openapi::{OpenApiGenerator, OpenApiServer};
pub use rest::{RestAdapter, RestRequest, RestResponse, RestRoute};
pub use scalar::{scalar_html, ScalarConfig, ScalarLayout, ScalarTheme};
pub use schema::ToJsonSchema;
pub use ts_codegen::{generate_ts_client, HandlerMeta, TsField, TsType};
async fn drive_stream<T, St>(stream: St, tx: &StreamSender) -> String
where
T: IntoStreamItem,
St: Stream<Item = T> + Send,
{
tokio::pin!(stream);
loop {
let next = std::future::poll_fn(|cx| stream.as_mut().poll_next(cx)).await;
match next {
Some(item) => match tx.send(item).await {
Ok(()) => {}
Err(StreamError::Closed) => break,
Err(StreamError::Serialize(e)) => {
return serde_json::json!({"error": e}).to_string();
}
},
None => break,
}
}
"null".to_string()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyTransform {
CamelToSnake,
}
fn camel_to_snake(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
for (i, ch) in s.chars().enumerate() {
if ch.is_ascii_uppercase() {
if i > 0 {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
} else {
result.push(ch);
}
}
result
}
fn transform_keys(value: Value, transform: KeyTransform) -> Value {
match value {
Value::Object(map) => {
let new_map: Map<String, Value> = map
.into_iter()
.map(|(k, v)| {
let new_key = match transform {
KeyTransform::CamelToSnake => camel_to_snake(&k),
};
(new_key, transform_keys(v, transform))
})
.collect();
Value::Object(new_map)
}
Value::Array(arr) => {
Value::Array(arr.into_iter().map(|v| transform_keys(v, transform)).collect())
}
other => other,
}
}
fn apply_key_transform(args: &str, transform: KeyTransform) -> String {
match serde_json::from_str::<Value>(args) {
Ok(value) => transform_keys(value, transform).to_string(),
Err(_) => args.to_string(),
}
}
pub struct Router {
handlers: HashMap<String, Box<dyn Handler>>,
streaming_handlers: HashMap<String, Box<dyn StreamHandler>>,
adapters: HashMap<String, Box<dyn ProtocolAdapter>>,
routes: Vec<RouteMetadata>,
states: SharedStateMap,
handler_metas: HashMap<String, HandlerMeta>,
key_transform: Option<KeyTransform>,
#[cfg(feature = "router")]
#[allow(dead_code)]
config: Option<RouterConfig>,
}
impl Router {
pub fn new() -> Self {
Self {
handlers: HashMap::new(),
streaming_handlers: HashMap::new(),
adapters: HashMap::new(),
routes: Vec::new(),
states: Arc::new(std::sync::RwLock::new(HashMap::new())),
handler_metas: HashMap::new(),
key_transform: None,
#[cfg(feature = "router")]
config: None,
}
}
#[cfg(feature = "router")]
pub fn with_config(config: RouterConfig) -> Self {
let mut router = Self {
handlers: HashMap::new(),
streaming_handlers: HashMap::new(),
adapters: HashMap::new(),
routes: Vec::new(),
states: Arc::new(std::sync::RwLock::new(HashMap::new())),
handler_metas: HashMap::new(),
key_transform: None,
config: Some(config.clone()),
};
if config.has_protocol("rest") {
router.add_adapter(Box::new(RestAdapter::new()));
}
if config.has_protocol("graphql") {
router.add_adapter(Box::new(GraphQLAdapter::new()));
}
if config.has_protocol("grpc") {
router.add_adapter(Box::new(GrpcAdapter::new()));
}
router
}
pub fn with_key_transform(mut self, transform: KeyTransform) -> Self {
self.key_transform = Some(transform);
self
}
pub fn with_state<S: Send + Sync + 'static>(mut self, state: S) -> Self {
self.insert_state::<S>(state);
self
}
pub fn inject_state<S: Send + Sync + 'static>(&mut self, state: S) {
self.insert_state::<S>(state);
}
fn insert_state<S: Send + Sync + 'static>(&mut self, state: S) {
let id = std::any::TypeId::of::<S>();
let mut map = self.states.write().expect("state lock poisoned");
if map.contains_key(&id) {
#[cfg(debug_assertions)]
eprintln!(
"allframe: with_state called twice for type `{}` — previous value replaced",
std::any::type_name::<S>()
);
}
map.insert(id, Arc::new(state));
}
pub fn shared_states(&self) -> SharedStateMap {
self.states.clone()
}
pub fn register<F, Fut>(&mut self, name: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::no_args(handler)));
}
pub fn register_with_args<T, F, Fut>(&mut self, name: &str, handler: F)
where
T: DeserializeOwned + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_args(handler)));
}
pub fn register_with_state<S, T, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let state = self.states.clone();
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state(handler, state)));
}
pub fn register_with_state_only<S, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let state = self.states.clone();
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state_only(handler, state)));
}
pub fn register_typed<R, F, Fut>(&mut self, name: &str, handler: F)
where
R: Serialize + Send + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let wrapped = move || {
let fut = handler();
async move { Json(fut.await) }
};
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::no_args(wrapped)));
}
pub fn register_typed_with_args<T, R, F, Fut>(&mut self, name: &str, handler: F)
where
T: DeserializeOwned + Send + 'static,
R: Serialize + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let wrapped = move |args: T| {
let fut = handler(args);
async move { Json(fut.await) }
};
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_args(wrapped)));
}
pub fn register_typed_with_state<S, T, R, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
R: Serialize + Send + 'static,
F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let state = self.states.clone();
let wrapped = move |s: State<Arc<S>>, args: T| {
let fut = handler(s, args);
async move { Json(fut.await) }
};
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state(wrapped, state)));
}
pub fn register_typed_with_state_only<S, R, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
R: Serialize + Send + 'static,
F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
{
let state = self.states.clone();
let wrapped = move |s: State<Arc<S>>| {
let fut = handler(s);
async move { Json(fut.await) }
};
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state_only(wrapped, state)));
}
pub fn register_result<R, E, F, Fut>(&mut self, name: &str, handler: F)
where
R: Serialize + Send + 'static,
E: std::fmt::Display + Send + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, E>> + Send + 'static,
{
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::no_args(handler)));
}
pub fn register_result_with_args<T, R, E, F, Fut>(&mut self, name: &str, handler: F)
where
T: DeserializeOwned + Send + 'static,
R: Serialize + Send + 'static,
E: std::fmt::Display + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, E>> + Send + 'static,
{
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_args(handler)));
}
pub fn register_result_with_state<S, T, R, E, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
R: Serialize + Send + 'static,
E: std::fmt::Display + Send + 'static,
F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, E>> + Send + 'static,
{
let state = self.states.clone();
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state(handler, state)));
}
pub fn register_result_with_state_only<S, R, E, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
R: Serialize + Send + 'static,
E: std::fmt::Display + Send + 'static,
F: Fn(State<Arc<S>>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, E>> + Send + 'static,
{
let state = self.states.clone();
self.handlers
.insert(name.to_string(), Box::new(ErasedHandler::with_state_only(handler, state)));
}
pub fn handlers_count(&self) -> usize {
self.handlers.len()
}
pub fn register_erased(&mut self, name: &str, handler: ErasedHandler) {
self.handlers.insert(name.to_string(), Box::new(handler));
}
pub fn register_streaming_erased(&mut self, name: &str, handler: ErasedStreamHandler) {
self.streaming_handlers
.insert(name.to_string(), Box::new(handler));
}
pub fn register_streaming<F, Fut, R>(&mut self, name: &str, handler: F)
where
F: Fn(StreamSender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoHandlerResult + 'static,
{
self.streaming_handlers
.insert(name.to_string(), Box::new(ErasedStreamHandler::no_args(handler)));
}
pub fn register_streaming_with_args<T, F, Fut, R>(&mut self, name: &str, handler: F)
where
T: DeserializeOwned + Send + 'static,
F: Fn(T, StreamSender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoHandlerResult + 'static,
{
self.streaming_handlers
.insert(name.to_string(), Box::new(ErasedStreamHandler::with_args(handler)));
}
pub fn register_streaming_with_state<S, T, F, Fut, R>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
F: Fn(State<Arc<S>>, T, StreamSender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoHandlerResult + 'static,
{
let state = self.states.clone();
self.streaming_handlers
.insert(name.to_string(), Box::new(ErasedStreamHandler::with_state(handler, state)));
}
pub fn register_streaming_with_state_only<S, F, Fut, R>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
F: Fn(State<Arc<S>>, StreamSender) -> Fut + Send + Sync + 'static,
Fut: Future<Output = R> + Send + 'static,
R: IntoHandlerResult + 'static,
{
let state = self.states.clone();
self.streaming_handlers
.insert(name.to_string(), Box::new(ErasedStreamHandler::with_state_only(handler, state)));
}
pub fn register_stream<T, St, F, Fut>(&mut self, name: &str, handler: F)
where
T: IntoStreamItem + 'static,
St: Stream<Item = T> + Send + 'static,
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = St> + Send + 'static,
{
self.register_streaming(name, move |tx: StreamSender| {
let stream_fut = handler();
async move {
drive_stream(stream_fut.await, &tx).await
}
});
}
pub fn register_stream_with_args<T, Item, St, F, Fut>(&mut self, name: &str, handler: F)
where
T: DeserializeOwned + Send + 'static,
Item: IntoStreamItem + 'static,
St: Stream<Item = Item> + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = St> + Send + 'static,
{
self.register_streaming_with_args::<T, _, _, _>(name, move |args: T, tx: StreamSender| {
let stream_fut = handler(args);
async move {
drive_stream(stream_fut.await, &tx).await
}
});
}
pub fn register_stream_with_state<S, T, Item, St, F, Fut>(&mut self, name: &str, handler: F)
where
S: Send + Sync + 'static,
T: DeserializeOwned + Send + 'static,
Item: IntoStreamItem + 'static,
St: Stream<Item = Item> + Send + 'static,
F: Fn(State<Arc<S>>, T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = St> + Send + 'static,
{
self.register_streaming_with_state::<S, T, _, _, _>(name, move |state: State<Arc<S>>, args: T, tx: StreamSender| {
let stream_fut = handler(state, args);
async move {
drive_stream(stream_fut.await, &tx).await
}
});
}
pub fn is_streaming(&self, name: &str) -> bool {
self.streaming_handlers.contains_key(name)
}
#[allow(clippy::type_complexity)]
pub fn call_streaming_handler(
&self,
name: &str,
args: &str,
) -> Result<
(
StreamReceiver,
Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>>,
),
String,
> {
let handler = self
.streaming_handlers
.get(name)
.ok_or_else(|| format!("Streaming handler '{}' not found", name))?;
let transformed = self.maybe_transform_args(args);
let args = transformed.as_deref().unwrap_or(args);
let (tx, rx) = StreamSender::channel();
let fut = handler.call_streaming(args, tx);
Ok((rx, fut))
}
#[allow(clippy::type_complexity)]
pub fn spawn_streaming_handler(
self: &Arc<Self>,
name: &str,
args: &str,
) -> Result<
(
StreamReceiver,
tokio::task::JoinHandle<Result<String, String>>,
),
String,
> {
if !self.streaming_handlers.contains_key(name) {
return Err(format!("Streaming handler '{}' not found", name));
}
let router = self.clone();
let name = name.to_string();
let args = match self.maybe_transform_args(args) {
Some(t) => t,
None => args.to_string(),
};
let (tx, rx) = StreamSender::channel();
let handle = tokio::spawn(async move {
let handler = router
.streaming_handlers
.get(&name)
.expect("handler verified to exist");
handler.call_streaming(&args, tx).await
});
Ok((rx, handle))
}
pub fn describe_handler(
&mut self,
name: &str,
args: Vec<TsField>,
returns: TsType,
) {
assert!(
self.handlers.contains_key(name),
"describe_handler: handler '{}' not registered",
name
);
self.handler_metas
.insert(name.to_string(), HandlerMeta::new(args, returns));
}
pub fn describe_streaming_handler(
&mut self,
name: &str,
args: Vec<TsField>,
item_type: TsType,
final_type: TsType,
) {
assert!(
self.streaming_handlers.contains_key(name),
"describe_streaming_handler: streaming handler '{}' not registered",
name
);
self.handler_metas
.insert(name.to_string(), HandlerMeta::streaming(args, item_type, final_type));
}
pub fn generate_ts_client(&self) -> String {
generate_ts_client(&self.handler_metas)
}
pub fn handler_meta(&self, name: &str) -> Option<&HandlerMeta> {
self.handler_metas.get(name)
}
pub fn add_adapter(&mut self, adapter: Box<dyn ProtocolAdapter>) {
self.adapters.insert(adapter.name().to_string(), adapter);
}
pub fn has_adapter(&self, name: &str) -> bool {
self.adapters.contains_key(name)
}
pub fn get_adapter(&self, name: &str) -> Option<&dyn ProtocolAdapter> {
self.adapters.get(name).map(|b| &**b)
}
pub async fn route_request(&self, protocol: &str, request: &str) -> Result<String, String> {
let adapter = self
.get_adapter(protocol)
.ok_or_else(|| format!("Adapter not found: {}", protocol))?;
adapter.handle(request).await
}
pub async fn execute(&self, name: &str) -> Result<String, String> {
self.execute_with_args(name, "{}").await
}
fn maybe_transform_args(&self, args: &str) -> Option<String> {
self.key_transform.map(|t| apply_key_transform(args, t))
}
pub async fn execute_with_args(&self, name: &str, args: &str) -> Result<String, String> {
let transformed;
let args = match self.maybe_transform_args(args) {
Some(t) => {
transformed = t;
&transformed
}
None => args,
};
match self.handlers.get(name) {
Some(handler) => handler.call(args).await,
None => Err(format!("Handler '{}' not found", name)),
}
}
pub fn list_handlers(&self) -> Vec<String> {
let mut names: Vec<String> = self.handlers.keys().cloned().collect();
names.extend(self.streaming_handlers.keys().cloned());
names
}
pub async fn call_handler(&self, name: &str, request: &str) -> Result<String, String> {
self.execute_with_args(name, request).await
}
pub fn can_handle_rest(&self, _name: &str) -> bool {
self.has_adapter("rest")
}
pub fn can_handle_graphql(&self, _name: &str) -> bool {
self.has_adapter("graphql")
}
pub fn can_handle_grpc(&self, _name: &str) -> bool {
self.has_adapter("grpc")
}
pub fn enabled_protocols(&self) -> Vec<String> {
self.adapters.keys().cloned().collect()
}
pub fn add_route(&mut self, metadata: RouteMetadata) {
self.routes.push(metadata);
}
pub fn routes(&self) -> &[RouteMetadata] {
&self.routes
}
pub fn get<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("GET:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::GET, "rest"));
}
pub fn post<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("POST:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::POST, "rest"));
}
pub fn put<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("PUT:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::PUT, "rest"));
}
pub fn delete<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("DELETE:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::DELETE, "rest"));
}
pub fn patch<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("PATCH:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::PATCH, "rest"));
}
pub fn head<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("HEAD:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::HEAD, "rest"));
}
pub fn options<F, Fut>(&mut self, path: &str, handler: F)
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = String> + Send + 'static,
{
let handler_name = format!("OPTIONS:{}", path);
self.register(&handler_name, handler);
self.add_route(RouteMetadata::new(path, Method::OPTIONS, "rest"));
}
pub async fn call_rest(&self, method: &str, path: &str) -> Result<String, String> {
let adapter = self
.adapters
.get("rest")
.ok_or_else(|| "REST adapter not enabled".to_string())?;
let request = format!("{} {}", method, path);
adapter.handle(&request).await
}
pub async fn call_graphql(&self, query: &str) -> Result<String, String> {
let adapter = self
.adapters
.get("graphql")
.ok_or_else(|| "GraphQL adapter not enabled".to_string())?;
adapter.handle(query).await
}
pub async fn call_grpc(&self, method: &str, request: &str) -> Result<String, String> {
let adapter = self
.adapters
.get("grpc")
.ok_or_else(|| "gRPC adapter not enabled".to_string())?;
let grpc_request = format!("{}:{}", method, request);
adapter.handle(&grpc_request).await
}
pub fn scalar(&self, title: &str, version: &str) -> String {
let config = scalar::ScalarConfig::default();
self.scalar_docs(config, title, version)
}
pub fn scalar_docs(&self, config: scalar::ScalarConfig, title: &str, version: &str) -> String {
let spec = OpenApiGenerator::new(title, version).generate(self);
let spec_json = serde_json::to_string(&spec).unwrap_or_else(|_| "{}".to_string());
scalar::scalar_html(&config, title, &spec_json)
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
#[macro_export]
macro_rules! register_handlers {
($router:expr, [ $($entry:tt)* ]) => {
$crate::register_handlers!(@entries $router, $($entry)*)
};
(@entries $router:expr, ) => {};
(@entries $router:expr, $name:literal => $handler:path, $($rest:tt)*) => {
$router.register($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, args $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_with_args($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, streaming $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_streaming($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, streaming args $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_streaming_with_args($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, state $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_with_state_only($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, state args $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_with_state($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, state streaming $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_streaming_with_state_only($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
(@entries $router:expr, state streaming args $name:literal => $handler:path, $($rest:tt)*) => {
$router.register_streaming_with_state($name, $handler);
$crate::register_handlers!(@entries $router, $($rest)*)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_router_creation() {
let router = Router::new();
assert_eq!(router.handlers_count(), 0);
}
#[tokio::test]
async fn test_handler_registration() {
let mut router = Router::new();
router.register("test", || async { "Hello".to_string() });
assert_eq!(router.handlers_count(), 1);
}
#[tokio::test]
async fn test_handler_execution() {
let mut router = Router::new();
router.register("test", || async { "Hello".to_string() });
let result = router.execute("test").await;
assert_eq!(result, Ok("Hello".to_string()));
}
#[tokio::test]
async fn test_router_starts_with_no_routes() {
let router = Router::new();
let routes = router.routes();
assert_eq!(routes.len(), 0);
}
#[tokio::test]
async fn test_add_route_metadata() {
let mut router = Router::new();
let metadata = RouteMetadata::new("/users", "GET", "rest");
router.add_route(metadata.clone());
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].path, "/users");
assert_eq!(routes[0].method, "GET");
assert_eq!(routes[0].protocol, "rest");
}
#[tokio::test]
async fn test_add_multiple_routes() {
let mut router = Router::new();
router.add_route(RouteMetadata::new("/users", "GET", "rest"));
router.add_route(RouteMetadata::new("/users", "POST", "rest"));
router.add_route(RouteMetadata::new("/posts", "GET", "rest"));
let routes = router.routes();
assert_eq!(routes.len(), 3);
}
#[tokio::test]
async fn test_routes_with_different_protocols() {
let mut router = Router::new();
router.add_route(RouteMetadata::new("/users", "GET", "rest"));
router.add_route(RouteMetadata::new("users", "query", "graphql"));
router.add_route(RouteMetadata::new("UserService.GetUser", "unary", "grpc"));
let routes = router.routes();
assert_eq!(routes.len(), 3);
assert_eq!(routes[0].protocol, "rest");
assert_eq!(routes[1].protocol, "graphql");
assert_eq!(routes[2].protocol, "grpc");
}
#[tokio::test]
async fn test_routes_returns_immutable_reference() {
let mut router = Router::new();
router.add_route(RouteMetadata::new("/test", "GET", "rest"));
let routes1 = router.routes();
let routes2 = router.routes();
assert_eq!(routes1.len(), routes2.len());
assert_eq!(routes1[0].path, routes2[0].path);
}
#[tokio::test]
async fn test_route_get_method() {
let mut router = Router::new();
router.get("/users", || async { "User list".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].path, "/users");
assert_eq!(routes[0].method, "GET");
assert_eq!(routes[0].protocol, "rest");
}
#[tokio::test]
async fn test_route_post_method() {
let mut router = Router::new();
router.post("/users", || async { "User created".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].path, "/users");
assert_eq!(routes[0].method, "POST");
assert_eq!(routes[0].protocol, "rest");
}
#[tokio::test]
async fn test_route_put_method() {
let mut router = Router::new();
router.put("/users/1", || async { "User updated".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].method, "PUT");
}
#[tokio::test]
async fn test_route_delete_method() {
let mut router = Router::new();
router.delete("/users/1", || async { "User deleted".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].method, "DELETE");
}
#[tokio::test]
async fn test_route_patch_method() {
let mut router = Router::new();
router.patch("/users/1", || async { "User patched".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 1);
assert_eq!(routes[0].method, "PATCH");
}
#[tokio::test]
async fn test_multiple_routes_different_methods() {
let mut router = Router::new();
router.get("/users", || async { "List".to_string() });
router.post("/users", || async { "Create".to_string() });
router.put("/users/1", || async { "Update".to_string() });
router.delete("/users/1", || async { "Delete".to_string() });
let routes = router.routes();
assert_eq!(routes.len(), 4);
assert_eq!(routes[0].method, "GET");
assert_eq!(routes[1].method, "POST");
assert_eq!(routes[2].method, "PUT");
assert_eq!(routes[3].method, "DELETE");
}
#[tokio::test]
async fn test_route_method_with_path_params() {
let mut router = Router::new();
router.get("/users/{id}", || async { "User details".to_string() });
router.get("/users/{id}/posts/{post_id}", || async {
"Post details".to_string()
});
let routes = router.routes();
assert_eq!(routes.len(), 2);
assert_eq!(routes[0].path, "/users/{id}");
assert_eq!(routes[1].path, "/users/{id}/posts/{post_id}");
}
#[tokio::test]
async fn test_route_registration_and_execution() {
let mut router = Router::new();
router.get("/test", || async { "GET response".to_string() });
router.post("/test", || async { "POST response".to_string() });
assert_eq!(router.routes().len(), 2);
assert_eq!(router.handlers_count(), 2);
let result1 = router.execute("GET:/test").await;
let result2 = router.execute("POST:/test").await;
assert_eq!(result1, Ok("GET response".to_string()));
assert_eq!(result2, Ok("POST response".to_string()));
}
#[tokio::test]
async fn test_scalar_generates_html() {
let mut router = Router::new();
router.get("/users", || async { "Users".to_string() });
let html = router.scalar("Test API", "1.0.0");
assert!(html.contains("<!DOCTYPE html>"));
assert!(html.contains("<title>Test API - API Documentation</title>"));
assert!(html.contains("@scalar/api-reference"));
}
#[tokio::test]
async fn test_scalar_contains_openapi_spec() {
let mut router = Router::new();
router.get("/users", || async { "Users".to_string() });
router.post("/users", || async { "User created".to_string() });
let html = router.scalar("Test API", "1.0.0");
assert!(html.contains("openapi"));
assert!(html.contains("Test API"));
assert!(html.contains("1.0.0"));
}
#[tokio::test]
async fn test_scalar_docs_with_custom_config() {
let mut router = Router::new();
router.get("/users", || async { "Users".to_string() });
let config = scalar::ScalarConfig::new()
.theme(scalar::ScalarTheme::Light)
.show_sidebar(false);
let html = router.scalar_docs(config, "Custom API", "2.0.0");
assert!(html.contains("Custom API"));
assert!(html.contains(r#""theme":"light""#));
assert!(html.contains(r#""showSidebar":false"#));
}
#[tokio::test]
async fn test_scalar_docs_with_custom_css() {
let mut router = Router::new();
router.get("/test", || async { "Test".to_string() });
let config = scalar::ScalarConfig::new().custom_css("body { font-family: 'Inter'; }");
let html = router.scalar_docs(config, "API", "1.0");
assert!(html.contains("<style>body { font-family: 'Inter'; }</style>"));
}
#[tokio::test]
async fn test_scalar_with_multiple_routes() {
let mut router = Router::new();
router.get("/users", || async { "Users".to_string() });
router.post("/users", || async { "Create".to_string() });
router.get("/users/{id}", || async { "User details".to_string() });
router.delete("/users/{id}", || async { "Delete".to_string() });
let html = router.scalar("API", "1.0.0");
assert!(html.contains("/users"));
}
#[tokio::test]
async fn test_get_adapter_returns_adapter() {
let mut router = Router::new();
router.add_adapter(Box::new(RestAdapter::new()));
let adapter = router.get_adapter("rest");
assert!(adapter.is_some());
assert_eq!(adapter.unwrap().name(), "rest");
}
#[tokio::test]
async fn test_get_adapter_returns_none_for_missing() {
let router = Router::new();
let adapter = router.get_adapter("rest");
assert!(adapter.is_none());
}
#[tokio::test]
async fn test_route_request_success() {
let mut router = Router::new();
router.register("test_handler", || async { "Success!".to_string() });
let mut rest_adapter = RestAdapter::new();
rest_adapter.route("GET", "/test", "test_handler");
router.add_adapter(Box::new(rest_adapter));
let result = router.route_request("rest", "GET /test").await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.contains("HTTP 200") || response.contains("test_handler"));
}
#[tokio::test]
async fn test_route_request_unknown_adapter() {
let router = Router::new();
let result = router.route_request("unknown", "test").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Adapter not found"));
}
#[tokio::test]
async fn test_enabled_protocols_empty() {
let router = Router::new();
let protocols = router.enabled_protocols();
assert_eq!(protocols.len(), 0);
}
#[tokio::test]
async fn test_enabled_protocols_multiple() {
let mut router = Router::new();
router.add_adapter(Box::new(RestAdapter::new()));
router.add_adapter(Box::new(GraphQLAdapter::new()));
router.add_adapter(Box::new(GrpcAdapter::new()));
let protocols = router.enabled_protocols();
assert_eq!(protocols.len(), 3);
assert!(protocols.contains(&"rest".to_string()));
assert!(protocols.contains(&"graphql".to_string()));
assert!(protocols.contains(&"grpc".to_string()));
}
#[tokio::test]
async fn test_can_handle_rest() {
let mut router = Router::new();
assert!(!router.can_handle_rest("test"));
router.add_adapter(Box::new(RestAdapter::new()));
assert!(router.can_handle_rest("test"));
}
#[tokio::test]
async fn test_can_handle_graphql() {
let mut router = Router::new();
assert!(!router.can_handle_graphql("test"));
router.add_adapter(Box::new(GraphQLAdapter::new()));
assert!(router.can_handle_graphql("test"));
}
#[tokio::test]
async fn test_can_handle_grpc() {
let mut router = Router::new();
assert!(!router.can_handle_grpc("test"));
router.add_adapter(Box::new(GrpcAdapter::new()));
assert!(router.can_handle_grpc("test"));
}
#[tokio::test]
async fn test_integration_single_handler_rest() {
let mut router = Router::new();
router.register("get_user", || async { "User data".to_string() });
let mut rest = RestAdapter::new();
rest.route("GET", "/users/:id", "get_user");
router.add_adapter(Box::new(rest));
let response = router.route_request("rest", "GET /users/42").await;
assert!(response.is_ok());
assert!(response.unwrap().contains("get_user"));
}
#[tokio::test]
async fn test_integration_single_handler_graphql() {
let mut router = Router::new();
router.register("get_user", || async { "User data".to_string() });
let mut graphql = GraphQLAdapter::new();
graphql.query("user", "get_user");
router.add_adapter(Box::new(graphql));
let response = router.route_request("graphql", "query { user }").await;
assert!(response.is_ok());
assert!(response.unwrap().contains("get_user"));
}
#[tokio::test]
async fn test_integration_single_handler_grpc() {
let mut router = Router::new();
router.register("get_user", || async { "User data".to_string() });
let mut grpc = GrpcAdapter::new();
grpc.unary("UserService", "GetUser", "get_user");
router.add_adapter(Box::new(grpc));
let response = router
.route_request("grpc", "UserService.GetUser:{\"id\":42}")
.await;
assert!(response.is_ok());
assert!(response.unwrap().contains("get_user"));
}
#[tokio::test]
async fn test_integration_single_handler_all_protocols() {
let mut router = Router::new();
router.register("get_user", || async { "User data".to_string() });
let mut rest = RestAdapter::new();
rest.route("GET", "/users/:id", "get_user");
router.add_adapter(Box::new(rest));
let mut graphql = GraphQLAdapter::new();
graphql.query("user", "get_user");
router.add_adapter(Box::new(graphql));
let mut grpc = GrpcAdapter::new();
grpc.unary("UserService", "GetUser", "get_user");
router.add_adapter(Box::new(grpc));
let rest_response = router.route_request("rest", "GET /users/42").await;
assert!(rest_response.is_ok());
assert!(rest_response.unwrap().contains("get_user"));
let graphql_response = router.route_request("graphql", "query { user }").await;
assert!(graphql_response.is_ok());
assert!(graphql_response.unwrap().contains("get_user"));
let grpc_response = router
.route_request("grpc", "UserService.GetUser:{\"id\":42}")
.await;
assert!(grpc_response.is_ok());
assert!(grpc_response.unwrap().contains("get_user"));
}
#[tokio::test]
async fn test_integration_multiple_handlers_all_protocols() {
let mut router = Router::new();
router.register("get_user", || async { "User data".to_string() });
router.register("list_users", || async { "Users list".to_string() });
router.register("create_user", || async { "Created user".to_string() });
let mut rest = RestAdapter::new();
rest.route("GET", "/users/:id", "get_user");
rest.route("GET", "/users", "list_users");
rest.route("POST", "/users", "create_user");
router.add_adapter(Box::new(rest));
let mut graphql = GraphQLAdapter::new();
graphql.query("user", "get_user");
graphql.query("users", "list_users");
graphql.mutation("createUser", "create_user");
router.add_adapter(Box::new(graphql));
let mut grpc = GrpcAdapter::new();
grpc.unary("UserService", "GetUser", "get_user");
grpc.unary("UserService", "ListUsers", "list_users");
grpc.unary("UserService", "CreateUser", "create_user");
router.add_adapter(Box::new(grpc));
assert!(router
.route_request("rest", "GET /users/42")
.await
.unwrap()
.contains("get_user"));
assert!(router
.route_request("graphql", "query { user }")
.await
.unwrap()
.contains("get_user"));
assert!(router
.route_request("grpc", "UserService.GetUser:{}")
.await
.unwrap()
.contains("get_user"));
}
#[tokio::test]
async fn test_integration_error_handling_rest_404() {
let mut router = Router::new();
let mut rest = RestAdapter::new();
rest.route("GET", "/users/:id", "get_user");
router.add_adapter(Box::new(rest));
let response = router.route_request("rest", "GET /posts/42").await;
assert!(response.is_ok());
assert!(response.unwrap().contains("HTTP 404"));
}
#[tokio::test]
async fn test_integration_error_handling_graphql_not_found() {
let mut router = Router::new();
let mut graphql = GraphQLAdapter::new();
graphql.query("user", "get_user");
router.add_adapter(Box::new(graphql));
let response = router.route_request("graphql", "query { post }").await;
assert!(response.is_ok());
assert!(response.unwrap().contains("errors"));
}
#[tokio::test]
async fn test_integration_error_handling_grpc_unimplemented() {
let mut router = Router::new();
let mut grpc = GrpcAdapter::new();
grpc.unary("UserService", "GetUser", "get_user");
router.add_adapter(Box::new(grpc));
let response = router.route_request("grpc", "UserService.GetPost:{}").await;
assert!(response.is_ok());
assert!(response.unwrap().contains("grpc-status: 12")); }
#[tokio::test]
async fn test_integration_unknown_protocol() {
let router = Router::new();
let response = router.route_request("unknown", "request").await;
assert!(response.is_err());
assert!(response.unwrap_err().contains("Adapter not found"));
}
#[tokio::test]
async fn test_integration_protocol_specific_features_rest_methods() {
let mut router = Router::new();
router.register("get_users", || async { "Users".to_string() });
router.register("create_user", || async { "Created".to_string() });
router.register("update_user", || async { "Updated".to_string() });
router.register("delete_user", || async { "Deleted".to_string() });
let mut rest = RestAdapter::new();
rest.route("GET", "/users", "get_users");
rest.route("POST", "/users", "create_user");
rest.route("PUT", "/users/:id", "update_user");
rest.route("DELETE", "/users/:id", "delete_user");
router.add_adapter(Box::new(rest));
assert!(router
.route_request("rest", "GET /users")
.await
.unwrap()
.contains("get_users"));
assert!(router
.route_request("rest", "POST /users")
.await
.unwrap()
.contains("create_user"));
assert!(router
.route_request("rest", "PUT /users/42")
.await
.unwrap()
.contains("update_user"));
assert!(router
.route_request("rest", "DELETE /users/42")
.await
.unwrap()
.contains("delete_user"));
}
#[tokio::test]
async fn test_integration_protocol_specific_features_graphql_types() {
let mut router = Router::new();
router.register("get_user", || async { "User".to_string() });
router.register("create_user", || async { "Created".to_string() });
let mut graphql = GraphQLAdapter::new();
graphql.query("user", "get_user");
graphql.mutation("createUser", "create_user");
router.add_adapter(Box::new(graphql));
assert!(router
.route_request("graphql", "query { user }")
.await
.unwrap()
.contains("get_user"));
assert!(router
.route_request("graphql", "mutation { createUser }")
.await
.unwrap()
.contains("create_user"));
}
#[tokio::test]
async fn test_integration_protocol_specific_features_grpc_streaming() {
let mut router = Router::new();
router.register("get_user", || async { "User".to_string() });
router.register("list_users", || async { "Users".to_string() });
let mut grpc = GrpcAdapter::new();
grpc.unary("UserService", "GetUser", "get_user");
grpc.server_streaming("UserService", "ListUsers", "list_users");
router.add_adapter(Box::new(grpc));
let unary_response = router
.route_request("grpc", "UserService.GetUser:{}")
.await
.unwrap();
assert!(unary_response.contains("unary"));
let streaming_response = router
.route_request("grpc", "UserService.ListUsers:{}")
.await
.unwrap();
assert!(streaming_response.contains("server_streaming"));
}
#[tokio::test]
async fn test_register_streaming_handler() {
let mut router = Router::new();
router.register_streaming("stream_data", |tx: StreamSender| async move {
tx.send("item".to_string()).await.ok();
"done".to_string()
});
assert!(router.is_streaming("stream_data"));
assert!(!router.is_streaming("nonexistent"));
}
#[tokio::test]
async fn test_register_streaming_with_args() {
#[derive(serde::Deserialize)]
struct Input {
count: usize,
}
let mut router = Router::new();
router.register_streaming_with_args("stream_items", |args: Input, tx: StreamSender| async move {
for i in 0..args.count {
tx.send(format!("item-{i}")).await.ok();
}
"done".to_string()
});
assert!(router.is_streaming("stream_items"));
}
#[tokio::test]
async fn test_streaming_handler_not_in_regular_handlers() {
let mut router = Router::new();
router.register_streaming("stream", |_tx: StreamSender| async move {
"done".to_string()
});
assert_eq!(router.handlers_count(), 0);
}
#[tokio::test]
async fn test_list_handlers_includes_streaming() {
let mut router = Router::new();
router.register("regular", || async { "ok".to_string() });
router.register_streaming("stream", |_tx: StreamSender| async move {
"ok".to_string()
});
let handlers = router.list_handlers();
assert_eq!(handlers.len(), 2);
assert!(handlers.contains(&"regular".to_string()));
assert!(handlers.contains(&"stream".to_string()));
}
#[tokio::test]
async fn test_call_streaming_handler() {
let mut router = Router::new();
router.register_streaming("stream", |tx: StreamSender| async move {
tx.send("a".to_string()).await.ok();
tx.send("b".to_string()).await.ok();
"final".to_string()
});
let (mut rx, fut) = router.call_streaming_handler("stream", "{}").unwrap();
let result = fut.await;
assert_eq!(result, Ok("final".to_string()));
assert_eq!(rx.recv().await, Some("a".to_string()));
assert_eq!(rx.recv().await, Some("b".to_string()));
}
#[tokio::test]
async fn test_call_streaming_handler_with_args() {
#[derive(serde::Deserialize)]
struct Input {
n: usize,
}
let mut router = Router::new();
router.register_streaming_with_args("count", |args: Input, tx: StreamSender| async move {
for i in 0..args.n {
tx.send(format!("{i}")).await.ok();
}
format!("counted to {}", args.n)
});
let (mut rx, fut) = router.call_streaming_handler("count", r#"{"n":3}"#).unwrap();
let result = fut.await;
assert_eq!(result, Ok("counted to 3".to_string()));
assert_eq!(rx.recv().await, Some("0".to_string()));
assert_eq!(rx.recv().await, Some("1".to_string()));
assert_eq!(rx.recv().await, Some("2".to_string()));
}
#[tokio::test]
async fn test_call_streaming_handler_not_found() {
let router = Router::new();
let result = router.call_streaming_handler("missing", "{}");
assert!(result.is_err());
match result {
Err(e) => assert!(e.contains("not found")),
Ok(_) => panic!("expected error"),
}
}
#[tokio::test]
async fn test_is_streaming_false_for_regular() {
let mut router = Router::new();
router.register("regular", || async { "ok".to_string() });
assert!(!router.is_streaming("regular"));
}
#[tokio::test]
async fn test_mixed_router() {
let mut router = Router::new();
router.register("get_user", || async { "user".to_string() });
router.register_streaming("stream_updates", |tx: StreamSender| async move {
tx.send("update".to_string()).await.ok();
"done".to_string()
});
let result = router.execute("get_user").await;
assert_eq!(result, Ok("user".to_string()));
let (mut rx, fut) = router.call_streaming_handler("stream_updates", "{}").unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("update".to_string()));
assert!(!router.is_streaming("get_user"));
assert!(router.call_streaming_handler("get_user", "{}").is_err());
}
#[tokio::test]
async fn test_register_streaming_with_state() {
struct AppState {
prefix: String,
}
#[derive(serde::Deserialize)]
struct Input {
name: String,
}
let mut router = Router::new().with_state(AppState {
prefix: "Hello".to_string(),
});
router.register_streaming_with_state::<AppState, Input, _, _, _>(
"greet_stream",
|state: State<Arc<AppState>>, args: Input, tx: StreamSender| async move {
tx.send(format!("{} {}", state.prefix, args.name))
.await
.ok();
"done".to_string()
},
);
let (mut rx, fut) = router
.call_streaming_handler("greet_stream", r#"{"name":"Alice"}"#)
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("Hello Alice".to_string()));
}
#[tokio::test]
async fn test_register_streaming_with_state_only() {
struct AppState {
items: Vec<String>,
}
let mut router = Router::new().with_state(AppState {
items: vec!["x".to_string(), "y".to_string()],
});
router.register_streaming_with_state_only::<AppState, _, _, _>(
"list_stream",
|state: State<Arc<AppState>>, tx: StreamSender| async move {
for item in &state.items {
tx.send(item.clone()).await.ok();
}
format!("listed {}", state.items.len())
},
);
let (mut rx, fut) = router
.call_streaming_handler("list_stream", "{}")
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("listed 2".to_string()));
assert_eq!(rx.recv().await, Some("x".to_string()));
assert_eq!(rx.recv().await, Some("y".to_string()));
}
#[tokio::test]
async fn test_register_stream_no_args() {
let mut router = Router::new();
router.register_stream("items", || async {
tokio_stream::iter(vec!["a".to_string(), "b".to_string(), "c".to_string()])
});
assert!(router.is_streaming("items"));
let (mut rx, fut) = router.call_streaming_handler("items", "{}").unwrap();
let _result = fut.await;
assert_eq!(rx.recv().await, Some("a".to_string()));
assert_eq!(rx.recv().await, Some("b".to_string()));
assert_eq!(rx.recv().await, Some("c".to_string()));
}
#[tokio::test]
async fn test_register_stream_with_args() {
#[derive(serde::Deserialize)]
struct Input {
count: usize,
}
let mut router = Router::new();
router.register_stream_with_args("counting", |args: Input| async move {
tokio_stream::iter((0..args.count).map(|i| format!("{i}")))
});
assert!(router.is_streaming("counting"));
let (mut rx, fut) = router
.call_streaming_handler("counting", r#"{"count":3}"#)
.unwrap();
let _result = fut.await;
assert_eq!(rx.recv().await, Some("0".to_string()));
assert_eq!(rx.recv().await, Some("1".to_string()));
assert_eq!(rx.recv().await, Some("2".to_string()));
}
#[tokio::test]
async fn test_register_stream_with_state() {
struct AppState {
items: Vec<String>,
}
let mut router = Router::new().with_state(AppState {
items: vec!["x".to_string(), "y".to_string()],
});
router.register_stream_with_state::<AppState, serde_json::Value, _, _, _, _>(
"state_stream",
|state: State<Arc<AppState>>, _args: serde_json::Value| {
let items = state.items.clone();
async move { tokio_stream::iter(items) }
},
);
assert!(router.is_streaming("state_stream"));
}
#[tokio::test]
async fn test_stream_adapter_shows_in_is_streaming() {
let mut router = Router::new();
router.register_stream("my_stream", || async {
tokio_stream::iter(vec!["done".to_string()])
});
assert!(router.is_streaming("my_stream"));
assert!(!router.is_streaming("nonexistent"));
}
#[tokio::test]
async fn test_multiple_state_types() {
struct DbPool {
url: String,
}
struct AppConfig {
name: String,
}
#[derive(serde::Deserialize)]
struct Input {
key: String,
}
let mut router = Router::new()
.with_state(DbPool {
url: "postgres://localhost".to_string(),
})
.with_state(AppConfig {
name: "MyApp".to_string(),
});
router.register_with_state::<DbPool, Input, _, _>(
"db_query",
|state: State<Arc<DbPool>>, args: Input| async move {
format!("{}:{}", state.url, args.key)
},
);
router.register_with_state_only::<AppConfig, _, _>(
"app_name",
|state: State<Arc<AppConfig>>| async move { state.name.clone() },
);
let result = router.call_handler("db_query", r#"{"key":"users"}"#).await;
assert_eq!(result, Ok("postgres://localhost:users".to_string()));
let result = router.call_handler("app_name", "{}").await;
assert_eq!(result, Ok("MyApp".to_string()));
}
#[tokio::test]
async fn test_inject_state_after_construction() {
struct LateState {
value: String,
}
let mut router = Router::new();
router.inject_state(LateState {
value: "injected".to_string(),
});
router.register_with_state_only::<LateState, _, _>(
"get_value",
|state: State<Arc<LateState>>| async move { state.value.clone() },
);
let result = router.call_handler("get_value", "{}").await;
assert_eq!(result, Ok("injected".to_string()));
}
#[tokio::test]
async fn test_multiple_state_streaming() {
struct StreamConfig {
prefix: String,
}
let mut router = Router::new().with_state(StreamConfig {
prefix: "stream".to_string(),
});
router.register_streaming_with_state_only::<StreamConfig, _, _, _>(
"prefixed_stream",
|state: State<Arc<StreamConfig>>, tx: StreamSender| async move {
tx.send(format!("{}:item", state.prefix)).await.ok();
"done".to_string()
},
);
let (mut rx, fut) = router
.call_streaming_handler("prefixed_stream", "{}")
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("stream:item".to_string()));
}
#[tokio::test]
async fn test_with_state_duplicate_type_last_wins() {
let mut router = Router::new()
.with_state("first".to_string())
.with_state("second".to_string());
router.register_with_state_only::<String, _, _>(
"get",
|state: State<Arc<String>>| async move { (**state).clone() },
);
let result = router.call_handler("get", "{}").await;
assert_eq!(result, Ok("second".to_string()));
}
mod macro_test_handlers {
use super::{State, StreamSender};
use std::sync::Arc;
pub async fn health() -> String {
"ok".to_string()
}
pub async fn echo(args: EchoArgs) -> String {
args.message
}
#[derive(serde::Deserialize)]
pub struct EchoArgs {
pub message: String,
}
pub async fn ticker(tx: StreamSender) -> String {
tx.send("tick".to_string()).await.ok();
"done".to_string()
}
pub async fn search(args: SearchArgs, tx: StreamSender) -> String {
tx.send(format!("found:{}", args.query)).await.ok();
"complete".to_string()
}
#[derive(serde::Deserialize)]
pub struct SearchArgs {
pub query: String,
}
pub async fn get_status(state: State<Arc<String>>) -> String {
format!("status:{}", *state)
}
pub async fn save_key(state: State<Arc<String>>, args: SaveArgs) -> String {
format!("{}:{}", *state, args.key)
}
#[derive(serde::Deserialize)]
pub struct SaveArgs {
pub key: String,
}
pub async fn state_stream(state: State<Arc<String>>, tx: StreamSender) -> String {
tx.send(format!("{}:chunk", *state)).await.ok();
"done".to_string()
}
pub async fn state_search(
state: State<Arc<String>>,
args: SearchArgs,
tx: StreamSender,
) -> String {
tx.send(format!("{}:{}", *state, args.query)).await.ok();
"complete".to_string()
}
}
#[tokio::test]
async fn test_register_handlers_basic() {
let mut router = Router::new();
register_handlers!(router, [
"health" => macro_test_handlers::health,
]);
assert_eq!(router.handlers_count(), 1);
let result = router.call_handler("health", "{}").await;
assert_eq!(result, Ok("ok".to_string()));
}
#[tokio::test]
async fn test_register_handlers_with_args() {
let mut router = Router::new();
register_handlers!(router, [
args "echo" => macro_test_handlers::echo,
]);
assert_eq!(router.handlers_count(), 1);
let result = router
.call_handler("echo", r#"{"message":"hello"}"#)
.await;
assert_eq!(result, Ok("hello".to_string()));
}
#[tokio::test]
async fn test_register_handlers_streaming() {
let mut router = Router::new();
register_handlers!(router, [
streaming "ticker" => macro_test_handlers::ticker,
]);
assert!(router.is_streaming("ticker"));
let (mut rx, fut) = router.call_streaming_handler("ticker", "{}").unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("tick".to_string()));
}
#[tokio::test]
async fn test_register_handlers_streaming_with_args() {
let mut router = Router::new();
register_handlers!(router, [
streaming args "search" => macro_test_handlers::search,
]);
assert!(router.is_streaming("search"));
let (mut rx, fut) = router
.call_streaming_handler("search", r#"{"query":"rust"}"#)
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("complete".to_string()));
assert_eq!(rx.recv().await, Some("found:rust".to_string()));
}
#[tokio::test]
async fn test_register_handlers_mixed() {
let mut router = Router::new();
register_handlers!(router, [
"health" => macro_test_handlers::health,
args "echo" => macro_test_handlers::echo,
streaming "ticker" => macro_test_handlers::ticker,
streaming args "search" => macro_test_handlers::search,
]);
assert_eq!(router.handlers_count(), 2);
assert_eq!(router.list_handlers().len(), 4);
assert_eq!(
router.call_handler("health", "{}").await,
Ok("ok".to_string())
);
assert_eq!(
router
.call_handler("echo", r#"{"message":"hi"}"#)
.await,
Ok("hi".to_string())
);
assert!(router.is_streaming("ticker"));
assert!(router.is_streaming("search"));
}
#[tokio::test]
async fn test_register_handlers_empty() {
let router = Router::new();
register_handlers!(router, []);
assert_eq!(router.handlers_count(), 0);
}
#[tokio::test]
async fn test_register_handlers_state_only() {
let mut router = Router::new().with_state("active".to_string());
register_handlers!(router, [
state "get_status" => macro_test_handlers::get_status,
]);
let result = router.call_handler("get_status", "{}").await;
assert_eq!(result, Ok("status:active".to_string()));
}
#[tokio::test]
async fn test_register_handlers_state_args() {
let mut router = Router::new().with_state("ns".to_string());
register_handlers!(router, [
state args "save_key" => macro_test_handlers::save_key,
]);
let result = router
.call_handler("save_key", r#"{"key":"api_token"}"#)
.await;
assert_eq!(result, Ok("ns:api_token".to_string()));
}
#[tokio::test]
async fn test_register_handlers_state_streaming() {
let mut router = Router::new().with_state("ctx".to_string());
register_handlers!(router, [
state streaming "state_stream" => macro_test_handlers::state_stream,
]);
assert!(router.is_streaming("state_stream"));
let (mut rx, fut) = router
.call_streaming_handler("state_stream", "{}")
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("ctx:chunk".to_string()));
}
#[tokio::test]
async fn test_register_handlers_state_streaming_args() {
let mut router = Router::new().with_state("db".to_string());
register_handlers!(router, [
state streaming args "state_search" => macro_test_handlers::state_search,
]);
assert!(router.is_streaming("state_search"));
let (mut rx, fut) = router
.call_streaming_handler("state_search", r#"{"query":"rust"}"#)
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("complete".to_string()));
assert_eq!(rx.recv().await, Some("db:rust".to_string()));
}
#[tokio::test]
async fn test_register_handlers_mixed_with_state() {
let mut router = Router::new().with_state("app".to_string());
register_handlers!(router, [
"health" => macro_test_handlers::health,
args "echo" => macro_test_handlers::echo,
state "get_status" => macro_test_handlers::get_status,
state args "save_key" => macro_test_handlers::save_key,
streaming "ticker" => macro_test_handlers::ticker,
state streaming "state_stream" => macro_test_handlers::state_stream,
state streaming args "state_search" => macro_test_handlers::state_search,
]);
assert_eq!(
router.call_handler("health", "{}").await,
Ok("ok".to_string())
);
assert_eq!(
router.call_handler("get_status", "{}").await,
Ok("status:app".to_string())
);
assert_eq!(
router
.call_handler("save_key", r#"{"key":"secret"}"#)
.await,
Ok("app:secret".to_string())
);
assert!(router.is_streaming("ticker"));
assert!(router.is_streaming("state_stream"));
assert!(router.is_streaming("state_search"));
}
#[test]
fn test_camel_to_snake_basic() {
assert_eq!(camel_to_snake("workflowId"), "workflow_id");
assert_eq!(camel_to_snake("actionLabel"), "action_label");
assert_eq!(camel_to_snake("simple"), "simple");
assert_eq!(camel_to_snake("alreadySnake"), "already_snake");
assert_eq!(camel_to_snake("ABC"), "a_b_c");
}
#[test]
fn test_camel_to_snake_single_char() {
assert_eq!(camel_to_snake("a"), "a");
assert_eq!(camel_to_snake("A"), "a");
}
#[test]
fn test_camel_to_snake_empty() {
assert_eq!(camel_to_snake(""), "");
}
#[test]
fn test_camel_to_snake_already_snake() {
assert_eq!(camel_to_snake("already_snake_case"), "already_snake_case");
}
#[test]
fn test_transform_keys_flat_object() {
let input: Value = serde_json::json!({
"workflowId": "abc",
"actionLabel": "run"
});
let result = transform_keys(input, KeyTransform::CamelToSnake);
assert_eq!(result, serde_json::json!({
"workflow_id": "abc",
"action_label": "run"
}));
}
#[test]
fn test_transform_keys_nested_object() {
let input: Value = serde_json::json!({
"outerKey": {
"innerKey": "value"
}
});
let result = transform_keys(input, KeyTransform::CamelToSnake);
assert_eq!(result, serde_json::json!({
"outer_key": {
"inner_key": "value"
}
}));
}
#[test]
fn test_transform_keys_array_of_objects() {
let input: Value = serde_json::json!([
{"firstName": "Alice"},
{"firstName": "Bob"}
]);
let result = transform_keys(input, KeyTransform::CamelToSnake);
assert_eq!(result, serde_json::json!([
{"first_name": "Alice"},
{"first_name": "Bob"}
]));
}
#[test]
fn test_transform_keys_primitive_passthrough() {
assert_eq!(transform_keys(Value::Null, KeyTransform::CamelToSnake), Value::Null);
assert_eq!(transform_keys(serde_json::json!(42), KeyTransform::CamelToSnake), serde_json::json!(42));
assert_eq!(transform_keys(serde_json::json!("hello"), KeyTransform::CamelToSnake), serde_json::json!("hello"));
}
#[tokio::test]
async fn test_router_with_key_transform_camel_to_snake() {
#[derive(serde::Deserialize)]
struct Input {
workflow_name: String,
is_active: bool,
}
let mut router = Router::new()
.with_key_transform(KeyTransform::CamelToSnake);
router.register_with_args("test", |args: Input| async move {
format!("{}:{}", args.workflow_name, args.is_active)
});
let result = router
.call_handler("test", r#"{"workflowName":"deploy","isActive":true}"#)
.await;
assert_eq!(result, Ok("deploy:true".to_string()));
}
#[tokio::test]
async fn test_router_with_key_transform_already_snake() {
#[derive(serde::Deserialize)]
struct Input {
workflow_name: String,
}
let mut router = Router::new()
.with_key_transform(KeyTransform::CamelToSnake);
router.register_with_args("test", |args: Input| async move {
args.workflow_name
});
let result = router
.call_handler("test", r#"{"workflow_name":"deploy"}"#)
.await;
assert_eq!(result, Ok("deploy".to_string()));
}
#[tokio::test]
async fn test_router_without_key_transform() {
#[derive(serde::Deserialize)]
struct Input {
workflow_name: String,
}
let mut router = Router::new(); router.register_with_args("test", |args: Input| async move {
args.workflow_name
});
let result = router
.call_handler("test", r#"{"workflowName":"deploy"}"#)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to deserialize"));
}
#[tokio::test]
async fn test_router_key_transform_streaming_handler() {
#[derive(serde::Deserialize)]
struct Input {
item_count: usize,
}
let mut router = Router::new()
.with_key_transform(KeyTransform::CamelToSnake);
router.register_streaming_with_args("stream", |args: Input, tx: StreamSender| async move {
for i in 0..args.item_count {
tx.send(format!("{i}")).await.ok();
}
"done".to_string()
});
let (mut rx, fut) = router
.call_streaming_handler("stream", r#"{"itemCount":2}"#)
.unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("0".to_string()));
assert_eq!(rx.recv().await, Some("1".to_string()));
}
#[tokio::test]
async fn test_router_key_transform_with_state() {
struct AppState {
prefix: String,
}
#[derive(serde::Deserialize)]
struct Input {
user_name: String,
}
let mut router = Router::new()
.with_key_transform(KeyTransform::CamelToSnake)
.with_state(AppState { prefix: "Hello".to_string() });
router.register_with_state::<AppState, Input, _, _>(
"greet",
|state: State<Arc<AppState>>, args: Input| async move {
format!("{} {}", state.prefix, args.user_name)
},
);
let result = router
.call_handler("greet", r#"{"userName":"Alice"}"#)
.await;
assert_eq!(result, Ok("Hello Alice".to_string()));
}
#[tokio::test]
async fn test_router_key_transform_zero_arg_handler_unaffected() {
let mut router = Router::new()
.with_key_transform(KeyTransform::CamelToSnake);
router.register("health", || async { "ok".to_string() });
let result = router.call_handler("health", "{}").await;
assert_eq!(result, Ok("ok".to_string()));
}
#[tokio::test]
async fn test_register_erased_no_args() {
let mut router = Router::new();
router.register_erased("health", crate::erase_handler!(|| async { "ok".to_string() }));
let result = router.call_handler("health", "{}").await;
assert_eq!(result, Ok("ok".to_string()));
}
#[tokio::test]
async fn test_register_erased_with_args() {
#[derive(serde::Deserialize)]
struct Args { greeting: String }
async fn greet(args: Args) -> String {
format!("hello {}", args.greeting)
}
let mut router = Router::new();
router.register_erased("greet", crate::erase_handler_with_args!(greet, Args));
let result = router.call_handler("greet", r#"{"greeting":"world"}"#).await;
assert_eq!(result, Ok("hello world".to_string()));
}
#[tokio::test]
async fn test_register_erased_with_state() {
#[derive(serde::Deserialize)]
struct Args { #[allow(dead_code)] key: String }
async fn with_state(state: handler::State<std::sync::Arc<String>>, _args: Args) -> String {
format!("state={}", *state)
}
let mut router = Router::new().with_state("mystate".to_string());
let states = router.shared_states();
router.register_erased(
"stateful",
crate::erase_handler_with_state!(with_state, String, Args, states),
);
let result = router.call_handler("stateful", r#"{"key":"v"}"#).await;
assert_eq!(result, Ok("state=mystate".to_string()));
}
#[tokio::test]
async fn test_register_erased_with_state_only() {
async fn check(state: handler::State<std::sync::Arc<u32>>) -> String {
format!("n={}", *state)
}
let mut router = Router::new().with_state(42u32);
let states = router.shared_states();
router.register_erased(
"check",
crate::erase_handler_with_state_only!(check, u32, states),
);
let result = router.call_handler("check", "{}").await;
assert_eq!(result, Ok("n=42".to_string()));
}
#[tokio::test]
async fn test_register_streaming_erased() {
let mut router = Router::new();
router.register_streaming_erased(
"stream",
crate::erase_streaming_handler!(|tx: handler::StreamSender| async move {
tx.send("chunk".to_string()).await.ok();
"done".to_string()
}),
);
let (mut rx, fut) = router.call_streaming_handler("stream", "{}").unwrap();
let result = fut.await;
assert_eq!(result, Ok("done".to_string()));
assert_eq!(rx.recv().await, Some("chunk".to_string()));
}
#[tokio::test]
async fn test_register_handlers_erased_macro() {
#[derive(serde::Deserialize)]
struct GreetArgs { name: String }
async fn health() -> String { "ok".into() }
async fn greet(args: GreetArgs) -> String { format!("hi {}", args.name) }
let mut router = Router::new();
crate::register_handlers_erased!(router, {
"health" => health(),
"greet" => greet(args: GreetArgs),
});
assert_eq!(router.call_handler("health", "{}").await, Ok("ok".to_string()));
assert_eq!(
router.call_handler("greet", r#"{"name":"Alice"}"#).await,
Ok("hi Alice".to_string()),
);
}
}