use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
use crate::{
config::{HttpClient, HttpServer},
error::{WalleError, WalleResult},
util::{AuthReqHeaderExt, Echo, ProtocolItem, SelfId},
ActionHandler, EventHandler, OneBot,
};
use hyper::{
body::Buf,
client::HttpConnector,
header::{AUTHORIZATION, CONTENT_TYPE},
server::conn::Http,
service::service_fn,
Body, Client as HyperClient, Method, Request, Response,
};
use tokio::{net::TcpListener, sync::mpsc, task::JoinHandle};
use tracing::{info, warn};
use super::{AppOBC, BotMapExt, EchoMap};
impl<A, R> AppOBC<A, R>
where
A: ProtocolItem,
R: ProtocolItem,
{
pub(crate) async fn webhook<E, AH, EH>(
&self,
ob: &Arc<OneBot<AH, EH, 12>>,
config: Vec<HttpServer>,
tasks: &mut Vec<JoinHandle<()>>,
) -> WalleResult<()>
where
E: ProtocolItem + SelfId + Clone,
AH: ActionHandler<E, A, R, 12> + Send + Sync + 'static,
EH: EventHandler<E, A, R, 12> + Send + Sync + 'static,
{
for webhook in config {
let bot_map = self.bots.clone();
let echo_map = self.echos.clone();
let access_token = webhook.access_token.clone();
let mut signal_rx = ob.get_signal_rx()?;
let ob = ob.clone();
let addr = std::net::SocketAddr::new(webhook.host, webhook.port);
info!(
target: crate::WALLE_CORE,
"Starting HTTP Webhook server on http://{}", addr
);
let listener = TcpListener::bind(&addr).await.map_err(WalleError::from)?;
let serv = service_fn(move |req: Request<Body>| {
let access_token = access_token.clone();
let ob = ob.clone();
let bot_map = bot_map.clone();
let echo_map = echo_map.clone();
async move {
if let Some(token) = access_token.as_ref() {
if let Some(header_token) = req
.headers()
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
{
if header_token != format!("Bearer {}", token) {
return Ok(Response::builder()
.status(403)
.body("Authorization Header is invalid".into())
.unwrap());
}
} else {
return Ok(Response::builder()
.status(403)
.body("Missing Authorization Header".into())
.unwrap());
}
}
let body = String::from_utf8(
hyper::body::to_bytes(req.into_body())
.await
.unwrap()
.to_vec(),
)
.unwrap();
match E::json_decode(&body) {
Ok(event) => {
let (action_tx, mut action_rx) = mpsc::unbounded_channel();
let self_id = event.self_id();
bot_map.ensure_bot(&self_id, &action_tx);
if let Err(e) = ob.event_handler.call(event).await {
warn!(target: super::OBC, "{}", e);
}
if let Ok(Some(a)) = tokio::time::timeout(
std::time::Duration::from_secs(8),
action_rx.recv(),
)
.await
{
let echo_s = a.get_echo();
echo_map.remove(&echo_s);
bot_map.remove_bot(&self_id, &action_tx);
return Ok(Response::new(a.json_encode().into()));
}
}
Err(s) => warn!(target: crate::WALLE_CORE, "Webhook json error: {}", s),
}
Ok::<Response<Body>, Infallible>(Response::new("".into()))
}
});
tasks.push(tokio::spawn(async move {
loop {
let service = serv.clone();
tokio::select! {
_ = signal_rx.recv() => break,
Ok((tcp_stream, _)) = listener.accept() => {
tokio::spawn(async move {
Http::new()
.serve_connection(tcp_stream, service)
.await
.unwrap();
});
}
}
}
}));
}
Ok(())
}
pub(crate) async fn http<E, AH, EH>(
&self,
ob: &Arc<OneBot<AH, EH, 12>>,
config: HashMap<String, HttpClient>,
tasks: &mut Vec<JoinHandle<()>>,
) -> WalleResult<()>
where
E: ProtocolItem + SelfId + Clone,
AH: ActionHandler<E, A, R, 12> + Send + Sync + 'static,
EH: EventHandler<E, A, R, 12> + Send + Sync + 'static,
{
let client = Arc::new(HyperClient::new());
for (bot_id, http) in config {
let (tx, mut rx) = mpsc::unbounded_channel();
self.bots.ensure_bot(&bot_id, &tx);
let ob = ob.clone();
let cli = client.clone();
let echo_map = self.echos.clone();
let mut signal_rx = ob.get_signal_rx()?;
tasks.push(tokio::spawn(async move {
loop {
tokio::select! {
_ = signal_rx.recv() => break,
Some(action) = rx.recv() => {
tokio::spawn(http_push(
action,
cli.clone(),
http.clone(),
echo_map.clone(),
));
}
}
}
}));
}
Ok(())
}
}
async fn http_push<A, R>(
action: Echo<A>,
client: Arc<HyperClient<HttpConnector, Body>>,
http: HttpClient,
echo_map: EchoMap<R>,
) where
A: ProtocolItem,
R: ProtocolItem,
{
let (action, echo_s) = action.unpack();
let req = Request::builder()
.method(Method::POST)
.uri(&http.url)
.header_auth_token(&http.access_token)
.header(CONTENT_TYPE, crate::util::ContentType::Json.to_string())
.body(action.to_body(&crate::util::ContentType::Json)) .unwrap();
match tokio::time::timeout(Duration::from_secs(http.timeout), client.request(req)).await {
Ok(Ok(resp)) => {
let body = hyper::body::aggregate(resp).await.unwrap(); let r: R = serde_json::from_reader(body.reader()).unwrap();
if let Some((_, r_tx)) = echo_map.remove(&echo_s) {
r_tx.send(r).ok();
}
}
Ok(Err(e)) => {
warn!(target: crate::WALLE_CORE, "HTTP push error: {}", e);
}
Err(e) => {
warn!(target: crate::WALLE_CORE, "HTTP push timeout: {}", e);
}
}
}