use std::{fmt, io::Read};
use async_recursion::async_recursion;
use futures::{stream::FuturesUnordered, StreamExt};
use http::ureq::{
http::{Response, Uri},
Body,
};
use sha1::{Digest, Sha1};
use tracing::debug;
use crate::{
native::{Deserializable, SignedPublicKey},
utils::spawn,
Error, Result,
};
struct EmailAddress {
pub local_part: String,
pub domain: String,
}
impl EmailAddress {
pub fn from(email_address: impl AsRef<str>) -> Result<Self> {
let email_address = email_address.as_ref();
let v: Vec<&str> = email_address.split('@').collect();
if v.len() != 2 {
return Err(Error::ParseEmailAddressError(email_address.into()));
};
let email = EmailAddress {
local_part: v[0].to_string(),
domain: v[1].to_lowercase(),
};
Ok(email)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
enum Variant {
#[default]
Advanced,
Direct,
}
#[derive(Debug, Clone)]
struct Url {
domain: String,
local_encoded: String,
local_part: String,
}
impl fmt::Display for Url {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.build(None))
}
}
impl Url {
pub fn from(email_address: impl AsRef<str>) -> Result<Self> {
let email = EmailAddress::from(email_address)?;
let local_encoded = encode_local_part(email.local_part.to_lowercase());
let url = Url {
domain: email.domain,
local_encoded,
local_part: email.local_part,
};
Ok(url)
}
pub fn build<V>(&self, variant: V) -> String
where
V: Into<Option<Variant>>,
{
let variant = variant.into().unwrap_or_default();
if variant == Variant::Direct {
format!(
"https://{}/.well-known/openpgpkey/hu/{}?l={}",
self.domain, self.local_encoded, self.local_part
)
} else {
format!(
"https://openpgpkey.{}/.well-known/openpgpkey/{}/hu/{}\
?l={}",
self.domain, self.domain, self.local_encoded, self.local_part
)
}
}
pub fn to_uri<V>(&self, variant: V) -> Result<Uri>
where
V: Into<Option<Variant>>,
{
let url_string = self.build(variant);
let uri = url_string
.as_str()
.parse::<Uri>()
.map_err(|err| Error::ParseUriError(err.into(), url_string.clone()))?;
Ok(uri)
}
}
fn encode_local_part<S: AsRef<str>>(local_part: S) -> String {
let local_part = local_part.as_ref();
let mut hasher = Sha1::new();
hasher.update(local_part.as_bytes());
let digest = hasher.finalize();
zbase32::encode(&digest[..])
}
#[async_recursion]
async fn get_following_redirects(
client: &http::Client,
url: Uri,
depth: i32,
) -> Result<Response<Body>> {
let response = client.send(move |agent| agent.get(url).call()).await;
if depth < 0 {
return Err(Error::RedirectOverflowError);
}
if let Ok(ref resp) = response {
if resp.status().is_redirection() {
let url = resp
.headers()
.get("Location")
.and_then(|value| value.to_str().ok())
.map(|value| value.parse::<Uri>());
if let Some(Ok(url)) = url {
return get_following_redirects(client, url, depth - 1).await;
}
}
}
Ok(response?)
}
async fn get(client: &http::Client, email: &String) -> Result<SignedPublicKey> {
let wkd_url = Url::from(email)?;
let uri = wkd_url.to_uri(Variant::Advanced)?;
const REDIRECT_LIMIT: i32 = 10;
let res = match get_following_redirects(client, uri.clone(), REDIRECT_LIMIT).await {
Ok(res) => Ok(res),
Err(_) => {
let uri = wkd_url.to_uri(Variant::Direct)?;
get_following_redirects(client, uri.clone(), REDIRECT_LIMIT).await
}
}?;
let status = res.status();
let mut body = res.into_body();
let mut body = body.as_reader();
if !status.is_success() {
let mut err = String::new();
body.read_to_string(&mut err)
.map_err(|err| Error::ReadHttpError(err, uri.clone(), status))?;
return Err(Error::GetPublicKeyError(err, uri, status));
}
let pkey = SignedPublicKey::from_bytes(body).map_err(Error::ParseCertError)?;
Ok(pkey)
}
pub async fn get_one(email: String) -> Result<SignedPublicKey> {
let client = http::Client::new();
self::get(&client, &email).await
}
pub async fn get_all(emails: Vec<String>) -> Vec<(String, Result<SignedPublicKey>)> {
let client = http::Client::new();
FuturesUnordered::from_iter(emails.into_iter().map(|email| {
let client = client.clone();
spawn(async move { (email.clone(), self::get(&client, &email).await) })
}))
.filter_map(|res| async {
match res {
Ok(res) => {
return Some(res);
}
Err(err) => {
debug!(?err, "skipping failed task");
None
}
}
})
.collect()
.await
}