use std::sync::Arc;
use std::time::Duration;
use futures::channel::mpsc as FuturesMpsc;
use futures::prelude::*;
use hyper::client::connect::{Connect, HttpConnector};
use hyper::{self, client::Client};
use hyper_tls::HttpsConnector;
use tokio::task::JoinHandle;
use tracing::{info, trace};
use crate::config::Config;
use crate::deliverable::Deliverable;
use crate::error::{RequestError, SpawnError};
use crate::pool::ConnectorAdaptor;
use crate::transaction::Transaction;
use raii_counter::{Counter, WeakCounter};
mod transaction_counter;
pub use self::transaction_counter::TransactionCounter;
pub(crate) struct Executor<D: Deliverable, C: 'static + Connect> {
client: Arc<Client<C>>,
transaction_counter: WeakCounter,
transaction_timeout: Duration,
receiver: FuturesMpsc::UnboundedReceiver<ExecutorMessage<D>>,
}
pub(crate) struct ExecutorHandle<D: Deliverable> {
transaction_counter: WeakCounter,
worker_counter: Counter,
max_transactions: usize,
sender: FuturesMpsc::UnboundedSender<ExecutorMessage<D>>,
join_handle: JoinHandle<()>,
}
type ExecutorMessage<D> = (Transaction<D>, Counter);
impl<D: Deliverable> ExecutorHandle<D> {
pub(crate) fn send(&mut self, transaction: Transaction<D>) -> Result<(), RequestError<D>> {
if self.is_full() {
return Err(RequestError::PoolFull(transaction));
}
let payload = (transaction, self.transaction_counter.spawn_upgrade());
if let Err(err) = self.sender.unbounded_send(payload) {
let (transaction, _counter) = err.into_inner();
return Err(RequestError::FailedSend(transaction));
}
Ok(())
}
pub(crate) fn shutdown(self) -> JoinHandle<()> {
self.join_handle
}
pub(crate) fn transaction_counter(&self) -> TransactionCounter {
TransactionCounter::new(
WeakCounter::clone(&self.transaction_counter),
Counter::clone(&self.worker_counter).downgrade(),
)
}
fn is_full(&self) -> bool {
self.transaction_counter.count() >= self.max_transactions
}
}
impl<D: Deliverable, C: 'static + Connect + Clone + Send + Sync> Executor<D, C> {
pub fn spawn<A, R>(config: &Config, resolver: R) -> Result<ExecutorHandle<D>, SpawnError>
where
A: ConnectorAdaptor<R, Connect = C>,
{
let (tx, rx) = FuturesMpsc::unbounded();
let weak_counter = WeakCounter::new();
let weak_counter_clone = weak_counter.clone();
let keep_alive_timeout = config.keep_alive_timeout;
let transaction_timeout = config.transaction_timeout.clone();
let tls = tokio_native_tls::TlsConnector::from(native_tls::TlsConnector::new()?);
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
http.set_nodelay(true);
http.set_keepalive(Some(keep_alive_timeout));
let connector = A::wrap(HttpsConnector::from((http, tls)));
let client = Arc::new(
Client::builder()
.pool_idle_timeout(Some(keep_alive_timeout))
.build(connector),
);
let executor = Executor::<D, C> {
receiver: rx,
transaction_counter: weak_counter_clone,
client,
transaction_timeout,
};
let join_handle = tokio::spawn(executor.run());
Ok(ExecutorHandle {
transaction_counter: weak_counter,
worker_counter: Counter::new(),
max_transactions: config.max_transactions_per_worker,
sender: tx,
join_handle,
})
}
async fn run(mut self) {
while let Some((transaction, counter)) = self.receiver.next().await {
trace!("Executor: spawning transaction.");
transaction.spawn_request(
Arc::clone(&self.client),
self.transaction_timeout.clone(),
counter,
);
}
while self.transaction_counter.count() > 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
info!("Executor exited.");
}
}