use std::io;
use thiserror::Error;
use url::Url;
use crate::{
coroutine::{DiscoveryCoroutine, DiscoveryCoroutineState, DiscoveryYield},
pacc::{
discover::{DiscoveryPacc, DiscoveryPaccError},
types::PaccConfig,
},
shared::pool::{Stream, StreamPool},
};
const READ_BUFFER_SIZE: usize = 8 * 1024;
#[derive(Debug, Error)]
pub enum DiscoveryPaccClientStdError {
#[error(transparent)]
Discovery(#[from] DiscoveryPaccError),
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Pool(#[from] anyhow::Error),
}
pub struct DiscoveryPaccClientStd {
dns: Url,
pool: StreamPool,
}
impl DiscoveryPaccClientStd {
pub fn new(dns: Url) -> Self {
Self {
dns,
pool: StreamPool::new(),
}
}
pub fn with_factory<F, S>(mut self, scheme: &'static str, factory: F) -> Self
where
F: FnMut(&Url) -> anyhow::Result<S> + 'static,
S: Stream + 'static,
{
self.pool = self.pool.with_factory(scheme, factory);
self
}
#[cfg(feature = "stream")]
pub fn with_tls(mut self, tls: pimalaya_stream::tls::Tls) -> Self {
self.pool = self.pool.with_http_factories(tls);
self
}
pub fn discover(&mut self, domain: &str) -> Result<PaccConfig, DiscoveryPaccClientStdError> {
let mut coroutine = DiscoveryPacc::new(domain, self.dns.clone())?;
let mut buf = [0u8; READ_BUFFER_SIZE];
let mut arg: Option<&[u8]> = None;
loop {
match coroutine.resume(arg.take()) {
DiscoveryCoroutineState::Complete(Ok(config)) => return Ok(config),
DiscoveryCoroutineState::Complete(Err(err)) => return Err(err.into()),
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsRead { url }) => {
let stream = self.pool.get(&url)?;
let n = stream.read(&mut buf)?;
arg = Some(&buf[..n]);
}
DiscoveryCoroutineState::Yielded(DiscoveryYield::WantsWrite { url, bytes }) => {
let stream = self.pool.get(&url)?;
stream.write_all(&bytes)?;
}
}
}
}
}