use std::net::IpAddr;
use anyhow::Result;
use pkarr::dns::Name;
use pkarr::errors::PublishError;
use pkarr::{dns::rdata::SVCB, SignedPacket};
use crate::app_context::AppContext;
use tokio::task::JoinHandle;
use tokio::time::{interval, Duration};
pub struct HomeserverKeyRepublisher {
join_handle: JoinHandle<()>,
}
impl HomeserverKeyRepublisher {
pub async fn start(
context: &AppContext,
icann_http_port: u16,
pubky_tls_port: u16,
) -> Result<Self> {
let signed_packet = create_signed_packet(context, icann_http_port, pubky_tls_port)?;
let join_handle =
Self::start_periodic_republish(context.pkarr_client.clone(), &signed_packet).await?;
Ok(Self { join_handle })
}
async fn publish_once(
client: &pkarr::Client,
signed_packet: &SignedPacket,
) -> Result<(), PublishError> {
let res = client.publish(signed_packet, None).await;
if let Err(e) = &res {
tracing::warn!(
"Failed to publish the homeserver's pkarr packet to the DHT: {}",
e
);
} else {
tracing::info!("Published the homeserver's pkarr packet to the DHT.");
}
res
}
async fn start_periodic_republish(
client: pkarr::Client,
signed_packet: &SignedPacket,
) -> anyhow::Result<JoinHandle<()>> {
Self::publish_once(&client, signed_packet).await?;
let signed_packet = signed_packet.clone();
let handle = tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60 * 60)); interval.tick().await; loop {
interval.tick().await;
let _ = Self::publish_once(&client, &signed_packet).await;
}
});
Ok(handle)
}
pub fn stop(&self) {
self.join_handle.abort();
}
}
impl Drop for HomeserverKeyRepublisher {
fn drop(&mut self) {
self.stop();
}
}
pub fn create_signed_packet(
context: &AppContext,
local_icann_http_port: u16,
local_pubky_tls_port: u16,
) -> Result<SignedPacket> {
let root_name: Name = "."
.try_into()
.expect(". is the root domain and always valid");
let mut signed_packet_builder = SignedPacket::builder();
let public_ip = context.config_toml.pkdns.public_ip;
let public_pubky_tls_port = context
.config_toml
.pkdns
.public_pubky_tls_port
.unwrap_or(local_pubky_tls_port);
let public_icann_http_port = context
.config_toml
.pkdns
.public_icann_http_port
.unwrap_or(local_icann_http_port);
let mut svcb = SVCB::new(1, root_name.clone());
svcb.set_port(public_pubky_tls_port);
match &public_ip {
IpAddr::V4(ip) => {
svcb.set_ipv4hint([ip.to_bits()])?;
}
IpAddr::V6(ip) => {
svcb.set_ipv6hint([ip.to_bits()])?;
}
};
signed_packet_builder = signed_packet_builder.https(root_name.clone(), svcb, 60 * 60);
if let Some(domain) = &context.config_toml.pkdns.icann_domain {
let mut svcb = SVCB::new(10, root_name.clone());
let http_port_be_bytes = public_icann_http_port.to_be_bytes();
if domain.0 == "localhost" {
svcb.set_param(
pubky_common::constants::reserved_param_keys::HTTP_PORT,
&http_port_be_bytes,
)?;
}
svcb.target = domain.0.as_str().try_into()?;
signed_packet_builder = signed_packet_builder.https(root_name.clone(), svcb, 60 * 60);
}
signed_packet_builder = signed_packet_builder.address(root_name.clone(), public_ip, 60 * 60);
Ok(signed_packet_builder.build(&context.keypair)?)
}
#[cfg(test)]
mod tests {
use futures_lite::StreamExt;
use pkarr::extra::endpoints::Endpoint;
use std::net::{Ipv4Addr, SocketAddr};
use super::*;
#[tokio::test]
async fn test_resolve_https_endpoint_with_pkarr_client() {
let context = AppContext::test();
let _republisher = HomeserverKeyRepublisher::start(&context, 8080, 8080)
.await
.unwrap();
let pkarr_client = context.pkarr_client.clone();
let hs_pubky = context.keypair.public_key();
let _packet = pkarr_client.resolve(&hs_pubky).await.unwrap();
let qname = format!("{}", hs_pubky);
let endpoint = pkarr_client
.resolve_https_endpoint(qname.as_str())
.await
.unwrap();
assert_eq!(
endpoint.to_socket_addrs().first().unwrap().clone(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
);
}
#[tokio::test]
async fn test_endpoints() {
let mut context = AppContext::test();
context.keypair = pkarr::Keypair::random();
let _republisher = HomeserverKeyRepublisher::start(&context, 8080, 8080)
.await
.unwrap();
let pubkey = context.keypair.public_key();
let client = pkarr::Client::builder().build().unwrap();
let packet = client.resolve(&pubkey).await.unwrap();
let rr: Vec<&pkarr::dns::ResourceRecord> = packet.all_resource_records().collect();
assert_eq!(rr.len(), 3);
let endpoints: Vec<Endpoint> = client
.resolve_https_endpoints(&pubkey.to_z32())
.collect()
.await;
assert_eq!(endpoints.len(), 2);
}
}