use core::mem;
use alloc::{
format, str,
string::{String, ToString},
vec::Vec,
};
use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
use log::trace;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use thiserror::Error;
use url::{ParseError, Url};
use crate::{
coroutine::{DiscoveryCoroutine, DiscoveryCoroutineState, DiscoveryYield},
pacc::types::PaccConfig,
shared::{
dns::{DiscoveryDnsTxt, DiscoveryDnsTxtError},
http::{HttpGet, HttpGetError},
},
};
#[derive(Debug, Error)]
pub enum DiscoveryPaccError {
#[error("PACC URL for domain `{1}` is not valid")]
InvalidUrl(#[source] ParseError, String),
#[error("no `_ua-auto-config` TXT record matched the configuration body")]
NoValidTxtRecord,
#[error("PACC body matched the published digest but is not valid JSON")]
Json(#[source] serde_json::Error),
#[error(transparent)]
Http(#[from] HttpGetError),
#[error(transparent)]
Dns(#[from] DiscoveryDnsTxtError),
}
#[derive(Default)]
enum State {
Get,
Verify,
#[default]
Done,
}
pub struct DiscoveryPacc {
state: State,
fetch: HttpGet,
verify: DiscoveryDnsTxt,
raw_body: Vec<u8>,
}
impl DiscoveryPacc {
pub fn url(domain: impl AsRef<str>) -> Result<Url, DiscoveryPaccError> {
let d = domain.as_ref().trim_matches('.');
let url = format!("https://ua-auto-config.{d}/.well-known/user-agent-configuration.json");
Url::parse(&url).map_err(|err| DiscoveryPaccError::InvalidUrl(err, d.to_string()))
}
pub fn new(domain: impl AsRef<str>, resolver: Url) -> Result<Self, DiscoveryPaccError> {
let url = Self::url(domain.as_ref())?;
let qname = format!("_ua-auto-config.{}", domain.as_ref().trim_matches('.'));
Ok(Self {
state: State::Get,
fetch: HttpGet::new(url),
verify: DiscoveryDnsTxt::new(qname, resolver),
raw_body: Vec::new(),
})
}
}
impl DiscoveryCoroutine for DiscoveryPacc {
type Yield = DiscoveryYield;
type Return = Result<PaccConfig, DiscoveryPaccError>;
fn resume(
&mut self,
mut arg: Option<&[u8]>,
) -> DiscoveryCoroutineState<Self::Yield, Self::Return> {
loop {
match mem::take(&mut self.state) {
State::Get => match self.fetch.resume(arg.take()) {
DiscoveryCoroutineState::Yielded(y) => {
self.state = State::Get;
return DiscoveryCoroutineState::Yielded(y);
}
DiscoveryCoroutineState::Complete(Ok(bytes)) => {
self.raw_body = bytes;
self.state = State::Verify;
}
DiscoveryCoroutineState::Complete(Err(err)) => {
return DiscoveryCoroutineState::Complete(Err(err.into()));
}
},
State::Verify => match self.verify.resume(arg.take()) {
DiscoveryCoroutineState::Yielded(y) => {
self.state = State::Verify;
return DiscoveryCoroutineState::Yielded(y);
}
DiscoveryCoroutineState::Complete(Err(err)) => {
return DiscoveryCoroutineState::Complete(Err(err.into()));
}
DiscoveryCoroutineState::Complete(Ok(records)) => {
for record in records {
let mut config = Vec::new();
for data in record.rdata.iter() {
config.extend_from_slice(&data.octets);
}
let Ok(config) = str::from_utf8(&config) else {
trace!("invalid UTF-8 TXT record, skip");
continue;
};
let mut v = None;
let mut a = None;
let mut d = None;
for tag in config.split(';') {
let Some((name, val)) = tag.split_once('=') else {
continue;
};
match name.trim() {
n if n.eq_ignore_ascii_case("v") => v = Some(val.trim()),
n if n.eq_ignore_ascii_case("a") => a = Some(val.trim()),
n if n.eq_ignore_ascii_case("d") => d = Some(val.trim()),
_ => continue,
}
}
let (Some(v), Some(a), Some(d)) = (v, a, d) else {
trace!("missing v, a or d in TXT record, skip");
continue;
};
if !v.eq_ignore_ascii_case("UAAC1") {
trace!("invalid `v`: expect `UAAC1` got `{v}`, skip");
continue;
}
if !a.eq_ignore_ascii_case("sha256") {
trace!("invalid `a`: expect `sha256` got `{a}`, skip");
continue;
}
let expected_digest = match BASE64.decode(d) {
Ok(digest) => {
trace!("expected digest: {digest:x?}");
digest
}
Err(err) => {
trace!("invalid base64 digest `{d}`, skip: {err}");
continue;
}
};
let actual_digest = Sha256::digest(&self.raw_body);
trace!("actual digest: {actual_digest:x?}");
if !bool::from(expected_digest.ct_eq(&actual_digest)) {
trace!("digest mismatch, skip");
continue;
}
return match serde_json::from_slice(&self.raw_body) {
Ok(config) => DiscoveryCoroutineState::Complete(Ok(config)),
Err(err) => DiscoveryCoroutineState::Complete(Err(
DiscoveryPaccError::Json(err),
)),
};
}
return DiscoveryCoroutineState::Complete(Err(
DiscoveryPaccError::NoValidTxtRecord,
));
}
},
State::Done => panic!("DiscoveryPacc::resume called after completion"),
}
}
}
}