use std::io;
use thiserror::Error;
use url::Url;
use crate::{
coroutine::{DiscoveryCoroutine, DiscoveryCoroutineState, DiscoveryYield},
rfc6186::{
discover::{DiscoverySrv, DiscoverySrvError},
types::SrvReport,
},
shared::pool::{Stream, StreamPool},
};
const READ_BUFFER_SIZE: usize = 8 * 1024;
#[derive(Debug, Error)]
pub enum DiscoverySrvClientStdError {
#[error(transparent)]
Discovery(#[from] DiscoverySrvError),
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Pool(#[from] anyhow::Error),
}
pub struct DiscoverySrvClientStd {
dns: Url,
pool: StreamPool,
}
impl DiscoverySrvClientStd {
pub fn new(dns: Url) -> Self {
Self {
dns,
pool: StreamPool::new(),
}
}
pub fn with_factory<F, S>(mut self, scheme: &'static str, factory: F) -> Self
where
F: FnMut(&Url) -> anyhow::Result<S> + 'static,
S: Stream + 'static,
{
self.pool = self.pool.with_factory(scheme, factory);
self
}
pub fn discover(&mut self, domain: &str) -> Result<SrvReport, DiscoverySrvClientStdError> {
let mut coroutine = DiscoverySrv::new(domain, self.dns.clone());
let mut buf = [0u8; READ_BUFFER_SIZE];
let mut arg: Option<&[u8]> = None;
loop {
match coroutine.resume(arg.take()) {
DiscoveryCoroutineState::Complete(Ok(report)) => return Ok(report),
DiscoveryCoroutineState::Complete(Err(err)) => return Err(err.into()),
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsRead { url }) => {
let stream = self.pool.get(&url)?;
let n = stream.read(&mut buf)?;
arg = Some(&buf[..n]);
}
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsWrite { url, bytes }) => {
let stream = self.pool.get(&url)?;
stream.write_all(&bytes)?;
}
}
}
}
}