use std::io;
use thiserror::Error;
use url::Url;
use crate::{
coroutine::{DiscoveryCoroutine, DiscoveryCoroutineState, DiscoveryYield},
rfc6764::{
discover::{DiscoveryRfc6764, DiscoveryRfc6764Error},
types::Rfc6764Report,
},
shared::pool::{Stream, StreamPool},
};
const READ_BUFFER_SIZE: usize = 8 * 1024;
#[derive(Debug, Error)]
pub enum DiscoveryRfc6764ClientStdError {
#[error(transparent)]
Discovery(#[from] DiscoveryRfc6764Error),
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Pool(#[from] anyhow::Error),
}
pub struct DiscoveryRfc6764ClientStd {
dns: Url,
pool: StreamPool,
}
impl DiscoveryRfc6764ClientStd {
pub fn new(dns: Url) -> Self {
Self {
dns,
pool: StreamPool::new(),
}
}
pub fn with_factory<F, S>(mut self, scheme: &'static str, factory: F) -> Self
where
F: FnMut(&Url) -> anyhow::Result<S> + 'static,
S: Stream + 'static,
{
self.pool = self.pool.with_factory(scheme, factory);
self
}
pub fn discover(
&mut self,
domain: &str,
) -> Result<Rfc6764Report, DiscoveryRfc6764ClientStdError> {
let mut coroutine = DiscoveryRfc6764::new(domain, self.dns.clone());
let mut buf = [0u8; READ_BUFFER_SIZE];
let mut arg: Option<&[u8]> = None;
loop {
match coroutine.resume(arg.take()) {
DiscoveryCoroutineState::Complete(Ok(report)) => return Ok(report),
DiscoveryCoroutineState::Complete(Err(err)) => return Err(err.into()),
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsRead { url }) => {
let stream = self.pool.get(&url)?;
let n = stream.read(&mut buf)?;
arg = Some(&buf[..n]);
}
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsWrite { url, bytes }) => {
let stream = self.pool.get(&url)?;
stream.write_all(&bytes)?;
}
}
}
}
}