public_appservice/
server.rs

1use axum::{
2    Json, Router, ServiceExt,
3    extract::{Request, State},
4    http::HeaderValue,
5    middleware::{self},
6    response::IntoResponse,
7    routing::{get, post, put},
8};
9
10use std::sync::Arc;
11use tracing::info;
12
13use tower::Layer;
14use tower_http::cors::{Any, CorsLayer};
15use tower_http::normalize_path::NormalizePathLayer;
16use tower_http::trace::TraceLayer;
17
18use serde_json::json;
19
20use http::header::CONTENT_TYPE;
21
22use crate::error::AppserviceError;
23use anyhow;
24
25use crate::config::Config;
26use crate::middleware::{
27    add_data, authenticate_homeserver, is_admin, validate_public_room, validate_room_id,
28};
29use crate::rooms::{join_room, leave_room, public_rooms, room_info};
30
31use crate::ping::ping;
32
33use crate::api::transactions;
34use crate::requests::{matrix_proxy, matrix_proxy_search};
35
36use crate::space::{space, space_rooms, spaces};
37
38pub struct Server {
39    state: Arc<AppState>,
40}
41
42pub use crate::AppState;
43
44impl Server {
45    pub fn new(state: Arc<AppState>) -> Self {
46        Self { state }
47    }
48
49    pub fn setup_cors(&self, config: &Config) -> CorsLayer {
50        let mut layer = CorsLayer::new()
51            .allow_origin(Any)
52            .allow_headers(vec![CONTENT_TYPE]);
53
54        layer = match &config.server.allow_origin {
55            Some(origins)
56                if !origins.is_empty()
57                    && !origins.contains(&"".to_string())
58                    && !origins.contains(&"*".to_string()) =>
59            {
60                let origins = origins
61                    .iter()
62                    .filter_map(|s| s.parse::<HeaderValue>().ok())
63                    .collect::<Vec<_>>();
64                layer.allow_origin(origins)
65            }
66            _ => layer,
67        };
68
69        layer
70    }
71
72    pub async fn run(&self) -> Result<(), anyhow::Error> {
73        let ping_state = self.state.clone();
74
75        let addr = format!("0.0.0.0:{}", &self.state.config.server.port);
76
77        let service_routes = Router::new()
78            .route("/_matrix/app/v1/ping", post(ping))
79            .route("/_matrix/app/v1/transactions/{txn_id}", put(transactions))
80            .route_layer(middleware::from_fn_with_state(
81                self.state.clone(),
82                authenticate_homeserver,
83            ));
84
85        let room_routes_inner = Router::new()
86            .route("/state", get(matrix_proxy))
87            .route("/state/{*path}", get(matrix_proxy))
88            .route("/events", get(matrix_proxy))
89            .route("/messages", get(matrix_proxy))
90            .route("/info", get(room_info))
91            .route("/joined_members", get(matrix_proxy))
92            .route("/members", get(matrix_proxy))
93            .route("/initialSync", get(matrix_proxy))
94            .route("/aliases", get(matrix_proxy))
95            .route("/event/{*path}", get(matrix_proxy))
96            .route("/context/{*path}", get(matrix_proxy))
97            .route("/timestamp_to_event", get(matrix_proxy));
98
99        let room_routes = Router::new()
100            .nest("/_matrix/client/v3/rooms/{room_id}", room_routes_inner)
101            .route_layer(middleware::from_fn_with_state(
102                self.state.clone(),
103                validate_public_room,
104            ))
105            .route_layer(middleware::from_fn_with_state(
106                self.state.clone(),
107                validate_room_id,
108            ));
109
110        let more_room_routes = Router::new()
111            .route(
112                "/_matrix/client/v1/rooms/{room_id}/hierarchy",
113                get(matrix_proxy),
114            )
115            .route(
116                "/_matrix/client/v1/rooms/{room_id}/threads",
117                get(matrix_proxy),
118            )
119            .route(
120                "/_matrix/client/v1/rooms/{room_id}/relations/{*path}",
121                get(matrix_proxy),
122            )
123            .route_layer(middleware::from_fn_with_state(
124                self.state.clone(),
125                validate_public_room,
126            ))
127            .route_layer(middleware::from_fn_with_state(
128                self.state.clone(),
129                validate_room_id,
130            ));
131
132        let public_rooms_route = Router::new().route("/publicRooms", get(public_rooms));
133
134        let media_routes = Router::new()
135            .route("/_matrix/client/v1/media/preview_url", get(matrix_proxy))
136            .route(
137                "/_matrix/client/v1/media/thumbnail/{*path}",
138                get(matrix_proxy),
139            )
140            .route(
141                "/_matrix/client/v1/media/download/{*path}",
142                get(matrix_proxy),
143            );
144
145        let admin_routes = Router::new()
146            .route("/admin/room/{room_id}/join", put(join_room))
147            .route("/admin/room/{room_id}/leave", put(leave_room))
148            .route_layer(middleware::from_fn_with_state(self.state.clone(), is_admin));
149
150        let spaces_routes = Router::new()
151            .route("/spaces/{space}/rooms", get(space_rooms))
152            .route("/spaces/{space}", get(space))
153            .route("/spaces", get(spaces));
154
155        let search_route =
156            Router::new().route("/_matrix/client/v3/search", post(matrix_proxy_search));
157
158        let app = Router::new()
159            .merge(service_routes)
160            .merge(room_routes)
161            .merge(more_room_routes)
162            .merge(media_routes)
163            .merge(public_rooms_route)
164            .merge(admin_routes)
165            .merge(spaces_routes);
166
167        let app = if !self.state.config.search.disabled {
168            app.merge(search_route)
169        } else {
170            app
171        };
172
173        let app = app
174            .route("/version", get(version))
175            .route("/identity", get(identity))
176            .route("/health", get(health))
177            .route("/", get(index))
178            .layer(self.setup_cors(&self.state.config))
179            .layer(middleware::from_fn_with_state(self.state.clone(), add_data))
180            .layer(TraceLayer::new_for_http())
181            .with_state(self.state.clone());
182
183        let app = NormalizePathLayer::trim_trailing_slash().layer(app);
184
185        tokio::spawn(async move {
186            info!("Pinging homeserver...");
187            let txn_id = ping_state.transaction_store.generate_transaction_id().await;
188            let ping = ping_state.appservice.ping_homeserver(txn_id.clone()).await;
189            match ping {
190                Ok(_) => info!("Homeserver pinged successfully."),
191                Err(e) => tracing::info!("Failed to ping homeserver: {}", e),
192            }
193        });
194
195        if let Ok(listener) = tokio::net::TcpListener::bind(addr.clone()).await {
196            axum::serve(listener, ServiceExt::<Request>::into_make_service(app)).await?;
197        } else {
198            tracing::info!("Failed to bind to address: {}", addr);
199            return Err(anyhow::anyhow!("Failed to bind to address: {}", addr));
200        }
201
202        Ok(())
203    }
204}
205
206async fn index() -> &'static str {
207    "Commune public appservice.\n"
208}
209
210pub async fn version() -> Result<impl IntoResponse, ()> {
211    let version = env!("CARGO_PKG_VERSION");
212    let hash = env!("GIT_COMMIT_HASH");
213
214    Ok(Json(json!({
215        "version": version,
216        "commit": hash,
217    })))
218}
219
220pub async fn identity(State(state): State<Arc<AppState>>) -> Result<impl IntoResponse, ()> {
221    let user = format!(
222        "@{}:{}",
223        state.config.appservice.sender_localpart, state.config.matrix.server_name
224    );
225
226    Ok(Json(json!({
227        "user": user,
228    })))
229}
230
231pub async fn health(
232    State(state): State<Arc<AppState>>,
233) -> Result<impl IntoResponse, AppserviceError> {
234    state.appservice.health_check().await.map_err(|e| {
235        tracing::error!("Health check failed: {}", e);
236        AppserviceError::HomeserverError(
237            "Health check failed. Could not reach homeserver.".to_string(),
238        )
239    })?;
240
241    let user = format!(
242        "@{}:{}",
243        state.config.appservice.sender_localpart, state.config.matrix.server_name
244    );
245
246    let search_disabled = state.config.search.disabled;
247
248    let features = json!({
249        "search_disabled": search_disabled,
250    });
251
252    Ok(Json(json!({
253        "status": "ok",
254        "user_id": user,
255        "features": features,
256    })))
257}