use anyhow::Result;
use caryatid_sdk::{module, Context, MessageBounds};
use config::Config;
use std::{collections::HashMap, sync::Arc};
use tracing::{error, info};
use axum::{
body::Body,
http::{Request, StatusCode},
response::Response,
Router,
};
use hyper::body;
use tower_http::cors::{Any, CorsLayer};
use std::convert::Infallible;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
pub mod messages;
use messages::{GetRESTResponse, RESTRequest, RESTResponse};
const DEFAULT_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const DEFAULT_PORT: u16 = 4340;
const MAX_LOG: usize = 40;
#[module(message_type(M), name = "rest-server", description = "REST server")]
pub struct RESTServer<M: From<RESTRequest> + GetRESTResponse + MessageBounds>;
impl<M: From<RESTRequest> + GetRESTResponse + MessageBounds> RESTServer<M> {
async fn init(&self, context: Arc<Context<M>>, config: Arc<Config>) -> Result<()> {
let message_bus = context.message_bus.clone();
let topic_prefix = config.get_string("topic").unwrap_or("rest".to_string());
let handle_request = |req: Request<Body>| async move {
info!(
"Received REST request {} {}{}",
req.method().as_str(),
req.uri().path(),
req.uri()
.query()
.map(|q| format!("?{}", q))
.unwrap_or_default(),
);
let method = req.method().as_str().to_string();
let path = req.uri().path().to_string();
let query_parameters: HashMap<String, String> = req
.uri()
.query()
.unwrap_or("")
.split('&')
.filter(|s| !s.is_empty())
.filter_map(|pair| {
let mut parts = pair.splitn(2, '=');
let key = parts.next()?;
let value = parts.next().unwrap_or("");
Some((key.to_string(), value.to_string()))
})
.collect();
let bytes = match body::to_bytes(req.into_body()).await {
Ok(b) => b,
Err(e) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(e.to_string())
.unwrap())
}
};
let body = match String::from_utf8(bytes.to_vec()) {
Ok(b) => b,
Err(e) => {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(e.to_string())
.unwrap())
}
};
let method_lower = method.to_lowercase();
let dot_path = path.strip_prefix("/").unwrap_or(&path);
let dot_path = dot_path.strip_suffix("/").unwrap_or(dot_path);
let dot_path = dot_path.replace('/', ".");
let topic = format!("{topic_prefix}.{method_lower}.{dot_path}");
info!("Sending to topic {}", topic);
let path_elements = dot_path.split('.').map(String::from).collect();
let message = RESTRequest {
method,
path,
body,
path_elements,
query_parameters,
};
let response = match message_bus.request(&topic, Arc::new(message.into())).await {
Ok(response) => match response.get_rest_response() {
Some(RESTResponse {
code,
body,
content_type,
}) => {
info!(
"Got response: {code} {}{}",
&body[..std::cmp::min(body.len(), MAX_LOG)],
if body.len() > MAX_LOG { "..." } else { "" }
);
Response::builder()
.status(
StatusCode::from_u16(code)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
)
.header("Content-Type", content_type)
.body(body)
.unwrap()
}
_ => {
error!("Response isn't RESTResponse");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body("".to_string())
.unwrap()
}
},
Err(_) => {
error!("No handler for {topic}");
Response::builder()
.status(StatusCode::NOT_FOUND)
.body("".to_string())
.unwrap()
}
};
Ok::<_, Infallible>(response)
};
context.run(async move {
let ip = config.get::<IpAddr>("address").unwrap_or(DEFAULT_IP);
let port: u16 = config.get::<u16>("port").unwrap_or(DEFAULT_PORT);
let addr = SocketAddr::from((ip, port));
info!("REST server listening on http://{}", addr);
let app = Router::new().fallback(handle_request).layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use caryatid_sdk::mock_bus::MockBus;
use caryatid_sdk::Module;
use config::{Config, FileFormat};
use futures::future;
use hyper::Client;
use std::net::TcpListener;
use tokio::sync::{watch::Sender, Notify};
use tokio::time::{timeout, Duration};
use tracing::{debug, Level};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum Message {
None(()),
RESTRequest(RESTRequest), RESTResponse(RESTResponse), }
impl Default for Message {
fn default() -> Self {
Message::None(())
}
}
impl From<RESTRequest> for Message {
fn from(msg: RESTRequest) -> Self {
Message::RESTRequest(msg)
}
}
impl From<RESTResponse> for Message {
fn from(msg: RESTResponse) -> Self {
Message::RESTResponse(msg)
}
}
impl GetRESTResponse for Message {
fn get_rest_response(&self) -> Option<RESTResponse> {
if let Message::RESTResponse(result) = self {
Some(result.clone())
} else {
None
}
}
}
struct TestSetup {
module: Arc<dyn Module<Message>>,
context: Arc<Context<Message>>,
startup_watch: Sender<bool>,
}
impl TestSetup {
async fn new(config_str: &str) -> Self {
let _ = tracing_subscriber::fmt()
.with_max_level(Level::DEBUG)
.with_test_writer()
.try_init();
let config = Arc::new(
Config::builder()
.add_source(config::File::from_str(config_str, FileFormat::Toml))
.build()
.unwrap(),
);
let mock_bus = Arc::new(MockBus::<Message>::new(&config));
let startup_watch = Sender::new(false);
let context = Arc::new(Context::new(
config.clone(),
mock_bus,
startup_watch.subscribe(),
));
let rest_server = RESTServer::<Message> {
_marker: std::marker::PhantomData,
};
assert!(rest_server.init(context.clone(), config).await.is_ok());
Self {
module: Arc::new(rest_server),
context,
startup_watch,
}
}
fn start(&self) {
let _ = self.startup_watch.send(true);
}
}
#[tokio::test]
async fn construct_a_rest_server() {
let setup = TestSetup::new("").await;
assert_eq!(setup.module.get_name(), "rest-server");
assert_eq!(setup.module.get_description(), "REST server");
}
#[tokio::test]
async fn rest_server_generates_request_and_returns_response() {
let port: u16;
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to address");
port = listener.local_addr().unwrap().port()
}
assert!(port > 0);
let setup = TestSetup::new(&format!("port = {port}")).await;
let notify = Arc::new(Notify::new());
let notify_clone = notify.clone();
setup
.context
.handle("rest.get.test", move |message: Arc<Message>| {
let response = match message.as_ref() {
Message::RESTRequest(request) => {
info!(
"REST hello world received {} {}",
request.method, request.path
);
RESTResponse::with_text(200, "Hello, world!")
}
_ => {
error!("Unexpected message type {:?}", message);
RESTResponse::with_text(500, "Unexpected message in REST request")
}
};
notify_clone.notify_one();
future::ready(Arc::new(Message::RESTResponse(response)))
});
setup.start();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let client = Client::new();
let uri = format!("http://127.0.0.1:{}/test", port).parse().unwrap();
match timeout(Duration::from_secs(1), client.get(uri)).await {
Ok(Ok(response)) => {
debug!("HTTP response: {:?}", response);
assert_eq!(response.status(), 200);
}
Ok(Err(e)) => panic!("HTTP request failed: {e}"),
Err(e) => panic!("HTTP request timed out: {e}"),
}
assert!(
timeout(Duration::from_secs(1), notify.notified())
.await
.is_ok(),
"Didn't receive a rest.get.test message"
);
}
#[tokio::test]
async fn rest_server_with_no_handler_generates_404() {
let port: u16;
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to address");
port = listener.local_addr().unwrap().port()
}
assert!(port > 0);
let setup = TestSetup::new(&format!("port = {port}\nrequest-timeout = 1")).await;
setup.start();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let client = Client::new();
let uri = format!("http://127.0.0.1:{}/test", port).parse().unwrap();
match timeout(Duration::from_secs(2), client.get(uri)).await {
Ok(Ok(response)) => {
debug!("HTTP response: {:?}", response);
assert_eq!(response.status(), 404);
}
Ok(Err(e)) => panic!("HTTP request failed: {e}"),
Err(e) => panic!("HTTP request timed out: {e}"),
}
}
}