use std::io;
use domain::{
base::{
Name, Question, RelativeName, Rtype, ToName, ToRelativeName, name::LongChainError,
wire::ParseError,
},
rdata::Txt,
resolv::{StubResolver, lookup::srv::SrvError},
};
use http::{
Response, Uri,
uri::{InvalidUri, PathAndQuery, Scheme},
};
use hyper::body::Incoming;
use log::{info, warn};
use tower_service::Service;
use crate::{
CheckSupportError,
common::ServiceForUrlError,
dav::{CheckSupport, WebDavClient},
};
#[derive(thiserror::Error, Debug)]
pub enum ContextUrlError {
#[error("missing host in input URL")]
MissingHost,
#[error("host in input URL is not a valid domain: {0}")]
InvalidDomain(domain::base::name::FromStrError),
#[error("resolving DNS SRV records: {0}")]
DnsError(SrvError),
#[error("resolving context path via TXT records: {0}")]
TxtError(TxtError),
#[error("the service is decidedly not available")]
NotAvailable,
#[error("SRV records returned domain/port pair that could not be parsed: {0}")]
UnusableSrv(http::Error),
}
#[derive(thiserror::Error, Debug)]
pub enum BootstrapError {
#[error("cannot determine service for this url: {0}")]
ServiceForUrl(#[from] ServiceForUrlError),
#[error("discovering context url: {0}")]
ContextUrl(#[from] ContextUrlError),
#[error("no usable URL found for service")]
NoUsableUrl,
}
pub enum FindContextUrlResult {
BaseUrl,
Found(Uri),
NoneFound,
Error(ContextUrlError),
}
pub async fn find_context_url<C>(
client: &WebDavClient<C>,
service: DiscoverableService,
) -> FindContextUrlResult
where
C: Service<http::Request<String>, Response = Response<Incoming>> + Sync + Send,
<C as Service<http::Request<String>>>::Error: std::error::Error + Send + Sync,
{
match client
.request(CheckSupport::new(&client.base_url, service.access_field()))
.await
{
Ok(()) => return FindContextUrlResult::BaseUrl,
Err(err) => info!("Original URL does not report {service} capabilities: {err}"),
}
let Some(domain) = client.base_url.host() else {
return FindContextUrlResult::Error(ContextUrlError::MissingHost);
};
let port = client.base_url.port_u16().unwrap_or(service.default_port());
let dname = match Name::bytes_from_str(domain) {
Ok(d) => d,
Err(err) => return FindContextUrlResult::Error(ContextUrlError::InvalidDomain(err)),
};
let host_candidates = match resolve_srv_record(service, &dname, port).await {
Ok(Some(hc)) => hc,
Ok(None) => return FindContextUrlResult::Error(ContextUrlError::NotAvailable),
Err(err) => return FindContextUrlResult::Error(ContextUrlError::DnsError(err)),
};
let txt_record = match find_context_path_via_txt_records(service, &dname).await {
Ok(record) => record,
Err(err) => return FindContextUrlResult::Error(ContextUrlError::TxtError(err)),
};
for candidate in &host_candidates {
if let Some(ref path) = txt_record {
let test_uri = match Uri::builder()
.scheme(service.scheme())
.authority(format!("{}:{}", candidate.0, candidate.1))
.path_and_query(path.clone())
.build()
{
Ok(uri) => uri,
Err(err) => return FindContextUrlResult::Error(ContextUrlError::UnusableSrv(err)),
};
let result = client
.request(CheckSupport::new(&test_uri, service.access_field()))
.await;
match result {
Ok(()) => return FindContextUrlResult::Found(test_uri),
Err(CheckSupportError::NotAdvertised) => {
return FindContextUrlResult::Found(test_uri);
}
Err(_) => {
warn!("Found path that doesn't report {service} capabilities: {test_uri}");
}
}
} else if let Ok(Some(url)) = client
.find_context_path(service, &candidate.0, candidate.1)
.await
{
return FindContextUrlResult::Found(url);
}
}
FindContextUrlResult::NoneFound
}
#[derive(Debug, Clone, Copy)]
pub enum DiscoverableService {
CalDavs,
CalDav,
CardDavs,
CardDav,
}
impl DiscoverableService {
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn relative_domain(self) -> &'static RelativeName<[u8]> {
match self {
DiscoverableService::CalDavs => RelativeName::from_slice(b"\x08_caldavs\x04_tcp"),
DiscoverableService::CalDav => RelativeName::from_slice(b"\x07_caldav\x04_tcp"),
DiscoverableService::CardDavs => RelativeName::from_slice(b"\x09_carddavs\x04_tcp"),
DiscoverableService::CardDav => RelativeName::from_slice(b"\x08_carddav\x04_tcp"),
}
.expect("well known relative prefix is valid")
}
#[must_use]
pub fn scheme(self) -> Scheme {
match self {
DiscoverableService::CalDavs | DiscoverableService::CardDavs => Scheme::HTTPS,
DiscoverableService::CalDav | DiscoverableService::CardDav => Scheme::HTTP,
}
}
#[must_use]
pub fn well_known_path(self) -> &'static str {
match self {
DiscoverableService::CalDavs | DiscoverableService::CalDav => "/.well-known/caldav",
DiscoverableService::CardDavs | DiscoverableService::CardDav => "/.well-known/carddav",
}
}
#[must_use]
pub fn default_port(self) -> u16 {
match self {
DiscoverableService::CalDavs | DiscoverableService::CardDavs => 443,
DiscoverableService::CalDav | DiscoverableService::CardDav => 80,
}
}
#[must_use]
pub fn access_field(self) -> &'static str {
match self {
DiscoverableService::CalDavs | DiscoverableService::CalDav => "calendar-access",
DiscoverableService::CardDavs | DiscoverableService::CardDav => "addressbook",
}
}
}
impl std::fmt::Display for DiscoverableService {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DiscoverableService::CalDavs => write!(f, "caldavs"),
DiscoverableService::CalDav => write!(f, "caldav"),
DiscoverableService::CardDavs => write!(f, "carddavs"),
DiscoverableService::CardDav => write!(f, "carddav"),
}
}
}
pub async fn resolve_srv_record(
service: DiscoverableService,
domain: &impl ToName,
fallback_port: u16,
) -> Result<Option<Vec<(String, u16)>>, SrvError> {
Ok(StubResolver::new()
.lookup_srv(service.relative_domain(), domain, fallback_port)
.await?
.map(|found| {
found
.into_srvs()
.map(|entry| (entry.target().to_string(), entry.port()))
.collect()
}))
}
#[derive(thiserror::Error, Debug)]
pub enum TxtError {
#[error("I/O error performing DNS request: {0}")]
Network(#[from] io::Error),
#[error("domain name is too long and cannot be queried: {0}")]
DomainTooLong(#[from] LongChainError),
#[error("parsing DNS response: {0}")]
ParseError(#[from] ParseError),
#[error("invalid data in response: {0}")]
InvalidData(#[from] InvalidUri),
#[error("missing expected prefix path= from TXT record.")]
BadTxt,
}
pub async fn find_context_path_via_txt_records(
service: DiscoverableService,
domain: impl ToName,
) -> Result<Option<PathAndQuery>, TxtError> {
let resolver = StubResolver::new();
let full_domain = service.relative_domain().chain(domain)?;
let question = Question::new_in(full_domain, Rtype::TXT);
let response = resolver.query(question).await?;
let Some(record) = response.answer()?.next() else {
return Ok(None);
};
let Some(parsed_record) = record?.into_record::<Txt<_>>()? else {
return Ok(None);
};
let path_result = parsed_record
.data()
.text::<Vec<u8>>()
.strip_prefix(b"path=")
.ok_or(TxtError::BadTxt)
.map(PathAndQuery::try_from)?
.map_err(TxtError::InvalidData);
Some(path_result).transpose()
}