use axum::body::{Body, to_bytes};
use axum::extract::Request;
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use base64::{Engine, engine::general_purpose::STANDARD};
use md5::{Digest, Md5};
use rsa::pkcs8::DecodePublicKey;
use rsa::signature::hazmat::PrehashVerifier;
use rsa::{RsaPublicKey, pkcs1v15};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use tower::{Layer, Service};
#[derive(Debug, thiserror::Error)]
enum OssVerifyError<'a> {
#[error("missing required header `{0}`")]
MissingHeader(&'a str),
#[error("invalid header `{0}`")]
InvalidHeader(&'a str),
#[error("invalid oss callback signature")]
InvalidSignature,
#[error("failed to read request body: {0}")]
BodyRead(#[from] axum::Error),
#[error("http error when verifying oss public key: {0}")]
Http(#[from] reqwest::Error),
#[error("base64 decode error: {0}")]
Base64(#[from] base64::DecodeError),
#[error("utf-8 error: {0}")]
Utf8(#[from] std::string::FromUtf8Error),
#[error("error: {0}")]
Common(&'a str),
}
impl IntoResponse for OssVerifyError<'_> {
fn into_response(self) -> Response {
match self {
OssVerifyError::MissingHeader(name) => (
StatusCode::BAD_REQUEST,
format!("missing required header `{name}`"),
)
.into_response(),
OssVerifyError::InvalidHeader(name) => {
(StatusCode::BAD_REQUEST, format!("invalid header `{name}`")).into_response()
}
OssVerifyError::InvalidSignature => {
(StatusCode::BAD_REQUEST, "invalid oss callback signature").into_response()
}
OssVerifyError::BodyRead(e) => (
StatusCode::BAD_REQUEST,
format!("failed to read request body: {e}"),
)
.into_response(),
OssVerifyError::Http(e) => (
StatusCode::BAD_GATEWAY,
format!("http error when verifying oss public key: {e}"),
)
.into_response(),
OssVerifyError::Base64(e) => {
(StatusCode::BAD_REQUEST, format!("base64 decode error: {e}")).into_response()
}
OssVerifyError::Utf8(e) => {
(StatusCode::BAD_REQUEST, format!("utf-8 error: {e}")).into_response()
}
OssVerifyError::Common(msg) => {
(StatusCode::BAD_REQUEST, format!("error: {msg}")).into_response()
}
}
}
}
#[derive(Debug, Clone)]
pub struct VerifiedOssCallbackBody(pub String);
#[derive(Clone)]
pub struct OssCallbackVerifyLayer {
client: reqwest::Client,
cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
callback_path: String,
}
impl OssCallbackVerifyLayer {
pub fn new(callback_url_path: &str) -> Self {
Self {
client: reqwest::Client::new(),
cache: Arc::new(RwLock::new(HashMap::new())),
callback_path: callback_url_path.to_owned(),
}
}
}
impl<S> Layer<S> for OssCallbackVerifyLayer {
type Service = OssCallbackVerifyService<S>;
fn layer(&self, inner: S) -> Self::Service {
OssCallbackVerifyService {
inner,
client: self.client.clone(),
cache: Arc::clone(&self.cache),
callback_path: self.callback_path.clone(),
}
}
}
#[derive(Clone)]
pub struct OssCallbackVerifyService<S> {
inner: S,
client: reqwest::Client,
cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
callback_path: String,
}
impl<S> Service<Request<Body>> for OssCallbackVerifyService<S>
where
S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
S::Error: Into<axum::BoxError>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
let client = self.client.clone();
let cache = Arc::clone(&self.cache);
let callback_path = self.callback_path.clone();
Box::pin(async move {
match verify_oss_request(req, &client, &cache, callback_path).await {
Ok(verified_req) => inner.call(verified_req).await,
Err(resp) => Ok(resp.into_response()),
}
})
}
}
async fn verify_oss_request<'a>(
req: Request<Body>,
client: &reqwest::Client,
cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
callback_path: String,
) -> Result<Request<Body>, OssVerifyError<'a>> {
let (parts, body) = req.into_parts();
let headers = parts.headers.clone();
let uri = parts.uri.clone();
let body_bytes = to_bytes(body, 5 * 1024 * 1024).await?;
let body_str = String::from_utf8(body_bytes.to_vec())?;
let pub_key_url_b64 = header_required(&headers, "x-oss-pub-key-url")?;
let pub_key_url_raw = STANDARD.decode(pub_key_url_b64.as_bytes())?;
let pub_key_url = String::from_utf8(pub_key_url_raw)?;
if !pub_key_url.starts_with("http://gosspublic.alicdn.com/")
&& !pub_key_url.starts_with("https://gosspublic.alicdn.com/")
{
return Err(OssVerifyError::Common("invalid oss public key url"));
}
let pub_key_pem = get_or_fetch_pub_key(&pub_key_url, client, cache).await?;
let pub_key_pem_str = String::from_utf8(pub_key_pem)?;
let auth_b64 = header_required(&headers, "authorization")?;
let auth_bytes = STANDARD.decode(auth_b64.as_bytes())?;
let decoded_path = callback_path;
let auth_path = match uri.query() {
Some(q) => format!("{}?{}", decoded_path, q),
None => decoded_path,
};
let auth_str = format!("{}\n{}", auth_path, body_str);
let mut hasher = Md5::new();
hasher.update(auth_str.as_bytes());
let digest = hasher.finalize();
let rsa_pub_key = RsaPublicKey::from_public_key_pem(&pub_key_pem_str)
.map_err(|_| OssVerifyError::Common("failed to parse oss public key pem"))?;
let verifying_key = pkcs1v15::VerifyingKey::<Md5>::new(rsa_pub_key);
let signature = pkcs1v15::Signature::try_from(auth_bytes.as_slice())
.map_err(|_| OssVerifyError::Common("failed to parse oss signature"))?;
verifying_key
.verify_prehash(&digest, &signature)
.map_err(|_| OssVerifyError::InvalidSignature)?;
let mut new_req = Request::from_parts(parts, Body::from(body_bytes));
new_req
.extensions_mut()
.insert(VerifiedOssCallbackBody(body_str));
Ok(new_req)
}
fn header_required<'a>(headers: &HeaderMap, name: &'a str) -> Result<String, OssVerifyError<'a>> {
let value = headers
.get(name)
.ok_or(OssVerifyError::MissingHeader(name))?;
let s = value
.to_str()
.map_err(|_| OssVerifyError::InvalidHeader(name))?;
Ok(s.to_owned())
}
async fn get_or_fetch_pub_key<'a>(
url: &str,
client: &reqwest::Client,
cache: &Arc<RwLock<HashMap<String, Vec<u8>>>>,
) -> Result<Vec<u8>, OssVerifyError<'a>> {
{
let cache_read = cache.read().unwrap();
if let Some(v) = cache_read.get(url) {
return Ok(v.clone());
}
}
let resp = client.get(url).send().await?;
let bytes = resp.bytes().await?;
{
let mut cache_write = cache.write().unwrap();
if let Some(v) = cache_write.get(url) {
return Ok(v.clone());
} else {
cache_write.insert(url.to_string(), bytes.to_vec());
}
}
Ok(bytes.to_vec())
}