public_appservice/
middleware.rs

1use axum::{
2    Extension, Json,
3    body::Body,
4    extract::{MatchedPath, OriginalUri, Path, State},
5    http::{Request, StatusCode, Uri, header::AUTHORIZATION},
6    middleware::Next,
7    response::IntoResponse,
8};
9
10use ruma::{OwnedRoomId, RoomAliasId, RoomId};
11
12use serde_json::{Value, json};
13
14use std::sync::Arc;
15
16use crate::AppState;
17use crate::utils::{room_alias_like, room_id_valid};
18
19use crate::error::AppserviceError;
20
21use crate::cache::CacheKey;
22
23pub fn extract_token(header: &str) -> Option<&str> {
24    header.strip_prefix("Bearer ").map(|token| token.trim())
25}
26
27fn unauthorized_error() -> (StatusCode, Json<Value>) {
28    (
29        StatusCode::UNAUTHORIZED,
30        Json(json!({
31            "errcode": "M_FORBIDDEN",
32        })),
33    )
34}
35
36pub async fn authenticate_homeserver(
37    State(state): State<Arc<AppState>>,
38    req: Request<Body>,
39    next: Next,
40) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
41    let token = req
42        .headers()
43        .get(AUTHORIZATION)
44        .ok_or(unauthorized_error())?
45        .to_str()
46        .map_err(|_| unauthorized_error())?;
47
48    let token = extract_token(token).ok_or(unauthorized_error())?;
49
50    if token != state.config.appservice.hs_access_token {
51        return Err(unauthorized_error());
52    }
53
54    Ok(next.run(req).await)
55}
56
57pub async fn is_admin(
58    //State(state): State<Arc<AppState>>,
59    req: Request<Body>,
60    next: Next,
61) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
62    let token = req
63        .headers()
64        .get(AUTHORIZATION)
65        .ok_or(unauthorized_error())?
66        .to_str()
67        .map_err(|_| unauthorized_error())?;
68
69    let token = extract_token(token).ok_or(unauthorized_error())?;
70
71    if token != "test" {
72        return Err(unauthorized_error());
73    }
74
75    Ok(next.run(req).await)
76}
77
78#[derive(Clone, Debug)]
79pub enum ProxyRequestType {
80    RoomState,
81    Messages,
82    Media,
83    Other,
84}
85
86#[derive(Clone, Debug)]
87pub struct Data {
88    pub modified_path: Option<String>,
89    pub room_id: Option<String>,
90    pub is_media_request: bool,
91    pub proxy_request_type: ProxyRequestType,
92}
93
94pub fn parse_request_type(req: &Request<Body>) -> ProxyRequestType {
95    match req.uri().path() {
96        path if path.ends_with("/state") => ProxyRequestType::RoomState,
97        path if path.ends_with("/messages") => ProxyRequestType::Messages,
98        path if path.starts_with("/_matrix/client/v1/media/") => ProxyRequestType::Media,
99        _ => ProxyRequestType::Other,
100    }
101}
102
103pub async fn add_data(
104    mut req: Request<Body>,
105    next: Next,
106) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
107    let data = Data {
108        modified_path: None,
109        room_id: None,
110        is_media_request: req.uri().path().starts_with("/_matrix/client/v1/media/"),
111        proxy_request_type: parse_request_type(&req),
112    };
113
114    req.extensions_mut().insert(data);
115
116    Ok(next.run(req).await)
117}
118
119pub async fn validate_room_id(
120    Path(params): Path<Vec<(String, String)>>,
121    State(state): State<Arc<AppState>>,
122    mut req: Request<Body>,
123    next: Next,
124) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
125    let room_id = params[0].1.clone();
126
127    let server_name = state.config.matrix.server_name.clone();
128
129    let mut data = Data {
130        modified_path: None,
131        room_id: Some(room_id.clone()),
132        is_media_request: req.uri().path().starts_with("/_matrix/media/v1/download/"),
133        proxy_request_type: parse_request_type(&req),
134    };
135
136    // This is a valid room_id, so move on
137    if room_id_valid(&room_id, &server_name).is_ok() {
138        req.extensions_mut().insert(data);
139        return Ok(next.run(req).await);
140    }
141
142    // If the alias is partial like room:server.com without the leading #, we assume it's a room alias
143    let raw_alias = room_alias_like(&room_id)
144        .then_some(format!("#{room_id}"))
145        .unwrap_or_else(|| format!("#{room_id}:{server_name}"));
146
147    if let Ok(alias) = RoomAliasId::parse(&raw_alias) {
148        let id = state.appservice.room_id_from_alias(alias).await;
149        match id {
150            Ok(id) => {
151                data.room_id = Some(id.to_string());
152            }
153            Err(_) => {
154                tracing::info!("Failed to get room ID from alias: {}", raw_alias);
155            }
156        }
157    }
158
159    if let Some(path) = req.extensions().get::<MatchedPath>() {
160        let pattern = path.as_str();
161
162        // Split into segments, skipping the empty first segment
163        let pattern_segments: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
164
165        let fullpath = if let Some(path) = req.extensions().get::<OriginalUri>() {
166            path.0.path()
167        } else {
168            req.uri().path()
169        };
170
171        let path_segments: Vec<&str> = fullpath.split('/').filter(|s| !s.is_empty()).collect();
172
173        if let Some(segment_index) = pattern_segments.iter().position(|&s| s == "{room_id}") {
174            let mut new_segments = path_segments.clone();
175            if segment_index < new_segments.len() {
176                new_segments[segment_index] = data.room_id.as_ref().unwrap_or(&room_id);
177
178                // Rebuild the path with leading slash
179                let new_path = format!("/{}", new_segments.join("/"));
180
181                // Preserve query string if it exists
182                let new_uri = if let Some(query) = req.uri().query() {
183                    format!("{new_path}?{query}")
184                        .parse::<Uri>()
185                        .unwrap_or_default()
186                } else {
187                    new_path.parse::<Uri>().unwrap_or_default()
188                };
189
190                data.modified_path = Some(new_uri.to_string());
191            }
192        }
193    }
194
195    req.extensions_mut().insert(data);
196
197    Ok(next.run(req).await)
198}
199
200pub async fn validate_public_room(
201    Extension(data): Extension<Data>,
202    //Path(params): Path<Vec<(String, String)>>,
203    State(state): State<Arc<AppState>>,
204    req: Request<Body>,
205    next: Next,
206) -> Result<impl IntoResponse, AppserviceError> {
207    let room_id = data
208        .room_id
209        .as_ref()
210        .ok_or(AppserviceError::AppserviceError(
211            "No room ID found".to_string(),
212        ))?;
213
214    let parsed_room_id = RoomId::parse(room_id)
215        .map_err(|_| AppserviceError::AppserviceError("Invalid room ID".to_string()))?;
216
217    let is_joined = check_room_membership(&state, room_id, parsed_room_id).await?;
218
219    if !is_joined {
220        return Err(AppserviceError::AppserviceError(
221            "Not a public room".to_string(),
222        ));
223    }
224
225    Ok(next.run(req).await)
226}
227
228async fn check_room_membership(
229    state: &AppState,
230    room_id: &str,
231    parsed_room_id: OwnedRoomId,
232) -> Result<bool, AppserviceError> {
233    if !state.config.cache.joined_rooms.enabled {
234        return check_membership_direct(state, room_id, parsed_room_id).await;
235    }
236
237    let cache_key = ("appservice:joined", room_id).cache_key();
238
239    match state.cache.get_cached_data::<bool>(&cache_key).await {
240        Ok(Some(cached_result)) => {
241            tracing::info!("Using cached joined status for room: {}", room_id);
242            Ok(cached_result)
243        }
244        Ok(None) | Err(_) => {
245            let joined = check_membership_direct(state, room_id, parsed_room_id).await?;
246
247            if joined {
248                cache_membership_result(state, &cache_key, room_id).await;
249            }
250
251            Ok(joined)
252        }
253    }
254}
255
256async fn check_membership_direct(
257    state: &AppState,
258    room_id: &str,
259    parsed_room_id: OwnedRoomId,
260) -> Result<bool, AppserviceError> {
261    state
262        .appservice
263        .has_joined_room(parsed_room_id)
264        .await
265        .map_err(|e| {
266            tracing::error!("Failed to check joined status for room {}: {}", room_id, e);
267            AppserviceError::AppserviceError("Failed to check joined status".to_string())
268        })
269}
270
271async fn cache_membership_result(state: &AppState, cache_key: &str, room_id: &str) {
272    match state.cache.cache_data(cache_key, &true, 300).await {
273        Ok(_) => tracing::info!("Cached joined status for room: {}", room_id),
274        Err(_) => tracing::warn!("Failed to cache joined status for room: {}", room_id),
275    }
276}