use coap_lite::{CoapRequest, CoapResponse, ObserveOption, Packet, RequestType, ResponseType};
use route_recognizer::Router;
use serde_json::Value;
use std::collections::HashMap;
use std::convert::Infallible;
use std::fmt::Debug;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::mpsc::{self, Sender};
use tokio::sync::{Mutex, RwLock};
use tower::Service;
use crate::handler::{ErasedHandler, Handler, HandlerFn, into_erased_handler, into_handler};
use crate::observer::{Observer, ObserverRequest, ObserverValue};
use crate::router::wrapper::IntoCoapResponse;
use self::wrapper::{RequestTypeWrapper, RouteHandler};
pub mod wrapper;
pub type RouterError = Box<(dyn std::error::Error + Send + Sync + 'static)>;
type StateUpdateFn<S> = Box<dyn FnOnce(&mut S) + Send + 'static>;
type StateUpdateSender<S> = mpsc::Sender<StateUpdateFn<S>>;
type StateUpdateReceiver<S> = mpsc::Receiver<StateUpdateFn<S>>;
#[derive(Clone)]
pub struct NotificationTrigger<O>
where
O: Observer + Send + Sync + Clone + 'static,
{
observer: O,
}
impl<O> NotificationTrigger<O>
where
O: Observer + Send + Sync + Clone + 'static,
{
pub fn new(observer: O) -> Self {
Self { observer }
}
pub async fn trigger_notification(
&mut self,
device_id: &str,
path: &str,
payload: &serde_json::Value,
) -> Result<(), O::Error> {
self.observer.write(device_id, path, payload).await
}
}
#[derive(Clone)]
pub struct StateUpdateHandle<S>
where
S: Send + Sync + Clone + 'static,
{
sender: StateUpdateSender<S>,
}
impl<S> StateUpdateHandle<S>
where
S: Send + Sync + Clone + 'static,
{
pub fn new(sender: StateUpdateSender<S>) -> Self {
Self { sender }
}
pub async fn update<F>(&self, updater: F) -> Result<(), StateUpdateError>
where
F: FnOnce(&mut S) + Send + 'static,
{
self.sender
.send(Box::new(updater))
.await
.map_err(|_| StateUpdateError::ChannelClosed)
}
pub fn try_update<F>(&self, updater: F) -> Result<(), StateUpdateError>
where
F: FnOnce(&mut S) + Send + 'static,
{
self.sender
.try_send(Box::new(updater))
.map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => StateUpdateError::ChannelFull,
mpsc::error::TrySendError::Closed(_) => StateUpdateError::ChannelClosed,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum StateUpdateError {
ChannelFull,
ChannelClosed,
}
impl std::fmt::Display for StateUpdateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StateUpdateError::ChannelFull => write!(f, "State update channel is full"),
StateUpdateError::ChannelClosed => write!(f, "State update channel is closed"),
}
}
}
impl std::error::Error for StateUpdateError {}
#[derive(Clone)]
pub struct ClientManager {
sender: mpsc::Sender<ClientCommand>,
}
#[derive(Debug)]
pub enum ClientCommand {
AddClient {
identity: String,
key: Vec<u8>,
metadata: Option<ClientMetadata>,
},
RemoveClient { identity: String },
UpdateKey { identity: String, key: Vec<u8> },
UpdateMetadata {
identity: String,
metadata: ClientMetadata,
},
SetClientEnabled { identity: String, enabled: bool },
ListClients {
response: tokio::sync::oneshot::Sender<Vec<String>>,
},
}
#[derive(Debug, Clone, Default)]
pub struct ClientMetadata {
pub name: Option<String>,
pub description: Option<String>,
pub enabled: bool,
pub tags: Vec<String>,
pub custom: HashMap<String, String>,
}
impl ClientManager {
pub fn new(sender: mpsc::Sender<ClientCommand>) -> Self {
Self { sender }
}
pub async fn add_client(&self, identity: &str, key: &[u8]) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::AddClient {
identity: identity.to_string(),
key: key.to_vec(),
metadata: None,
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn add_client_with_metadata(
&self,
identity: &str,
key: &[u8],
metadata: ClientMetadata,
) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::AddClient {
identity: identity.to_string(),
key: key.to_vec(),
metadata: Some(metadata),
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn remove_client(&self, identity: &str) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::RemoveClient {
identity: identity.to_string(),
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn update_key(&self, identity: &str, key: &[u8]) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::UpdateKey {
identity: identity.to_string(),
key: key.to_vec(),
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn update_metadata(
&self,
identity: &str,
metadata: ClientMetadata,
) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::UpdateMetadata {
identity: identity.to_string(),
metadata,
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn set_client_enabled(
&self,
identity: &str,
enabled: bool,
) -> Result<(), ClientManagerError> {
self.sender
.send(ClientCommand::SetClientEnabled {
identity: identity.to_string(),
enabled,
})
.await
.map_err(|_| ClientManagerError::ChannelClosed)
}
pub async fn list_clients(&self) -> Result<Vec<String>, ClientManagerError> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.sender
.send(ClientCommand::ListClients { response: tx })
.await
.map_err(|_| ClientManagerError::ChannelClosed)?;
rx.await.map_err(|_| ClientManagerError::ResponseFailed)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ClientManagerError {
ChannelClosed,
ResponseFailed,
}
impl std::fmt::Display for ClientManagerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientManagerError::ChannelClosed => write!(f, "Client manager channel is closed"),
ClientManagerError::ResponseFailed => {
write!(f, "Failed to receive response from client manager")
}
}
}
}
impl std::error::Error for ClientManagerError {}
#[derive(Debug, Clone)]
pub struct ClientEntry {
pub key: Vec<u8>,
pub metadata: ClientMetadata,
}
pub type ClientStore = Arc<RwLock<HashMap<String, ClientEntry>>>;
#[derive(Clone)]
pub struct CoapRouter<O, S>
where
S: Clone + Debug + Send + Sync + 'static,
O: Observer,
{
inner: Router<HashMap<RequestTypeWrapper, RouteHandler<S>>>,
state: Arc<Mutex<S>>, db: O,
state_update_sender: Option<StateUpdateSender<S>>,
}
impl<O, S> CoapRouter<O, S>
where
S: Send + Sync + Clone + Debug + 'static,
O: Observer + Send + Sync + Clone + 'static,
{
pub fn new(state: S, db: O) -> Self {
Self {
inner: Router::new(),
state: Arc::new(Mutex::new(state)),
db,
state_update_sender: None,
}
}
pub fn builder(state: S, observer: O) -> RouterBuilder<O, S> {
RouterBuilder::new(state, observer)
}
pub async fn register_observer(
&mut self,
device_id: &str,
path: &str,
sender: Arc<Sender<ObserverValue>>,
) -> Result<(), O::Error> {
self.db.register(device_id, path, sender).await
}
pub async fn unregister_observer(
&mut self,
device_id: &str,
path: &str,
) -> Result<(), O::Error> {
self.db.unregister(device_id, path).await
}
pub async fn unregister_all(&mut self, _device_id: &str) -> Result<(), O::Error> {
self.db.unregister_all().await
}
pub async fn backend_write(
&mut self,
device_id: &str,
path: &str,
payload: &Value,
) -> Result<(), O::Error> {
self.db.write(device_id, path, payload).await
}
pub async fn trigger_notification(
&mut self,
device_id: &str,
path: &str,
payload: &Value,
) -> Result<(), O::Error> {
self.backend_write(device_id, path, payload).await
}
pub async fn backend_read(
&mut self,
device_id: &str,
path: &str,
) -> Result<Option<Value>, O::Error> {
self.db.read(device_id, path).await
}
pub fn enable_state_updates(&mut self, buffer_size: usize) -> StateUpdateHandle<S> {
let (sender, receiver) = mpsc::channel(buffer_size);
self.state_update_sender = Some(sender.clone());
let state = Arc::clone(&self.state);
tokio::spawn(async move {
Self::process_state_updates(state, receiver).await;
});
StateUpdateHandle::new(sender)
}
async fn process_state_updates(state: Arc<Mutex<S>>, mut receiver: StateUpdateReceiver<S>) {
while let Some(update) = receiver.recv().await {
let mut state_guard = state.lock().await;
update(&mut *state_guard);
}
}
pub fn state_update_handle(&self) -> Option<StateUpdateHandle<S>> {
self.state_update_sender
.as_ref()
.map(|sender| StateUpdateHandle::new(sender.clone()))
}
pub fn add(&mut self, route: &str, handler: RouteHandler<S>) {
match self.inner.recognize(route) {
Ok(r) => {
let mut r = (**r.handler()).clone();
r.insert(handler.method.into(), handler);
self.inner.add(route, r);
}
Err(_) => {
let mut r = HashMap::new();
r.insert(handler.method.into(), handler);
self.inner.add(route, r);
}
};
}
pub fn lookup_observer_handler(&self, path: &str) -> Option<Box<dyn ErasedHandler<S>>> {
log::debug!("Looking up observer handler for path: '{}'", path);
match self.inner.recognize(path) {
Ok(matched) => {
let handler = matched.handler();
let reqtype: RequestTypeWrapper = RequestType::Get.into();
log::debug!("Matched route: {:?}", matched);
match handler.get(&reqtype) {
Some(h) => {
log::debug!(
"Matched handler, has observe_handler: {}",
h.observe_handler.is_some()
);
h.observe_handler
.as_ref()
.map(|handler| handler.clone_erased())
}
None => {
log::debug!("No handler found for GET method");
None
}
}
}
Err(e) => {
log::warn!(
"Unable to recognize observer handler path '{}'. Err: {}",
path,
e
);
None
}
}
}
pub fn lookup(&self, r: &CoapumRequest<SocketAddr>) -> Option<Box<dyn ErasedHandler<S>>> {
match self.inner.recognize(r.get_path()) {
Ok(matched) => {
let handler = matched.handler();
let reqtype: RequestTypeWrapper = r.get_method().into();
log::debug!("Matched route: {:?}", matched);
match handler.get(&reqtype) {
Some(h) => {
log::debug!("Matched handler: {:?}", h);
Some(h.handler.clone_erased())
}
None => {
log::debug!("No handler found");
None
}
}
}
Err(e) => {
log::warn!("Unable to recognize. Err: {}", e);
None
}
}
}
}
pub struct RouterBuilder<O, S>
where
S: Clone + Debug + Send + Sync + 'static,
O: Observer + Send + Sync + Clone + 'static,
{
router: CoapRouter<O, S>,
}
impl<O, S> RouterBuilder<O, S>
where
S: Clone + Debug + Send + Sync + 'static,
O: Observer + Send + Sync + Clone + 'static,
{
pub fn new(state: S, observer: O) -> Self {
Self {
router: CoapRouter::new(state, observer),
}
}
fn add_route<F, T>(&mut self, path: &str, method: RequestType, handler: F)
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
let route_handler = RouteHandler {
handler: into_erased_handler(into_handler(handler)),
observe_handler: None,
method,
};
self.router.add(path, route_handler);
}
pub fn get<F, T>(mut self, path: &str, handler: F) -> Self
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
self.add_route(path, RequestType::Get, handler);
self
}
pub fn post<F, T>(mut self, path: &str, handler: F) -> Self
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
self.add_route(path, RequestType::Post, handler);
self
}
pub fn put<F, T>(mut self, path: &str, handler: F) -> Self
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
self.add_route(path, RequestType::Put, handler);
self
}
pub fn delete<F, T>(mut self, path: &str, handler: F) -> Self
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
self.add_route(path, RequestType::Delete, handler);
self
}
pub fn any<F, T>(mut self, path: &str, handler: F) -> Self
where
HandlerFn<F, S>: Handler<T, S>,
F: Send + Sync + Clone,
T: Send + Sync + 'static,
{
self.add_route(path, RequestType::UnKnown, handler);
self
}
pub fn observe<F1, T1, F2, T2>(
mut self,
path: &str,
get_handler: F1,
notify_handler: F2,
) -> Self
where
HandlerFn<F1, S>: Handler<T1, S>,
HandlerFn<F2, S>: Handler<T2, S>,
F1: Send + Sync + Clone,
F2: Send + Sync + Clone,
T1: Send + Sync + 'static,
T2: Send + Sync + 'static,
{
let route_handler = RouteHandler {
handler: into_erased_handler(into_handler(get_handler)),
observe_handler: Some(into_erased_handler(into_handler(notify_handler))),
method: RequestType::Get,
};
self.router.add(path, route_handler);
self
}
pub fn build(self) -> CoapRouter<O, S> {
self.router
}
pub fn notification_trigger(&self) -> NotificationTrigger<O> {
NotificationTrigger::new(self.router.db.clone())
}
pub fn enable_state_updates(&mut self, buffer_size: usize) -> StateUpdateHandle<S> {
self.router.enable_state_updates(buffer_size)
}
pub fn state_update_handle(&self) -> Option<StateUpdateHandle<S>> {
self.router.state_update_handle()
}
pub fn router_mut(&mut self) -> &mut CoapRouter<O, S> {
&mut self.router
}
}
#[derive(Debug, Clone)]
pub struct CoapumRequest<Endpoint> {
pub message: Packet,
code: RequestType,
path: String,
observe_flag: Option<ObserveOption>,
pub response: Option<CoapResponse>,
pub source: Option<Endpoint>,
pub identity: String,
}
impl<Endpoint> From<CoapRequest<Endpoint>> for CoapumRequest<Endpoint> {
fn from(req: CoapRequest<Endpoint>) -> Self {
let path = req.get_path();
let code = *req.get_method();
let observe_flag = match req.get_observe_flag() {
Some(o) => o.ok(),
None => None,
};
Self {
message: req.message,
response: req.response,
source: req.source,
path,
code,
observe_flag,
identity: String::new(),
}
}
}
impl<Endpoint> CoapumRequest<Endpoint> {
pub fn get_path(&self) -> &String {
&self.path
}
pub fn get_method(&self) -> &RequestType {
&self.code
}
pub fn get_observe_flag(&self) -> &Option<ObserveOption> {
&self.observe_flag
}
}
impl<O, S> Service<CoapumRequest<SocketAddr>> for CoapRouter<O, S>
where
S: Debug + Send + Clone + Sync + 'static,
O: Observer + Send + Sync + Clone + 'static,
{
type Response = CoapResponse;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: CoapumRequest<SocketAddr>) -> Self::Future {
let state = self.state.clone();
match self.lookup(&request) {
Some(handler) => {
let path = request.get_path();
log::debug!("Handler found for route: {:?}", &path);
Box::pin(async move { handler.call_erased(request, state).await })
}
None => {
log::info!(
"No handler found for method: {:#?} to: {:?}",
request.get_method(),
request.get_path()
);
Box::pin(async move { (ResponseType::BadRequest, &request).into_response() })
}
}
}
}
impl<O, S> Service<ObserverRequest<SocketAddr>> for CoapRouter<O, S>
where
S: Debug + Send + Clone + Sync + 'static,
O: Observer + Send + Sync + Clone + 'static,
{
type Response = CoapResponse;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, request: ObserverRequest<SocketAddr>) -> Self::Future {
let state = self.state.clone();
log::debug!("Processing ObserverRequest for path: {}", request.path);
match self.lookup_observer_handler(&request.path) {
Some(handler) => {
log::debug!("Handler found for route: {:?}", &request.path);
let packet = Packet::default();
let mut raw = CoapRequest::from_packet(packet, request.source);
raw.set_path(&request.path);
let mut coap_request: CoapumRequest<SocketAddr> = raw.into();
coap_request.identity = String::new();
Box::pin(async move { handler.call_erased(coap_request, state).await })
}
None => {
log::debug!("No observer handler found for: {}", request.path);
Box::pin(async move { (ResponseType::BadRequest).into_response() })
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::extract::{Identity, StatusCode};
#[derive(Clone, Debug)]
struct TestState {
#[allow(dead_code)]
counter: i32,
}
impl AsRef<TestState> for TestState {
fn as_ref(&self) -> &TestState {
self
}
}
#[tokio::test]
async fn test_register_observer() {
let state = TestState { counter: 0 };
let mut router = CoapRouter::new(state, ());
let (sender, _receiver) = tokio::sync::mpsc::channel(10);
let sender = Arc::new(sender);
let result = router
.register_observer("device123", "/temperature", sender)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_unregister_observer() {
let state = TestState { counter: 0 };
let mut router = CoapRouter::new(state, ());
let result = router
.unregister_observer("device123", "/temperature")
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_backend_write_and_read() {
let state = TestState { counter: 0 };
let mut router = CoapRouter::new(state, ());
let payload = serde_json::json!({"value": 25});
let write_result = router
.backend_write("device123", "/temperature", &payload)
.await;
assert!(write_result.is_ok());
}
#[tokio::test]
async fn test_add_and_lookup() {
let state = TestState { counter: 0 };
let mut router = CoapRouter::new(state, ());
let handler = RouteHandler {
handler: into_erased_handler(into_handler(|| async { StatusCode::Valid })),
observe_handler: None,
method: RequestType::Get,
};
router.add("/test", handler);
let packet = Packet::new();
let raw = CoapRequest::from_packet(packet, "127.0.0.1:5683".parse().unwrap());
let mut request: CoapumRequest<SocketAddr> = raw.into();
request.path = "/test".to_string();
request.code = RequestType::Get;
let result = router.lookup(&request);
assert!(result.is_some());
}
#[tokio::test]
async fn test_add_and_lookup_observer_handler() {
let state = TestState { counter: 0 };
let mut router = CoapRouter::new(state, ());
let handler = RouteHandler {
handler: into_erased_handler(into_handler(|| async { StatusCode::Valid })),
observe_handler: Some(into_erased_handler(into_handler(|| async {
StatusCode::Content
}))),
method: RequestType::Get,
};
router.add("/observable", handler);
let result = router.lookup_observer_handler("/observable");
assert!(result.is_some());
}
#[tokio::test]
async fn test_router_builder() {
async fn test_handler() -> StatusCode {
StatusCode::Valid
}
let state = TestState { counter: 0 };
let _router = RouterBuilder::new(state, ())
.get("/test", test_handler)
.post("/test", test_handler)
.build();
}
#[tokio::test]
async fn test_handler_with_extractor() {
async fn identity_handler(Identity(_id): Identity) -> StatusCode {
StatusCode::Valid
}
let state = TestState { counter: 0 };
let _router = RouterBuilder::new(state, ())
.get("/user", identity_handler)
.build();
}
#[tokio::test]
async fn test_observe_handler() {
async fn get_handler() -> StatusCode {
StatusCode::Content
}
async fn notify_handler() -> StatusCode {
StatusCode::Valid
}
let state = TestState { counter: 0 };
let _router = RouterBuilder::new(state, ())
.observe("/observable", get_handler, notify_handler)
.build();
}
#[tokio::test]
async fn test_builder_convenience_method() {
async fn test_handler() -> StatusCode {
StatusCode::Valid
}
let state = TestState { counter: 0 };
let _router = CoapRouter::builder(state, ())
.get("/test", test_handler)
.build();
}
}