public-appservice 0.2.2

An appservice to make Matrix spaces publicly accessible.
Documentation
use axum::{
    Json, Router, ServiceExt,
    extract::{Request, State},
    http::HeaderValue,
    middleware::{self},
    response::IntoResponse,
    routing::{get, post, put},
};

use std::sync::Arc;
use tracing::info;

use tower::Layer;
use tower_http::cors::{Any, CorsLayer};
use tower_http::normalize_path::NormalizePathLayer;
use tower_http::trace::TraceLayer;

use serde_json::json;

use http::header::CONTENT_TYPE;

use crate::error::AppserviceError;
use anyhow;

use crate::config::Config;
use crate::middleware::{
    add_data, authenticate_homeserver, is_admin, validate_public_room, validate_room_id,
};
use crate::rooms::{join_room, leave_room, public_rooms, room_info};

use crate::ping::ping;

use crate::api::transactions;
use crate::requests::{matrix_proxy, matrix_proxy_search};

use crate::space::{space, space_rooms, spaces};

pub struct Server {
    state: Arc<AppState>,
}

pub use crate::AppState;

impl Server {
    pub fn new(state: Arc<AppState>) -> Self {
        Self { state }
    }

    pub fn setup_cors(&self, config: &Config) -> CorsLayer {
        let mut layer = CorsLayer::new()
            .allow_origin(Any)
            .allow_headers(vec![CONTENT_TYPE]);

        layer = match &config.server.allow_origin {
            Some(origins)
                if !origins.is_empty()
                    && !origins.contains(&"".to_string())
                    && !origins.contains(&"*".to_string()) =>
            {
                let origins = origins
                    .iter()
                    .filter_map(|s| s.parse::<HeaderValue>().ok())
                    .collect::<Vec<_>>();
                layer.allow_origin(origins)
            }
            _ => layer,
        };

        layer
    }

