use crate::handler::RequestHandler;
use crate::task::{UdpWaker, WakerExt};
use crate::Error;
use crossbeam_channel::{Receiver, Sender};
use curl::multi::WaitFd;
use futures_util::task::ArcWake;
use slab::Slab;
use std::net::UdpSocket;
use std::sync::Arc;
use std::task::Waker;
use std::thread;
use std::time::{Duration, Instant};
const AGENT_THREAD_NAME: &str = "curl agent";
const WAIT_TIMEOUT: Duration = Duration::from_millis(100);
type EasyHandle = curl::easy::Easy2<RequestHandler>;
type MultiMessage = (usize, Result<(), curl::Error>);
#[derive(Debug)]
pub(crate) struct Handle {
message_tx: Sender<Message>,
waker: Waker,
join_handle: Option<thread::JoinHandle<Result<(), Error>>>,
}
struct AgentThread {
multi: curl::multi::Multi,
multi_messages: (Sender<MultiMessage>, Receiver<MultiMessage>),
message_tx: Sender<Message>,
message_rx: Receiver<Message>,
wake_socket: UdpSocket,
requests: Slab<curl::multi::Easy2Handle<RequestHandler>>,
close_requested: bool,
waker: Waker,
}
#[derive(Debug)]
enum Message {
Close,
Execute(EasyHandle),
UnpauseRead(usize),
UnpauseWrite(usize),
}
pub(crate) fn new() -> Result<Handle, Error> {
let create_start = Instant::now();
let wake_socket = UdpSocket::bind("127.0.0.1:0")?;
wake_socket.set_nonblocking(true)?;
let wake_addr = wake_socket.local_addr()?;
let waker = Arc::new(UdpWaker::connect(wake_addr)?).into_waker();
log::debug!("agent waker listening on {}", wake_addr);
let (message_tx, message_rx) = crossbeam_channel::unbounded();
Ok(Handle {
message_tx: message_tx.clone(),
waker: waker.clone(),
join_handle: Some(
thread::Builder::new()
.name(String::from(AGENT_THREAD_NAME))
.spawn(move || {
let agent = AgentThread {
multi: curl::multi::Multi::new(),
multi_messages: crossbeam_channel::unbounded(),
message_tx,
message_rx,
wake_socket,
requests: Slab::new(),
close_requested: false,
waker,
};
log::debug!("agent took {:?} to start up", create_start.elapsed());
agent.run()
})?,
),
})
}
impl Handle {
pub(crate) fn submit_request(&self, request: EasyHandle) -> Result<(), Error> {
self.send_message(Message::Execute(request))
}
fn send_message(&self, message: Message) -> Result<(), Error> {
match self.message_tx.send(message) {
Ok(()) => {
self.waker.wake_by_ref();
Ok(())
}
Err(_) => {
log::error!("agent thread terminated prematurely");
Err(Error::Internal)
}
}
}
}
impl Drop for Handle {
fn drop(&mut self) {
if self.send_message(Message::Close).is_err() {
log::error!("agent thread terminated prematurely");
}
if let Some(join_handle) = self.join_handle.take() {
match join_handle.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => log::error!("agent thread terminated with error: {}", e),
Err(_) => log::error!("agent thread panicked"),
}
}
}
}
impl AgentThread {
fn begin_request(&mut self, mut request: EasyHandle) -> Result<(), Error> {
let entry = self.requests.vacant_entry();
let id = entry.key();
request.get_mut().init(
id,
{
let tx = self.message_tx.clone();
self.waker
.chain(move |inner| match tx.send(Message::UnpauseRead(id)) {
Ok(()) => inner.wake_by_ref(),
Err(_) => log::warn!(
"agent went away while resuming read for request [id={}]",
id
),
})
},
{
let tx = self.message_tx.clone();
self.waker
.chain(move |inner| match tx.send(Message::UnpauseWrite(id)) {
Ok(()) => inner.wake_by_ref(),
Err(_) => log::warn!(
"agent went away while resuming write for request [id={}]",
id
),
})
},
);
let mut handle = self.multi.add2(request)?;
handle.set_token(id)?;
entry.insert(handle);
Ok(())
}
fn get_wait_fds(&self) -> [WaitFd; 1] {
let mut fd = WaitFd::new();
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
fd.set_fd(self.wake_socket.as_raw_fd());
}
#[cfg(windows)]
{
use std::os::windows::io::AsRawSocket;
fd.set_fd(self.wake_socket.as_raw_socket());
}
fd.poll_on_read(true);
[fd]
}
fn poll_messages(&mut self) -> Result<(), Error> {
loop {
if !self.close_requested && self.requests.is_empty() {
match self.message_rx.recv() {
Ok(message) => self.handle_message(message)?,
_ => {
log::warn!("agent handle disconnected without close message");
self.close_requested = true;
break;
}
}
} else {
match self.message_rx.try_recv() {
Ok(message) => self.handle_message(message)?,
Err(crossbeam_channel::TryRecvError::Empty) => break,
Err(crossbeam_channel::TryRecvError::Disconnected) => {
log::warn!("agent handle disconnected without close message");
self.close_requested = true;
break;
}
}
}
}
Ok(())
}
fn handle_message(&mut self, message: Message) -> Result<(), Error> {
log::trace!("received message from agent handle: {:?}", message);
match message {
Message::Close => self.close_requested = true,
Message::Execute(request) => self.begin_request(request)?,
Message::UnpauseRead(token) => {
if let Some(request) = self.requests.get(token) {
request.unpause_read()?;
} else {
log::warn!(
"received unpause request for unknown request token: {}",
token
);
}
}
Message::UnpauseWrite(token) => {
if let Some(request) = self.requests.get(token) {
request.unpause_write()?;
} else {
log::warn!(
"received unpause request for unknown request token: {}",
token
);
}
}
}
Ok(())
}
fn dispatch(&mut self) -> Result<(), Error> {
self.multi.perform()?;
self.multi.messages(|message| {
if let Some(result) = message.result() {
if let Ok(token) = message.token() {
self.multi_messages.0.send((token, result)).unwrap();
}
}
});
loop {
match self.multi_messages.1.try_recv() {
Ok((token, result)) => {
let handle = self.requests.remove(token);
let mut handle = self.multi.remove2(handle)?;
handle.get_mut().on_result(result);
}
Err(crossbeam_channel::TryRecvError::Empty) => break,
Err(crossbeam_channel::TryRecvError::Disconnected) => panic!(),
}
}
Ok(())
}
fn waker_drain(&self) -> bool {
let mut woke = false;
while let Ok(_) = self.wake_socket.recv_from(&mut [0; 32]) {
woke = true;
}
woke
}
fn run(mut self) -> Result<(), Error> {
let mut wait_fds = self.get_wait_fds();
log::debug!("agent ready");
loop {
self.poll_messages()?;
self.dispatch()?;
if self.close_requested {
break;
}
self.multi.wait(&mut wait_fds, WAIT_TIMEOUT)?;
if self.waker_drain() {
log::trace!("woke up from waker");
}
}
log::debug!("agent shutting down");
self.requests.clear();
self.multi.close()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_send<T: Send>() {}
fn is_sync<T: Sync>() {}
#[test]
fn traits() {
is_send::<Handle>();
is_sync::<Handle>();
is_send::<Message>();
}
}