Skip to main content

onebot_api/communication/
http_post.rs

1use super::utils::*;
2use async_trait::async_trait;
3use axum::Router;
4use axum::extract::State;
5use axum::response::IntoResponse;
6use axum::routing::any;
7use hmac::{Hmac, Mac};
8use http::{HeaderMap, StatusCode};
9use sha1::Sha1;
10use std::sync::Arc;
11use tokio::net::{TcpListener, ToSocketAddrs};
12use tokio::sync::broadcast;
13
14type HmacSha1 = Hmac<Sha1>;
15
16pub struct HttpPostService<T: ToSocketAddrs + Clone + Send + Sync> {
17	addr: T,
18	hmac: Option<HmacSha1>,
19	event_sender: Option<EventSender>,
20	close_signal_sender: broadcast::Sender<()>,
21	prefix: String,
22}
23
24impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for HttpPostService<T> {
25	fn drop(&mut self) {
26		let _ = self.close_signal_sender.send(());
27	}
28}
29
30impl<T: ToSocketAddrs + Clone + Send + Sync> HttpPostService<T> {
31	pub fn new(addr: T, prefix: Option<String>, secret: Option<String>) -> anyhow::Result<Self> {
32		let (close_signal_sender, _) = broadcast::channel(1);
33		let hmac = if let Some(secret) = secret {
34			Some(HmacSha1::new_from_slice(secret.as_ref())?)
35		} else {
36			None
37		};
38		let mut prefix = prefix.unwrap_or("/".to_string());
39		if !prefix.starts_with("/") {
40			prefix = "/".to_string() + &prefix;
41		}
42		Ok(Self {
43			addr,
44			hmac,
45			event_sender: None,
46			close_signal_sender,
47			prefix,
48		})
49	}
50}
51
52struct AppState {
53	hmac: Option<HmacSha1>,
54	event_sender: EventSender,
55}
56
57pub fn get_sig(mut hmac: HmacSha1, content: &[u8]) -> String {
58	hmac.update(content);
59	let result = hmac.finalize().into_bytes();
60	hex::encode(result)
61}
62
63async fn processor(
64	headers: HeaderMap,
65	State(state): State<Arc<AppState>>,
66	body: String,
67) -> impl IntoResponse {
68	if state.hmac.is_some() {
69		let received_sig = headers.get("X-Signature").map(|v| v.to_str().unwrap());
70		if received_sig.is_none() {
71			return StatusCode::UNAUTHORIZED;
72		}
73		let received_sig = received_sig.unwrap();
74		let hmac = state.hmac.clone().unwrap();
75		let sig = get_sig(hmac, body.as_ref());
76		if received_sig != "sha1=".to_string() + sig.as_str() {
77			return StatusCode::FORBIDDEN;
78		}
79	}
80	let event = serde_json::from_str(&body).unwrap();
81	let _ = state.event_sender.send(Arc::new(event));
82	StatusCode::NO_CONTENT
83}
84
85#[async_trait]
86impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for HttpPostService<T> {
87	fn inject(&mut self, _api_receiver: APIReceiver, event_sender: EventSender) {
88		self.event_sender = Some(event_sender);
89	}
90
91	async fn start_service(&self) -> anyhow::Result<()> {
92		if self.event_sender.is_none() {
93			return Err(anyhow::anyhow!("event sender is none"));
94		}
95
96		let event_sender = self.event_sender.clone().unwrap();
97
98		let state = Arc::new(AppState {
99			event_sender,
100			hmac: self.hmac.clone(),
101		});
102
103		let listener = TcpListener::bind(self.addr.clone()).await?;
104		let router = Router::new()
105			.route(&self.prefix, any(processor))
106			.with_state(state);
107		let mut close_signal = self.close_signal_sender.subscribe();
108
109		tokio::spawn(
110			axum::serve(listener, router)
111				.with_graceful_shutdown(async move {
112					let _ = close_signal.recv().await;
113				})
114				.into_future(),
115		);
116
117		Ok(())
118	}
119}