Skip to main content

onebot_api/communication/
http_post.rs

1use super::utils::*;
2use crate::error::{ServiceStartError, ServiceStartResult};
3use async_trait::async_trait;
4use axum::Router;
5use axum::extract::State;
6use axum::response::IntoResponse;
7use axum::routing::any;
8use hmac::{Hmac, Mac};
9use http::{HeaderMap, StatusCode};
10use sha1::Sha1;
11use std::sync::Arc;
12use tokio::net::{TcpListener, ToSocketAddrs};
13use tokio::sync::broadcast;
14
15type HmacSha1 = Hmac<Sha1>;
16
17pub struct HttpPostService<T: ToSocketAddrs + Clone + Send + Sync> {
18	addr: T,
19	hmac: Option<HmacSha1>,
20	event_sender: Option<EventSender>,
21	close_signal_sender: broadcast::Sender<()>,
22	prefix: String,
23}
24
25impl<T: ToSocketAddrs + Clone + Send + Sync> Drop for HttpPostService<T> {
26	fn drop(&mut self) {
27		let _ = self.close_signal_sender.send(());
28	}
29}
30
31impl<T: ToSocketAddrs + Clone + Send + Sync> HttpPostService<T> {
32	pub fn new(addr: T, prefix: Option<String>, secret: Option<String>) -> anyhow::Result<Self> {
33		let (close_signal_sender, _) = broadcast::channel(1);
34		let hmac = if let Some(secret) = secret {
35			Some(HmacSha1::new_from_slice(secret.as_ref())?)
36		} else {
37			None
38		};
39		let mut prefix = prefix.unwrap_or("/".to_string());
40		if !prefix.starts_with("/") {
41			prefix = "/".to_string() + &prefix;
42		}
43		Ok(Self {
44			addr,
45			hmac,
46			event_sender: None,
47			close_signal_sender,
48			prefix,
49		})
50	}
51}
52
53struct AppState {
54	hmac: Option<HmacSha1>,
55	event_sender: EventSender,
56}
57
58pub fn get_sig(mut hmac: HmacSha1, content: &[u8]) -> String {
59	hmac.update(content);
60	let result = hmac.finalize().into_bytes();
61	hex::encode(result)
62}
63
64async fn processor(
65	headers: HeaderMap,
66	State(state): State<Arc<AppState>>,
67	body: String,
68) -> impl IntoResponse {
69	if state.hmac.is_some() {
70		let received_sig = headers.get("X-Signature").map(|v| v.to_str().unwrap());
71		if received_sig.is_none() {
72			return StatusCode::UNAUTHORIZED;
73		}
74		let received_sig = received_sig.unwrap();
75		let hmac = state.hmac.clone().unwrap();
76		let sig = get_sig(hmac, body.as_ref());
77		if received_sig != "sha1=".to_string() + sig.as_str() {
78			return StatusCode::FORBIDDEN;
79		}
80	}
81	let event = serde_json::from_str(&body).unwrap();
82	let _ = state.event_sender.send(Arc::new(event));
83	StatusCode::NO_CONTENT
84}
85
86#[async_trait]
87impl<T: ToSocketAddrs + Clone + Send + Sync> CommunicationService for HttpPostService<T> {
88	fn inject(&mut self, _api_receiver: APIReceiver, event_sender: EventSender) {
89		self.event_sender = Some(event_sender);
90	}
91
92	async fn start_service(&self) -> ServiceStartResult<()> {
93		if self.event_sender.is_none() {
94			return Err(ServiceStartError::NotInjectedEventSender);
95		}
96
97		let event_sender = self.event_sender.clone().unwrap();
98
99		let state = Arc::new(AppState {
100			event_sender,
101			hmac: self.hmac.clone(),
102		});
103
104		let listener = TcpListener::bind(self.addr.clone()).await?;
105		let router = Router::new()
106			.route(&self.prefix, any(processor))
107			.with_state(state);
108		let mut close_signal = self.close_signal_sender.subscribe();
109
110		tokio::spawn(
111			axum::serve(listener, router)
112				.with_graceful_shutdown(async move {
113					let _ = close_signal.recv().await;
114				})
115				.into_future(),
116		);
117
118		Ok(())
119	}
120}