use std::{collections::HashMap, sync::Arc};
use tower::Service;
use uuid::Uuid;
use crate::{Frame, FrameFuture, Handler, MessageFrame};
use super::{path_error::PathError, Endpoint, RouteId, Routes};
#[derive(Clone, Default)]
pub struct Router<State> {
routes: Arc<Routes>,
actions: HashMap<RouteId, Endpoint<State>>,
}
impl<State> Router<State>
where
State: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
routes: Arc::new(Routes::default()),
actions: HashMap::new(),
}
}
pub fn capture<T>(&self, path: impl AsRef<str>) -> Result<T, PathError>
where
T: serde::de::DeserializeOwned,
{
self.routes.capture(path)
}
pub fn handle_frame_with_state(&self, frame: Frame, state: State) -> FrameFuture {
let endpoint = match &frame.clone() {
Frame::Message(MessageFrame { uri, .. }) => {
if let Some(endpoint) = self
.routes
.at(uri)
.map(|(found_it, _)| found_it)
.and_then(|route_id| self.actions.get(route_id))
{
endpoint
} else {
return FrameFuture::empty();
}
}
_ => return FrameFuture::empty(),
};
let endpoint = endpoint.clone().into_inner();
endpoint.into_actionable(state).call(frame)
}
fn insert_endpoint(
&mut self,
route: impl Into<String>,
endpoint: Endpoint<State>,
) -> crate::Result<()>
where
State: Clone + Send + 'static,
{
let route = {
let route = route.into();
match route.as_str() {
"" => "/".to_string(),
_ => route,
}
};
let initial_fetch = {
let key = route.as_str();
let fetch_by_path = self.routes.get_id(key);
fetch_by_path
.and_then(|route_id| self.actions.get(&route_id).map(|action| (route_id, action)))
};
#[expect(
unused_variables,
reason = "we want to add some feedback to indicate that the route already exists"
)]
if let Some((route_id, action)) = initial_fetch {
self.actions.insert(route_id, endpoint);
} else {
let route_id = Uuid::new_v4().into();
self.insert_route(&route, route_id);
self.actions.insert(route_id, endpoint);
}
Ok(())
}
fn insert_route(&mut self, route: &str, id: RouteId) {
let routes = Arc::make_mut(&mut self.routes);
routes.insert(route, id);
}
pub fn insert<ActionHandler, Args>(
&mut self,
route: impl Into<String>,
given_action: ActionHandler,
) -> crate::Result<()>
where
ActionHandler: Handler<Args, State> + Clone + Send + Sync + 'static,
Args: Clone + Send + Sync + 'static,
{
let endpoint = Endpoint::new(given_action);
self.insert_endpoint(route, endpoint)
}
pub fn routes(&self) -> Vec<String> {
self.routes.paths()
}
pub fn endpoints(&self) -> Vec<Endpoint<State>>
where
State: Clone + Send + 'static,
{
let mut endpoints = Vec::new();
for endpoint in self.actions.values() {
endpoints.push(endpoint.clone());
}
endpoints
}
pub fn scope(&mut self, route: impl Into<String>, other: Router<State>) -> crate::Result<()>
where
State: Clone + Send + 'static,
{
let route = route.into();
let Router { routes, actions } = other;
for (id, action) in actions {
let path = routes.get_path(id).unwrap();
let route = {
let candidate = if path.starts_with('/') {
format!("{route}{path}")
} else {
format!("{route}/{path}")
};
match candidate.as_str() {
"/" => "/".to_string(),
partial if partial.ends_with('/') => partial[..partial.len() - 1].to_string(),
_ => candidate,
}
};
let nested_action = action.clone();
self.insert_endpoint(route, nested_action)?;
}
Ok(())
}
}
impl<State> std::fmt::Debug for Router<State> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Router")
.field("routes", &self.routes)
.finish()
}
}