1use std::sync::Arc;
2
3use axum::{
4 extract::Query,
5 http::{header, Request, StatusCode},
6 middleware::{self, Next},
7 response,
8 routing::post,
9 routing::{delete, get},
10 Extension, Json, Router,
11};
12
13use crate::service::{self, model::*};
14
15#[derive(Clone)]
16pub struct Auth(String, String);
17
18async fn auth<B>(mut req: Request<B>, next: Next<B>) -> Result<response::Response, StatusCode> {
19 let token = req
20 .headers()
21 .get(header::AUTHORIZATION)
22 .and_then(|header| header.to_str().ok())
23 .ok_or(StatusCode::UNAUTHORIZED)?;
24 let mut tokens = token.split(" ").collect::<Vec<_>>();
25
26 let token = if tokens.len() != 2 {
27 return Err(StatusCode::UNAUTHORIZED);
28 } else {
29 tokens.pop().unwrap()
30 };
31
32 let token = String::from_utf8(base64::decode(token).map_err(|e| StatusCode::UNAUTHORIZED)?)
33 .map_err(|e| StatusCode::UNAUTHORIZED)?;
34
35 let user_pass = token.split(":").collect::<Vec<_>>();
36 if user_pass.len() != 2 {
37 return Err(StatusCode::UNAUTHORIZED);
38 }
39
40 if let [client_id, client_secret] = user_pass[..] {
41 let app = req
42 .extensions()
43 .get::<Arc<service::App>>()
44 .expect("missing extension `service::App`");
45
46 let _ = app
47 .validate_app(client_id, client_secret)
48 .await
49 .map_err(|e| StatusCode::UNAUTHORIZED)?;
50
51 req.extensions_mut()
52 .insert(Auth(user_pass[0].to_string(), user_pass[1].to_string()));
53 return Ok(next.run(req).await);
54 }
55
56 Err(StatusCode::UNAUTHORIZED)
57}
58
59pub async fn register_token(
60 Extension(app): Extension<Arc<service::App>>,
61 Extension(Auth(client_id, _)): Extension<Auth>,
62 Json(params): Json<Vec<RegisterTokenParams>>,
63) -> Json<Response<RegisterTokenResp>> {
64 let mut success = 0;
65 let mut failure = 0;
66 let mut failure_tokens = Vec::new();
67 let mut errors = Vec::new();
68 for params in params {
69 match app
70 .register_token(
71 &client_id,
72 ¶ms.group,
73 ¶ms.ch_id,
74 ¶ms.token,
75 params._override,
76 )
77 .await
78 {
79 Ok(_) => {
80 success += 1;
81 }
82 Err(e) => {
83 failure += 1;
84 failure_tokens.push(params.token);
85 errors.push(e.to_string());
86 }
87 }
88 }
89
90 let resp = ResponseBuilder::default()
91 .data(Some(RegisterTokenResp {
92 success,
93 failure,
94 failure_tokens,
95 }))
96 .errors(Some(errors))
97 .build()
98 .unwrap();
99
100 Json(resp)
101}
102
103pub async fn revoke_token(
104 Extension(app): Extension<Arc<service::App>>,
105 Extension(Auth(client_id, _)): Extension<Auth>,
106 Json(params): Json<Vec<RevokeTokenParams>>,
107) -> Json<Response<RevokeTokenResp>> {
108 let mut success = 0;
109 let mut failure = 0;
110 let mut failure_tokens = Vec::new();
111
112 let mut errors = Vec::new();
113
114 for param in params.into_iter() {
115 match app
116 .revoke_token(&client_id, ¶m.group, ¶m.ch_id, ¶m.token)
117 .await
118 {
119 Ok(_) => {
120 success += 1;
121 }
122 Err(e) => {
123 failure += 1;
124 failure_tokens.push(param.token);
125 errors.push(e.to_string());
126 }
127 }
128 }
129
130 let resp = ResponseBuilder::default()
131 .data(Some(RevokeTokenResp {
132 success,
133 failure,
134 failure_tokens,
135 }))
136 .errors(Some(errors))
137 .build()
138 .unwrap();
139
140 Json(resp)
141}
142
143pub async fn push_transparent(
144 Extension(app): Extension<Arc<service::App>>,
145 Extension(Auth(client_id, _)): Extension<Auth>,
146 Json(params): Json<PushTransparentParams>,
147) -> Json<Response> {
148 let res = app
149 .push_message(&client_id, service::Message::Transparent(params))
150 .await;
151
152 let mut resp = ResponseBuilder::default();
153 match res {
154 Ok(_) => {}
155 Err(e) => {
156 resp.errors(Some(vec![e.to_string()]));
157 }
158 }
159
160 Json(resp.build().unwrap())
161}
162
163pub async fn push_notification(
164 Extension(app): Extension<Arc<service::App>>,
165 Extension(Auth(client_id, _)): Extension<Auth>,
166 Json(params): Json<PushNotificationParams>,
167) -> Json<Response<PushResp>> {
168 let resp = match app
169 .push_message(&client_id, service::Message::Notification(params))
170 .await
171 {
172 Ok(res) => Response {
173 data: Some(res),
174 code: Default::default(),
175 msg: Default::default(),
176 errors: Default::default(),
177 },
178 Err(e) => e.into(),
179 };
180
181 Json(resp)
182}
183
184pub async fn ping() -> Json<Response<String>> {
185 let resp = ResponseBuilder::<String>::default()
186 .msg("PONG".to_string())
187 .build()
188 .unwrap();
189 Json(resp)
190}
191
192pub async fn create_channel(
193 Extension(app): Extension<Arc<service::App>>,
194 Extension(Auth(client_id, _)): Extension<Auth>,
195 Json(params): Json<PublicChannel>,
196) -> Json<Response> {
197 let resp = match app.create_channel(&client_id, params).await {
198 Ok(ch_id) => ResponseBuilder::default()
199 .data(Some(ch_id))
200 .build()
201 .unwrap(),
202 Err(e) => e.into(),
203 };
204 Json(resp)
205}
206
207pub async fn delete_channel(
208 Extension(app): Extension<Arc<service::App>>,
209 Extension(Auth(client_id, _)): Extension<Auth>,
210 Query(params): Query<DeleteChannelParams>,
211) -> Json<Response> {
212 let resp = match app.delete_channel(&client_id, ¶ms.ch_id).await {
213 Ok(_) => ResponseBuilder::default().build().unwrap(),
214 Err(e) => e.into(),
215 };
216 Json(resp)
217}
218
219pub async fn fetch_channels(
220 Extension(app): Extension<Arc<service::App>>,
221 Extension(Auth(client_id, _)): Extension<Auth>,
222) -> Json<Response<Vec<Channel>>> {
223 let resp = match app.fetch_channels(&client_id).await {
224 Ok(chans) => ResponseBuilder::default()
225 .data(Some(chans))
226 .build()
227 .unwrap(),
228 Err(e) => e.into(),
229 };
230 Json(resp)
231}
232
233pub async fn delete_app(
234 Extension(app): Extension<Arc<service::App>>,
235 Extension(Auth(client_id, client_secret)): Extension<Auth>,
236) -> Json<Response> {
237 match app.delete_app(&client_id, &client_secret).await {
238 Ok(_) => Json(ResponseBuilder::default().build().unwrap()),
239 Err(e) => Json(e.into()),
240 }
241}
242
243pub async fn create_app(
244 Extension(app): Extension<Arc<service::App>>,
245 Json(params): Json<CreateAppParams>,
246) -> Json<Response<App>> {
247 match app.create_app(¶ms.name).await {
248 Ok(app) => Json(ResponseBuilder::default().data(Some(app)).build().unwrap()),
249 Err(e) => Json(e.into()),
250 }
251}
252
253pub async fn fetch_applications(
254 Extension(app): Extension<Arc<service::App>>,
255) -> Json<Response<Vec<App>>> {
256 match app.fetch_apps().await {
257 Ok(apps) => Json(ResponseBuilder::default().data(Some(apps)).build().unwrap()),
258 Err(e) => Json(e.into()),
259 }
260}
261
262pub async fn status(
263 Extension(app): Extension<Arc<service::App>>,
264 Extension(Auth(client_id, _)): Extension<Auth>,
265) -> Json<Response<Running>> {
266 let resp = ResponseBuilder::default()
267 .data(Some(Running {
268 ch_ids: app.running_ch_ids().await,
269 }))
270 .msg(client_id)
271 .build()
272 .unwrap();
273 Json(resp)
274}
275
276#[inline]
277pub fn api_router() -> axum::Router {
278 axum::Router::new()
279 .route("/register", post(register_token))
280 .route("/revoke", post(revoke_token))
281 .route("/transparent", post(push_transparent))
282 .route("/notification", post(push_notification))
283 .route("/ping", get(ping))
284 .route("/channels", get(fetch_channels))
285 .route("/channel", post(create_channel))
286 .route("/channel", delete(delete_channel))
287 .route("/applications", get(fetch_applications))
288 .route("/application", post(create_app))
289 .route("/application", delete(delete_app))
290 .route("/status", get(status))
291}
292
293pub trait ServiceOption {
307 fn apply(&self, r: &mut Router);
308}
309
310pub async fn start(
311 mut app: service::App,
312 addr: impl Into<Option<&str>>,
313 options: Option<Vec<Box<dyn ServiceOption>>>,
314) -> anyhow::Result<()> {
315 app.init().await.expect("app init error");
316 let mut router = axum::Router::new()
317 .nest("/api", api_router())
318 .layer(middleware::from_fn(auth))
319 .layer(Extension(Arc::new(app)));
320
321 for option in options.unwrap_or_default().into_iter() {
322 let option: Box<dyn ServiceOption> = option.into();
323 option.apply(&mut router);
324 }
325
326 Ok(
327 axum::Server::bind(&addr.into().unwrap_or("0.0.0.0:8080").parse().unwrap())
328 .serve(router.into_make_service())
329 .await?,
330 )
331}
332
333#[cfg(test)]
334mod test {}