1use axum::{
2 Extension, Json,
3 body::Body,
4 extract::{MatchedPath, OriginalUri, Path, State},
5 http::{Request, StatusCode, Uri, header::AUTHORIZATION},
6 middleware::Next,
7 response::IntoResponse,
8};
9
10use ruma::{OwnedRoomId, RoomAliasId, RoomId};
11
12use serde_json::{Value, json};
13
14use std::sync::Arc;
15
16use crate::AppState;
17use crate::utils::{room_alias_like, room_id_valid};
18
19use crate::error::AppserviceError;
20
21use crate::cache::CacheKey;
22
23pub fn extract_token(header: &str) -> Option<&str> {
24 header.strip_prefix("Bearer ").map(|token| token.trim())
25}
26
27fn unauthorized_error() -> (StatusCode, Json<Value>) {
28 (
29 StatusCode::UNAUTHORIZED,
30 Json(json!({
31 "errcode": "M_FORBIDDEN",
32 })),
33 )
34}
35
36pub async fn authenticate_homeserver(
37 State(state): State<Arc<AppState>>,
38 req: Request<Body>,
39 next: Next,
40) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
41 let token = req
42 .headers()
43 .get(AUTHORIZATION)
44 .ok_or(unauthorized_error())?
45 .to_str()
46 .map_err(|_| unauthorized_error())?;
47
48 let token = extract_token(token).ok_or(unauthorized_error())?;
49
50 if token != state.config.appservice.hs_access_token {
51 return Err(unauthorized_error());
52 }
53
54 Ok(next.run(req).await)
55}
56
57pub async fn is_admin(
58 req: Request<Body>,
60 next: Next,
61) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
62 let token = req
63 .headers()
64 .get(AUTHORIZATION)
65 .ok_or(unauthorized_error())?
66 .to_str()
67 .map_err(|_| unauthorized_error())?;
68
69 let token = extract_token(token).ok_or(unauthorized_error())?;
70
71 if token != "test" {
72 return Err(unauthorized_error());
73 }
74
75 Ok(next.run(req).await)
76}
77
78#[derive(Clone, Debug)]
79pub enum ProxyRequestType {
80 RoomState,
81 Messages,
82 Media,
83 Other,
84}
85
86#[derive(Clone, Debug)]
87pub struct Data {
88 pub modified_path: Option<String>,
89 pub room_id: Option<String>,
90 pub is_media_request: bool,
91 pub proxy_request_type: ProxyRequestType,
92}
93
94pub fn parse_request_type(req: &Request<Body>) -> ProxyRequestType {
95 match req.uri().path() {
96 path if path.ends_with("/state") => ProxyRequestType::RoomState,
97 path if path.ends_with("/messages") => ProxyRequestType::Messages,
98 path if path.starts_with("/_matrix/client/v1/media/") => ProxyRequestType::Media,
99 _ => ProxyRequestType::Other,
100 }
101}
102
103pub async fn add_data(
104 mut req: Request<Body>,
105 next: Next,
106) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
107 let data = Data {
108 modified_path: None,
109 room_id: None,
110 is_media_request: req.uri().path().starts_with("/_matrix/client/v1/media/"),
111 proxy_request_type: parse_request_type(&req),
112 };
113
114 req.extensions_mut().insert(data);
115
116 Ok(next.run(req).await)
117}
118
119pub async fn validate_room_id(
120 Path(params): Path<Vec<(String, String)>>,
121 State(state): State<Arc<AppState>>,
122 mut req: Request<Body>,
123 next: Next,
124) -> Result<impl IntoResponse, (StatusCode, Json<Value>)> {
125 let room_id = params[0].1.clone();
126
127 let server_name = state.config.matrix.server_name.clone();
128
129 let mut data = Data {
130 modified_path: None,
131 room_id: Some(room_id.clone()),
132 is_media_request: req.uri().path().starts_with("/_matrix/media/v1/download/"),
133 proxy_request_type: parse_request_type(&req),
134 };
135
136 if room_id_valid(&room_id, &server_name).is_ok() {
138 req.extensions_mut().insert(data);
139 return Ok(next.run(req).await);
140 }
141
142 let raw_alias = room_alias_like(&room_id)
144 .then_some(format!("#{room_id}"))
145 .unwrap_or_else(|| format!("#{room_id}:{server_name}"));
146
147 if let Ok(alias) = RoomAliasId::parse(&raw_alias) {
148 let id = state.appservice.room_id_from_alias(alias).await;
149 match id {
150 Ok(id) => {
151 data.room_id = Some(id.to_string());
152 }
153 Err(_) => {
154 tracing::info!("Failed to get room ID from alias: {}", raw_alias);
155 }
156 }
157 }
158
159 if let Some(path) = req.extensions().get::<MatchedPath>() {
160 let pattern = path.as_str();
161
162 let pattern_segments: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
164
165 let fullpath = if let Some(path) = req.extensions().get::<OriginalUri>() {
166 path.0.path()
167 } else {
168 req.uri().path()
169 };
170
171 let path_segments: Vec<&str> = fullpath.split('/').filter(|s| !s.is_empty()).collect();
172
173 if let Some(segment_index) = pattern_segments.iter().position(|&s| s == "{room_id}") {
174 let mut new_segments = path_segments.clone();
175 if segment_index < new_segments.len() {
176 new_segments[segment_index] = data.room_id.as_ref().unwrap_or(&room_id);
177
178 let new_path = format!("/{}", new_segments.join("/"));
180
181 let new_uri = if let Some(query) = req.uri().query() {
183 format!("{new_path}?{query}")
184 .parse::<Uri>()
185 .unwrap_or_default()
186 } else {
187 new_path.parse::<Uri>().unwrap_or_default()
188 };
189
190 data.modified_path = Some(new_uri.to_string());
191 }
192 }
193 }
194
195 req.extensions_mut().insert(data);
196
197 Ok(next.run(req).await)
198}
199
200pub async fn validate_public_room(
201 Extension(data): Extension<Data>,
202 State(state): State<Arc<AppState>>,
204 req: Request<Body>,
205 next: Next,
206) -> Result<impl IntoResponse, AppserviceError> {
207 let room_id = data
208 .room_id
209 .as_ref()
210 .ok_or(AppserviceError::AppserviceError(
211 "No room ID found".to_string(),
212 ))?;
213
214 let parsed_room_id = RoomId::parse(room_id)
215 .map_err(|_| AppserviceError::AppserviceError("Invalid room ID".to_string()))?;
216
217 let is_joined = check_room_membership(&state, room_id, parsed_room_id).await?;
218
219 if !is_joined {
220 return Err(AppserviceError::AppserviceError(
221 "Not a public room".to_string(),
222 ));
223 }
224
225 Ok(next.run(req).await)
226}
227
228async fn check_room_membership(
229 state: &AppState,
230 room_id: &str,
231 parsed_room_id: OwnedRoomId,
232) -> Result<bool, AppserviceError> {
233 if !state.config.cache.joined_rooms.enabled {
234 return check_membership_direct(state, room_id, parsed_room_id).await;
235 }
236
237 let cache_key = ("appservice:joined", room_id).cache_key();
238
239 match state.cache.get_cached_data::<bool>(&cache_key).await {
240 Ok(Some(cached_result)) => {
241 tracing::info!("Using cached joined status for room: {}", room_id);
242 Ok(cached_result)
243 }
244 Ok(None) | Err(_) => {
245 let joined = check_membership_direct(state, room_id, parsed_room_id).await?;
246
247 if joined {
248 cache_membership_result(state, &cache_key, room_id).await;
249 }
250
251 Ok(joined)
252 }
253 }
254}
255
256async fn check_membership_direct(
257 state: &AppState,
258 room_id: &str,
259 parsed_room_id: OwnedRoomId,
260) -> Result<bool, AppserviceError> {
261 state
262 .appservice
263 .has_joined_room(parsed_room_id)
264 .await
265 .map_err(|e| {
266 tracing::error!("Failed to check joined status for room {}: {}", room_id, e);
267 AppserviceError::AppserviceError("Failed to check joined status".to_string())
268 })
269}
270
271async fn cache_membership_result(state: &AppState, cache_key: &str, room_id: &str) {
272 match state.cache.cache_data(cache_key, &true, 300).await {
273 Ok(_) => tracing::info!("Cached joined status for room: {}", room_id),
274 Err(_) => tracing::warn!("Failed to cache joined status for room: {}", room_id),
275 }
276}