use self::{selector::Selector, timer::Timer, util::IntHasher};
use crate::{error::Error, handler::RequestHandler, task::WakerExt};
use async_channel::{Receiver, Sender};
use crossbeam_utils::{atomic::AtomicCell, sync::WaitGroup};
use curl::multi::{Events, Multi, Socket, SocketEvents};
use futures_lite::future::block_on;
use slab::Slab;
use std::{
collections::HashMap,
hash::BuildHasherDefault,
io,
sync::{Arc, Mutex},
task::Waker,
thread,
time::{Duration, Instant},
};
mod selector;
mod timer;
mod util;
static NEXT_AGENT_ID: AtomicCell<usize> = AtomicCell::new(0);
const WAIT_TIMEOUT: Duration = Duration::from_millis(1000);
type EasyHandle = curl::easy::Easy2<RequestHandler>;
#[derive(Debug, Default)]
pub(crate) struct AgentBuilder {
max_connections: usize,
max_connections_per_host: usize,
connection_cache_size: usize,
}
impl AgentBuilder {
pub(crate) fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
pub(crate) fn max_connections_per_host(mut self, max: usize) -> Self {
self.max_connections_per_host = max;
self
}
pub(crate) fn connection_cache_size(mut self, size: usize) -> Self {
self.connection_cache_size = size;
self
}
pub(crate) fn spawn(&self) -> io::Result<Handle> {
let create_start = Instant::now();
curl::init();
let id = NEXT_AGENT_ID.fetch_add(1);
let selector = Selector::new()?;
let (message_tx, message_rx) = async_channel::unbounded();
let wait_group = WaitGroup::new();
let wait_group_thread = wait_group.clone();
let AgentBuilder {
max_connections,
max_connections_per_host,
connection_cache_size,
} = *self;
let agent_span = tracing::debug_span!("agent_thread", id);
agent_span.follows_from(tracing::Span::current());
let waker = selector.waker();
let message_tx_clone = message_tx.clone();
let thread_main = move || {
let _enter = agent_span.enter();
let mut multi = Multi::new();
if max_connections > 0 {
multi
.set_max_total_connections(max_connections)
.map_err(Error::from_any)?;
}
if max_connections_per_host > 0 {
multi
.set_max_host_connections(max_connections_per_host)
.map_err(Error::from_any)?;
}
if connection_cache_size > 0 {
multi
.set_max_connects(connection_cache_size)
.map_err(Error::from_any)?;
}
let agent = AgentContext::new(multi, selector, message_tx_clone, message_rx)?;
drop(wait_group_thread);
tracing::debug!("agent took {:?} to start up", create_start.elapsed());
let result = agent.run();
if let Err(error) = &result {
tracing::error!(?error, "agent shut down with error");
}
result
};
let handle = Handle {
message_tx,
waker,
join_handle: Mutex::new(Some(
thread::Builder::new()
.name(format!("isahc-agent-{}", id))
.spawn(thread_main)?,
)),
};
wait_group.wait();
Ok(handle)
}
}
#[derive(Debug)]
pub(crate) struct Handle {
message_tx: Sender<Message>,
waker: Waker,
join_handle: Mutex<Option<thread::JoinHandle<Result<(), Error>>>>,
}
struct AgentContext {
multi: curl::multi::Multi,
message_tx: Sender<Message>,
message_rx: Receiver<Message>,
requests: Slab<curl::multi::Easy2Handle<RequestHandler>>,
close_requested: bool,
waker: Waker,
selector: Selector,
timer: Arc<Timer>,
socket_updates: Arc<Mutex<HashMap<Socket, SocketEvents, BuildHasherDefault<IntHasher>>>>,
}
#[derive(Debug)]
enum Message {
Close,
Execute(EasyHandle),
UnpauseRead(usize),
UnpauseWrite(usize),
}
#[derive(Debug)]
enum JoinResult {
AlreadyJoined,
Ok,
Err(Error),
Panic,
}
impl Handle {
pub(crate) fn submit_request(&self, request: EasyHandle) -> Result<(), Error> {
self.send_message(Message::Execute(request))
}
fn send_message(&self, message: Message) -> Result<(), Error> {
match self.message_tx.try_send(message) {
Ok(()) => {
self.waker.wake_by_ref();
Ok(())
}
Err(_) => match self.try_join() {
JoinResult::Err(e) => panic!("agent thread terminated with error: {:?}", e),
JoinResult::Panic => panic!("agent thread panicked"),
_ => panic!("agent thread terminated prematurely"),
},
}
}
fn try_join(&self) -> JoinResult {
let mut option = self.join_handle.lock().unwrap();
if let Some(join_handle) = option.take() {
match join_handle.join() {
Ok(Ok(())) => JoinResult::Ok,
Ok(Err(e)) => JoinResult::Err(e),
Err(_) => JoinResult::Panic,
}
} else {
JoinResult::AlreadyJoined
}
}
}
impl Drop for Handle {
fn drop(&mut self) {
if self.send_message(Message::Close).is_err() {
tracing::error!("agent thread terminated prematurely");
}
match self.try_join() {
JoinResult::Ok => tracing::trace!("agent thread joined cleanly"),
JoinResult::Err(error) => tracing::error!(?error, "agent thread terminated with error"),
JoinResult::Panic => tracing::error!("agent thread panicked"),
_ => {}
}
}
}
impl AgentContext {
fn new(
mut multi: Multi,
selector: Selector,
message_tx: Sender<Message>,
message_rx: Receiver<Message>,
) -> Result<Self, Error> {
let timer = Arc::new(Timer::new());
let socket_updates = Arc::new(Mutex::new(HashMap::with_hasher(Default::default())));
let socket_updates_clone = socket_updates.clone();
multi
.socket_function(move |socket, events, _| {
let mut socket_updates = socket_updates_clone
.try_lock()
.expect("unexpected socket lock contention");
socket_updates.insert(socket, events);
})
.map_err(Error::from_any)?;
multi
.timer_function({
let timer = timer.clone();
move |timeout| match timeout {
Some(timeout) => {
timer.start(timeout);
true
}
None => {
timer.stop();
true
}
}
})
.map_err(Error::from_any)?;
Ok(Self {
multi,
message_tx,
message_rx,
requests: Slab::new(),
close_requested: false,
waker: selector.waker(),
selector,
timer,
socket_updates,
})
}
#[tracing::instrument(level = "trace", skip(self))]
fn begin_request(&mut self, mut request: EasyHandle) -> Result<(), Error> {
let entry = self.requests.vacant_entry();
let id = entry.key();
let handle = request.raw();
request.get_mut().init(
id,
handle,
{
let tx = self.message_tx.clone();
self.waker
.chain(move |inner| match tx.try_send(Message::UnpauseRead(id)) {
Ok(()) => inner.wake_by_ref(),
Err(_) => {
tracing::warn!(id, "agent went away while resuming read for request")
}
})
},
{
let tx = self.message_tx.clone();
self.waker
.chain(move |inner| match tx.try_send(Message::UnpauseWrite(id)) {
Ok(()) => inner.wake_by_ref(),
Err(_) => {
tracing::warn!(id, "agent went away while resuming write for request")
}
})
},
);
let mut handle = self.multi.add2(request).map_err(Error::from_any)?;
handle.set_token(id).map_err(Error::from_any)?;
entry.insert(handle);
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
fn complete_request(
&mut self,
token: usize,
result: Result<(), curl::Error>,
) -> Result<(), Error> {
let handle = self.requests.remove(token);
let mut handle = self.multi.remove2(handle).map_err(Error::from_any)?;
handle.get_mut().set_result(result.map_err(Error::from_any));
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
fn poll_messages(&mut self) -> Result<(), Error> {
while !self.close_requested {
if self.requests.is_empty() {
match block_on(self.message_rx.recv()) {
Ok(message) => self.handle_message(message)?,
_ => {
tracing::warn!("agent handle disconnected without close message");
self.close_requested = true;
break;
}
}
} else {
match self.message_rx.try_recv() {
Ok(message) => self.handle_message(message)?,
Err(async_channel::TryRecvError::Empty) => break,
Err(async_channel::TryRecvError::Closed) => {
tracing::warn!("agent handle disconnected without close message");
self.close_requested = true;
break;
}
}
}
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(self))]
fn handle_message(&mut self, message: Message) -> Result<(), Error> {
tracing::trace!("received message from agent handle");
match message {
Message::Close => self.close_requested = true,
Message::Execute(request) => self.begin_request(request)?,
Message::UnpauseRead(token) => {
if let Some(request) = self.requests.get(token) {
if let Err(error) = request.unpause_read() {
tracing::debug!(id = token, ?error, "error unpausing read for request");
}
} else {
tracing::warn!(
id = token,
"received unpause request for unknown request token",
);
}
}
Message::UnpauseWrite(token) => {
if let Some(request) = self.requests.get(token) {
if let Err(error) = request.unpause_write() {
tracing::debug!(id = token, ?error, "error unpausing write for request");
}
} else {
tracing::warn!(
id = token,
"received unpause request for unknown request token",
);
}
}
}
Ok(())
}
fn run(mut self) -> Result<(), Error> {
let mut multi_messages = Vec::new();
loop {
self.poll_messages()?;
if self.close_requested {
break;
}
self.poll()?;
self.multi.messages(|message| {
if let Some(result) = message.result() {
if let Ok(token) = message.token() {
multi_messages.push((token, result));
}
}
});
for (token, result) in multi_messages.drain(..) {
self.complete_request(token, result)?;
}
}
tracing::debug!("agent shutting down");
self.requests.clear();
Ok(())
}
fn poll(&mut self) -> Result<(), Error> {
let now = Instant::now();
let timeout = self.timer.get_remaining(now);
let poll_timeout = timeout.map(|t| t.min(WAIT_TIMEOUT)).unwrap_or(WAIT_TIMEOUT);
if self.selector.poll(poll_timeout)? {
for (socket, readable, writable) in self.selector.events() {
tracing::trace!(socket, readable, writable, "socket event");
let mut events = Events::new();
events.input(readable);
events.output(writable);
self.multi
.action(socket, &events)
.map_err(Error::from_any)?;
}
}
if self.timer.is_expired(now) {
self.timer.stop();
self.multi.timeout().map_err(Error::from_any)?;
}
for (socket, events) in self
.socket_updates
.try_lock()
.expect("unexpected socket lock contention")
.drain()
{
if events.remove() {
self.selector.deregister(socket)?;
} else {
let readable = events.input() || events.input_and_output();
let writable = events.output() || events.input_and_output();
self.selector.register(socket, readable, writable)?;
}
}
Ok(())
}
}
impl Drop for AgentContext {
fn drop(&mut self) {
self.requests.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
static_assertions::assert_impl_all!(Handle: Send, Sync);
static_assertions::assert_impl_all!(Message: Send);
}