use std::fmt::Debug;
use std::time::Duration;
use async_trait::async_trait;
use curl::easy::{Easy2, Handler};
use curl::multi::{Multi, Socket, WaitFd};
use log::trace;
use std::collections::HashMap;
use std::sync::Mutex;
use tokio::runtime::{Builder, Runtime};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use crate::error::Error;
#[async_trait]
pub trait Actor<H>
where
H: Handler + Debug + Send + 'static,
{
async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>>;
}
use std::sync::Arc;
use std::thread::JoinHandle;
struct Inner<H>
where
H: Handler + Debug + Send + 'static,
{
request_sender: Option<Sender<Request<H>>>,
join_handle: Option<JoinHandle<()>>,
}
impl<H> Drop for Inner<H>
where
H: Handler + Debug + Send + 'static,
{
fn drop(&mut self) {
if let Some(sender) = self.request_sender.take() {
trace!("Dropping request sender to signal background actor to shut down.");
drop(sender);
trace!("Request sender dropped, signaling background actor to shut down.");
}
if let Some(handle) = self.join_handle.take() {
trace!("Attempting to join background actor thread for graceful shutdown...");
let _ = handle.join();
trace!("Background actor thread joined successfully.");
}
}
}
#[derive(Clone)]
pub struct CurlActor<H>
where
H: Handler + Debug + Send + 'static,
{
inner: Arc<Inner<H>>,
}
impl<H> Default for CurlActor<H>
where
H: Handler + Debug + Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl<H> Actor<H> for CurlActor<H>
where
H: Handler + Debug + Send + 'static,
{
async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>> {
let (oneshot_sender, oneshot_receiver) = oneshot::channel::<Result<Easy2<H>, Error<H>>>();
self.inner
.request_sender
.as_ref()
.expect("request_sender missing")
.send(Request(easy2, oneshot_sender))
.await?;
oneshot_receiver.await?
}
}
impl<H> CurlActor<H>
where
H: Handler + Debug + Send + 'static,
{
pub fn new() -> Self {
let runtime = Builder::new_current_thread().enable_all().build().unwrap();
let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
let handle = Self::spawn_actor(runtime, request_receiver);
Self {
inner: Arc::new(Inner {
request_sender: Some(request_sender),
join_handle: Some(handle),
}),
}
}
pub fn new_runtime(runtime: Runtime) -> Self {
let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
let handle = Self::spawn_actor(runtime, request_receiver);
Self {
inner: Arc::new(Inner {
request_sender: Some(request_sender),
join_handle: Some(handle),
}),
}
}
pub fn new_runtime_with_capacity(runtime: Runtime, capacity: usize) -> Self {
let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(capacity);
let handle = Self::spawn_actor(runtime, request_receiver);
Self {
inner: Arc::new(Inner {
request_sender: Some(request_sender),
join_handle: Some(handle),
}),
}
}
fn spawn_actor(runtime: Runtime, mut request_receiver: Receiver<Request<H>>) -> JoinHandle<()> {
std::thread::spawn(move || {
let local = LocalSet::new();
local.spawn_local(async move {
while let Some(Request(easy2, oneshot_sender)) = request_receiver.recv().await {
tokio::task::spawn_local(async move {
let response = perform_curl_multi(easy2).await;
if let Err(res) = oneshot_sender.send(response) {
trace!("Warning! The receiver has been dropped. {:?}", res);
}
});
}
});
runtime.block_on(local);
})
}
}
async fn perform_curl_multi<H: Handler + Debug + Send + 'static>(
easy2: Easy2<H>,
) -> Result<Easy2<H>, Error<H>> {
let mut multi = Multi::new();
let socket_map: std::sync::Arc<Mutex<HashMap<Socket, (bool, bool)>>> =
std::sync::Arc::new(Mutex::new(HashMap::new()));
{
let map = socket_map.clone();
multi
.socket_function(move |socket, events, _| match map.lock() {
Ok(mut m) => {
if events.remove() {
m.remove(&socket);
} else {
m.insert(socket, (events.input(), events.output()));
}
}
Err(poison) => {
trace!("socket_function: socket_map mutex poisoned, recovering");
let mut m = poison.into_inner();
if events.remove() {
m.remove(&socket);
} else {
m.insert(socket, (events.input(), events.output()));
}
}
})
.map_err(|e| Error::Multi(e))?;
}
let handle = multi.add2(easy2).map_err(|e| Error::Multi(e))?;
while multi.perform().map_err(|e| Error::Multi(e))? != 0 {
let timeout_result = multi
.get_timeout()
.map(|d| d.unwrap_or_else(|| Duration::from_secs(2)));
let timeout = match timeout_result {
Ok(duration) => duration,
Err(multi_error) => {
if !multi_error.is_call_perform() {
return Err(Error::Multi(multi_error));
}
Duration::ZERO
}
};
if !timeout.is_zero() {
trace!(
"perform_curl_multi: waiting for IO or timeout {:?}",
timeout
);
let sockets: Vec<(Socket, (bool, bool))> = match socket_map.lock() {
Ok(g) => g.iter().map(|(s, bo)| (*s, *bo)).collect(),
Err(poison) => {
trace!("perform_curl_multi: socket_map mutex poisoned, recovering");
let g = poison.into_inner();
g.iter().map(|(s, bo)| (*s, *bo)).collect()
}
};
let mut waitfds: Vec<WaitFd> = Vec::with_capacity(sockets.len());
for (fd, (inp, out)) in sockets.into_iter() {
let mut w = WaitFd::new();
w.set_fd(fd);
if inp {
w.poll_on_read(true);
}
if out {
w.poll_on_write(true);
}
waitfds.push(w);
}
let ready = multi
.wait(&mut waitfds, timeout)
.map_err(|e| Error::Multi(e))?;
trace!(
"perform_curl_multi: wait completed, {} fds ready (buffered {})",
ready,
waitfds.len()
);
}
}
let mut transfer_error: Option<Error<H>> = None;
multi.messages(|msg| {
if let Some(Err(e)) = msg.result() {
transfer_error = Some(Error::Curl(e));
}
});
let cleanup = multi.remove2(handle).map_err(|e| Error::Multi(e));
if let Some(e) = transfer_error {
if let Err(ref clean_err) = cleanup {
trace!(
"perform_curl_multi: remove2 failed during cleanup: {:?}",
clean_err
);
}
Err(e)
} else {
cleanup
}
}
#[derive(Debug)]
pub struct Request<H: Handler + Debug + Send + 'static>(
Easy2<H>,
oneshot::Sender<Result<Easy2<H>, Error<H>>>,
);