use axum::{
middleware::{self},
routing::{get, put, post},
http::HeaderValue,
extract::{Request, State},
Router,
ServiceExt,
response::IntoResponse,
Json,
};
use std::sync::Arc;
use tracing::info;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tower_http::normalize_path::NormalizePathLayer;
use tower::Layer;
use serde_json::json;
use http::header::CONTENT_TYPE;
use anyhow;
use crate::error::AppserviceError;
use crate::config::Config;
use crate::rooms::{public_rooms, room_info, join_room, leave_room};
use crate::middleware::{
is_admin,
authenticate_homeserver,
validate_public_room,
validate_room_id,
add_data
};
use crate::ping::ping;
use crate::api::{
transactions,
matrix_proxy,
matrix_proxy_search
};
use crate::space::{
spaces,
space_summary,
};
pub struct Server{
state: Arc<AppState>,
}
pub use crate::AppState;
impl Server {
pub fn new(state: Arc<AppState>) -> Self {
Self {
state
}
}
pub fn setup_cors(&self, config: &Config) -> CorsLayer {
let mut layer = CorsLayer::new()
.allow_origin(Any)
.allow_headers(vec![CONTENT_TYPE]);
layer = match &config.server.allow_origin {
Some(origins) if !origins.is_empty() &&
!origins.contains(&"".to_string()) &&
!origins.contains(&"*".to_string()) => {
let origins = origins.iter().filter_map(|s| s.parse::<HeaderValue>().ok()).collect::<Vec<_>>();
layer.allow_origin(origins)
},
_ => layer,
};
layer
}
pub async fn run(&self) -> Result<(), anyhow::Error> {
let ping_state = self.state.clone();
let addr = format!("0.0.0.0:{}", &self.state.config.server.port);
let service_routes = Router::new()
.route("/_matrix/app/v1/ping", post(ping))
.route("/_matrix/app/v1/transactions/{txn_id}", put(transactions))
.route_layer(middleware::from_fn_with_state(self.state.clone(), authenticate_homeserver));
let room_routes_inner = Router::new()
.route("/state", get(matrix_proxy))
.route("/state/{*path}", get(matrix_proxy))
.route("/events", get(matrix_proxy))
.route("/messages", get(matrix_proxy))
.route("/info", get(room_info))
.route("/joined_members", get(matrix_proxy))
.route("/members", get(matrix_proxy))
.route("/initialSync", get(matrix_proxy))
.route("/aliases", get(matrix_proxy))
.route("/event/{*path}", get(matrix_proxy))
.route("/context/{*path}", get(matrix_proxy))
.route("/timestamp_to_event", get(matrix_proxy));
let room_routes = Router::new()
.nest("/_matrix/client/v3/rooms/{room_id}", room_routes_inner)
.route_layer(middleware::from_fn_with_state(self.state.clone(), validate_public_room))
.route_layer(middleware::from_fn_with_state(self.state.clone(), validate_room_id));
let more_room_routes = Router::new()
.route("/_matrix/client/v1/rooms/{room_id}/hierarchy", get(matrix_proxy))
.route("/_matrix/client/v1/rooms/{room_id}/threads", get(matrix_proxy))
.route("/_matrix/client/v1/rooms/{room_id}/relations/{*path}", get(matrix_proxy))
.route_layer(middleware::from_fn_with_state(self.state.clone(), validate_public_room))
.route_layer(middleware::from_fn_with_state(self.state.clone(), validate_room_id));
let public_rooms_route = Router::new()
.route("/publicRooms", get(public_rooms));
let media_routes = Router::new()
.route("/_matrix/client/v1/media/preview_url", get(matrix_proxy))
.route("/_matrix/client/v1/media/thumbnail/{*path}", get(matrix_proxy))
.route("/_matrix/client/v1/media/download/{*path}", get(matrix_proxy));
let admin_routes = Router::new()
.route("/admin/room/{room_id}/join", put(join_room))
.route("/admin/room/{room_id}/leave", put(leave_room))
.route_layer(middleware::from_fn_with_state(self.state.clone(), is_admin));
let spaces_routes = Router::new()
.route("/spaces/{space}", get(space_summary))
.route("/spaces", get(spaces));
let search_route = Router::new()
.route("/_matrix/client/v3/search", post(matrix_proxy_search));
let app = Router::new()
.merge(service_routes)
.merge(room_routes)
.merge(more_room_routes)
.merge(media_routes)
.merge(public_rooms_route)
.merge(admin_routes)
.merge(spaces_routes)
.merge(search_route)
.route("/version", get(version))
.route("/identity", get(identity))
.route("/health", get(health))
.route("/", get(index))
.layer(self.setup_cors(&self.state.config))
.layer(middleware::from_fn_with_state(self.state.clone(), add_data))
.layer(TraceLayer::new_for_http())
.with_state(self.state.clone());
let app = NormalizePathLayer::trim_trailing_slash()
.layer(app);
tokio::spawn(async move {
info!("Pinging homeserver...");
let txn_id = ping_state.transaction_store.generate_transaction_id().await;
let ping = ping_state.appservice.ping_homeserver(txn_id.clone()).await;
match ping {
Ok(_) => info!("Homeserver pinged successfully."),
Err(e) => tracing::info!("Failed to ping homeserver: {}", e),
}
});
if let Ok(listener) = tokio::net::TcpListener::bind(addr.clone()).await {
axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?;
} else {
tracing::info!("Failed to bind to address: {}", addr);
std::process::exit(1);
}
Ok(())
}
}
async fn index() -> &'static str {
"Commune public appservice.\n"
}
pub async fn version(
) -> Result<impl IntoResponse, ()> {
let version = env!("CARGO_PKG_VERSION");
let hash = env!("GIT_COMMIT_HASH");
Ok(Json(json!({
"version": version,
"commit": hash,
})))
}
pub async fn identity(
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, ()> {
let user = format!(
"@{}:{}",
state.config.appservice.sender_localpart,
state.config.matrix.server_name
);
Ok(Json(json!({
"user": user,
})))
}
pub async fn health(
State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, AppserviceError> {
let _ = state.appservice.health_check().await
.map_err(|_| AppserviceError::AppserviceError("Health check failed".to_string()))?;
Ok(Json(json!({
"status": "ok",
})))
}