    pub async fn run(&self) -> Result<(), anyhow::Error> {
        let ping_state = self.state.clone();

        let addr = format!("0.0.0.0:{}", &self.state.config.server.port);

        let service_routes = Router::new()
            .route("/_matrix/app/v1/ping", post(ping))
            .route("/_matrix/app/v1/transactions/{txn_id}", put(transactions))
            .route_layer(middleware::from_fn_with_state(
                self.state.clone(),
                authenticate_homeserver,
            ));

        let room_routes_inner = Router::new()
            .route("/state", get(matrix_proxy))
            .route("/state/{*path}", get(matrix_proxy))
            .route("/events", get(matrix_proxy))
            .route("/messages", get(matrix_proxy))
            .route("/info", get(room_info))
            .route("/joined_members", get(matrix_proxy))
            .route("/members", get(matrix_proxy))
            .route("/initialSync", get(matrix_proxy))
            .route("/aliases", get(matrix_proxy))
            .route("/event/{*path}", get(matrix_proxy))
            .route("/context/{*path}", get(matrix_proxy))
            .route("/timestamp_to_event", get(matrix_proxy));

        let room_routes = Router::new()
            .nest("/_matrix/client/v3/rooms/{room_id}", room_routes_inner)
            .route_layer(middleware::from_fn_with_state(
                self.state.clone(),
                validate_public_room,
            ))
            .route_layer(middleware::from_fn_with_state(
                self.state.clone(),
                validate_room_id,
            ));

        let more_room_routes = Router::new()
            .route(
                "/_matrix/client/v1/rooms/{room_id}/hierarchy",
                get(matrix_proxy),
            )
            .route(
                "/_matrix/client/v1/rooms/{room_id}/threads",
                get(matrix_proxy),
            )
            .route(
                "/_matrix/client/v1/rooms/{room_id}/relations/{*path}",
                get(matrix_proxy),
            )
            .route_layer(middleware::from_fn_with_state(
                self.state.clone(),
                validate_public_room,
            ))
            .route_layer(middleware::from_fn_with_state(
                self.state.clone(),
                validate_room_id,
            ));

        let public_rooms_route = Router::new().route("/publicRooms", get(public_rooms));

        let media_routes = Router::new()
            .route("/_matrix/client/v1/media/preview_url", get(matrix_proxy))
            .route(
                "/_matrix/client/v1/media/thumbnail/{*path}",
                get(matrix_proxy),
            )
            .route(
                "/_matrix/client/v1/media/download/{*path}",
                get(matrix_proxy),
            );

        let admin_routes = Router::new()
            .route("/admin/room/{room_id}/join", put(join_room))
            .route("/admin/room/{room_id}/leave", put(leave_room))
            .route_layer(middleware::from_fn_with_state(self.state.clone(), is_admin));

        let spaces_routes = Router::new()
            .route("/spaces/{space}/rooms", get(space_rooms))
            .route("/spaces/{space}", get(space))
            .route("/spaces", get(spaces));

        let search_route =
            Router::new().route("/_matrix/client/v3/search", post(matrix_proxy_search));

        let app = Router::new()
            .merge(service_routes)
            .merge(room_routes)
            .merge(more_room_routes)
            .merge(media_routes)
            .merge(public_rooms_route)
            .merge(admin_routes)
            .merge(spaces_routes);

        let app = if !self.state.config.search.disabled {
            app.merge(search_route)
        } else {
            app
        };

        let app = app
            .route("/version", get(version))
            .route("/identity", get(identity))
            .route("/health", get(health))
            .route("/", get(index))
            .layer(self.setup_cors(&self.state.config))
            .layer(middleware::from_fn_with_state(self.state.clone(), add_data))
            .layer(TraceLayer::new_for_http())
            .with_state(self.state.clone());

        let app = NormalizePathLayer::trim_trailing_slash().layer(app);

        tokio::spawn(async move {
            info!("Pinging homeserver...");
            let txn_id = ping_state.transaction_store.generate_transaction_id().await;
            let ping = ping_state.appservice.ping_homeserver(txn_id.clone()).await;
            match ping {
                Ok(_) => info!("Homeserver pinged successfully."),
                Err(e) => tracing::info!("Failed to ping homeserver: {}", e),
            }
        });

        if let Ok(listener) = tokio::net::TcpListener::bind(addr.clone()).await {
            axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?;
        } else {
            tracing::info!("Failed to bind to address: {}", addr);
            return Err(anyhow::anyhow!("Failed to bind to address: {}", addr));
        }

        Ok(())
    }
}

async fn index() -> &'static str {
    "Commune public appservice.\n"
}

pub async fn version() -> Result<impl IntoResponse, ()> {
    let version = env!("CARGO_PKG_VERSION");
    let hash = env!("GIT_COMMIT_HASH");

    Ok(Json(json!({
        "version": version,
        "commit": hash,
    })))
}

pub async fn identity(State(state): State<Arc<AppState>>) -> Result<impl IntoResponse, ()> {
    let user = format!(
        "@{}:{}",
        state.config.appservice.sender_localpart, state.config.matrix.server_name
    );

    Ok(Json(json!({
        "user": user,
    })))
}

pub async fn health(
    State(state): State<Arc<AppState>>,
) -> Result<impl IntoResponse, AppserviceError> {
    state.appservice.health_check().await.map_err(|e| {
        tracing::error!("Health check failed: {}", e);
        AppserviceError::HomeserverError(
            "Health check failed. Could not reach homeserver.".to_string(),
        )
    })?;

    let user = format!(
        "@{}:{}",
        state.config.appservice.sender_localpart, state.config.matrix.server_name
    );

    let search_disabled = state.config.search.disabled;

    let features = json!({
        "search_disabled": search_disabled,
    });

    Ok(Json(json!({
        "status": "ok",
        "user_id": user,
        "features": features,
    })))
}