Skip to main content

async_curl/
actor.rs

1use std::fmt::Debug;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use curl::easy::{Easy2, Handler};
6use curl::multi::{Multi, Socket, WaitFd};
7use log::trace;
8use std::collections::HashMap;
9use std::sync::Mutex;
10use tokio::runtime::{Builder, Runtime};
11use tokio::sync::mpsc::{self, Receiver, Sender};
12use tokio::sync::oneshot;
13use tokio::task::LocalSet;
14
15use crate::error::Error;
16
17#[async_trait]
18pub trait Actor<H>
19where
20    H: Handler + Debug + Send + 'static,
21{
22    async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>>;
23}
24
25/// CurlActor is responsible for performing
26/// the contructed Easy2 object at the background
27/// to perform it asynchronously.
28/// ```
29/// use async_curl::actor::{Actor, CurlActor};
30/// use curl::easy::{Easy2, Handler, WriteError};
31///
32/// #[derive(Debug, Clone, Default)]
33/// pub struct ResponseHandler {
34///     data: Vec<u8>,
35/// }
36///
37/// impl Handler for ResponseHandler {
38///     /// This will store the response from the server
39///     /// to the data vector.
40///     fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
41///         self.data.extend_from_slice(data);
42///         Ok(data.len())
43///     }
44/// }
45///
46/// impl ResponseHandler {
47///     /// Instantiation of the ResponseHandler
48///     /// and initialize the data vector.
49///     pub fn new() -> Self {
50///         Self::default()
51///     }
52///
53///     /// This will consumed the object and
54///     /// give the data to the caller
55///     pub fn get_data(self) -> Vec<u8> {
56///         self.data
57///     }
58/// }
59///
60/// # #[tokio::main(flavor = "current_thread")]
61/// # async fn main() -> Result<(), Box<dyn std::error::Error>>{
62/// let curl = CurlActor::new();
63/// let mut easy2 = Easy2::new(ResponseHandler::new());
64///
65/// easy2.url("https://www.rust-lang.org").unwrap();
66/// easy2.get(true).unwrap();
67///
68/// let response = curl.send_request(easy2).await.unwrap();
69/// eprintln!("{:?}", response.get_ref());
70///
71/// Ok(())
72/// # }
73/// ```
74///
75/// Example for multiple request executed
76/// at the same time.
77///
78/// ```
79/// use async_curl::actor::{Actor, CurlActor};
80/// use curl::easy::{Easy2, Handler, WriteError};
81///
82/// #[derive(Debug, Clone, Default)]
83/// pub struct ResponseHandler {
84///     data: Vec<u8>,
85/// }
86///
87/// impl Handler for ResponseHandler {
88///     /// This will store the response from the server
89///     /// to the data vector.
90///     fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
91///         self.data.extend_from_slice(data);
92///         Ok(data.len())
93///     }
94/// }
95///
96/// impl ResponseHandler {
97///     /// Instantiation of the ResponseHandler
98///     /// and initialize the data vector.
99///     pub fn new() -> Self {
100///         Self::default()
101///     }
102///
103///     /// This will consumed the object and
104///     /// give the data to the caller
105///     pub fn get_data(self) -> Vec<u8> {
106///         self.data
107///     }
108/// }
109///
110/// # #[tokio::main(flavor = "current_thread")]
111/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
112/// let actor = CurlActor::new();
113/// let mut easy2 = Easy2::new(ResponseHandler::new());
114/// easy2.url("https://www.rust-lang.org").unwrap();
115/// easy2.get(true).unwrap();
116///
117/// let actor1 = actor.clone();
118/// let spawn1 = tokio::spawn(async move {
119///     let response = actor1.send_request(easy2).await;
120///     let mut response = response.unwrap();
121///
122///     // Response body
123///     eprintln!(
124///         "Task 1 : {}",
125///         String::from_utf8_lossy(&response.get_ref().to_owned().get_data())
126///     );
127///     // Response status code
128///     let status_code = response.response_code().unwrap();
129///     eprintln!("Task 1 : {}", status_code);
130/// });
131///
132/// let mut easy2 = Easy2::new(ResponseHandler::new());
133/// easy2.url("https://www.rust-lang.org").unwrap();
134/// easy2.get(true).unwrap();
135///
136/// let spawn2 = tokio::spawn(async move {
137///     let response = actor.send_request(easy2).await;
138///     let mut response = response.unwrap();
139///
140///     // Response body
141///     eprintln!(
142///         "Task 2 : {}",
143///         String::from_utf8_lossy(&response.get_ref().to_owned().get_data())
144///     );
145///     // Response status code
146///     let status_code = response.response_code().unwrap();
147///     eprintln!("Task 2 : {}", status_code);
148/// });
149/// let (_, _) = tokio::join!(spawn1, spawn2);
150///
151/// Ok(())
152/// # }
153/// ```
154///
155use std::sync::Arc;
156use std::thread::JoinHandle;
157
158struct Inner<H>
159where
160    H: Handler + Debug + Send + 'static,
161{
162    request_sender: Option<Sender<Request<H>>>,
163    join_handle: Option<JoinHandle<()>>,
164}
165
166impl<H> Drop for Inner<H>
167where
168    H: Handler + Debug + Send + 'static,
169{
170    fn drop(&mut self) {
171        // Take and drop the sender so the background actor sees channel closed.
172        if let Some(sender) = self.request_sender.take() {
173            trace!("Dropping request sender to signal background actor to shut down.");
174            drop(sender);
175            trace!("Request sender dropped, signaling background actor to shut down.");
176        }
177        // Join the background thread to ensure graceful shutdown.
178        if let Some(handle) = self.join_handle.take() {
179            trace!("Attempting to join background actor thread for graceful shutdown...");
180            let _ = handle.join();
181            trace!("Background actor thread joined successfully.");
182        }
183    }
184}
185
186#[derive(Clone)]
187pub struct CurlActor<H>
188where
189    H: Handler + Debug + Send + 'static,
190{
191    inner: Arc<Inner<H>>,
192}
193
194impl<H> Default for CurlActor<H>
195where
196    H: Handler + Debug + Send + 'static,
197{
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203#[async_trait]
204impl<H> Actor<H> for CurlActor<H>
205where
206    H: Handler + Debug + Send + 'static,
207{
208    /// This will send Easy2 into the background task that will perform
209    /// curl asynchronously, await the response in the oneshot receiver and
210    /// return Easy2 back to the caller.
211    async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>> {
212        let (oneshot_sender, oneshot_receiver) = oneshot::channel::<Result<Easy2<H>, Error<H>>>();
213        self.inner
214            .request_sender
215            .as_ref()
216            .expect("request_sender missing")
217            .send(Request(easy2, oneshot_sender))
218            .await?;
219        oneshot_receiver.await?
220    }
221}
222
223impl<H> CurlActor<H>
224where
225    H: Handler + Debug + Send + 'static,
226{
227    /// This creates the new instance of CurlActor to handle Curl perform asynchronously using Curl Multi
228    /// in a background thread to avoid blocking of other tasks.
229    pub fn new() -> Self {
230        let runtime = Builder::new_current_thread().enable_all().build().unwrap();
231        let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
232
233        let handle = Self::spawn_actor(runtime, request_receiver);
234
235        Self {
236            inner: Arc::new(Inner {
237                request_sender: Some(request_sender),
238                join_handle: Some(handle),
239            }),
240        }
241    }
242
243    /// This creates the new instance of CurlActor to handle Curl perform asynchronously using Curl Multi
244    /// in a background thread to avoid blocking of other tasks. The user can provide a custom runtime
245    /// to use for the background task.
246    pub fn new_runtime(runtime: Runtime) -> Self {
247        let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
248
249        let handle = Self::spawn_actor(runtime, request_receiver);
250
251        Self {
252            inner: Arc::new(Inner {
253                request_sender: Some(request_sender),
254                join_handle: Some(handle),
255            }),
256        }
257    }
258
259    /// Create a new CurlActor with a user-provided runtime and configurable channel capacity.
260    pub fn new_runtime_with_capacity(runtime: Runtime, capacity: usize) -> Self {
261        let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(capacity);
262
263        let handle = Self::spawn_actor(runtime, request_receiver);
264
265        Self {
266            inner: Arc::new(Inner {
267                request_sender: Some(request_sender),
268                join_handle: Some(handle),
269            }),
270        }
271    }
272
273    fn spawn_actor(runtime: Runtime, mut request_receiver: Receiver<Request<H>>) -> JoinHandle<()> {
274        std::thread::spawn(move || {
275            let local = LocalSet::new();
276            local.spawn_local(async move {
277                while let Some(Request(easy2, oneshot_sender)) = request_receiver.recv().await {
278                    tokio::task::spawn_local(async move {
279                        let response = perform_curl_multi(easy2).await;
280                        if let Err(res) = oneshot_sender.send(response) {
281                            trace!("Warning! The receiver has been dropped. {:?}", res);
282                        }
283                    });
284                }
285            });
286            runtime.block_on(local);
287        })
288    }
289}
290
291async fn perform_curl_multi<H: Handler + Debug + Send + 'static>(
292    easy2: Easy2<H>,
293) -> Result<Easy2<H>, Error<H>> {
294    let mut multi = Multi::new();
295
296    // Track sockets libcurl wants us to wait on. We populate this via
297    // `socket_function` and then construct `WaitFd` entries from it before
298    // calling `multi.wait`.
299    let socket_map: std::sync::Arc<Mutex<HashMap<Socket, (bool, bool)>>> =
300        std::sync::Arc::new(Mutex::new(HashMap::new()));
301
302    {
303        let map = socket_map.clone();
304        multi
305            .socket_function(move |socket, events, _| match map.lock() {
306                Ok(mut m) => {
307                    if events.remove() {
308                        m.remove(&socket);
309                    } else {
310                        m.insert(socket, (events.input(), events.output()));
311                    }
312                }
313                Err(poison) => {
314                    trace!("socket_function: socket_map mutex poisoned, recovering");
315                    let mut m = poison.into_inner();
316                    if events.remove() {
317                        m.remove(&socket);
318                    } else {
319                        m.insert(socket, (events.input(), events.output()));
320                    }
321                }
322            })
323            .map_err(|e| Error::Multi(e))?;
324    }
325
326    let handle = multi.add2(easy2).map_err(|e| Error::Multi(e))?;
327
328    while multi.perform().map_err(|e| Error::Multi(e))? != 0 {
329        let timeout_result = multi
330            .get_timeout()
331            .map(|d| d.unwrap_or_else(|| Duration::from_secs(2)));
332
333        let timeout = match timeout_result {
334            Ok(duration) => duration,
335            Err(multi_error) => {
336                if !multi_error.is_call_perform() {
337                    return Err(Error::Multi(multi_error));
338                }
339                Duration::ZERO
340            }
341        };
342
343        if !timeout.is_zero() {
344            // Prefer libcurl's wait API to be event-driven and avoid arbitrary sleeps.
345            // This is cross-platform and should avoid the macOS SSL hang observed.
346            trace!(
347                "perform_curl_multi: waiting for IO or timeout {:?}",
348                timeout
349            );
350
351            // Snapshot the socket map while holding the mutex, then drop the
352            // guard before calling `multi.wait` to avoid deadlocks if libcurl
353            // invokes `socket_function` during the wait (which would try to
354            // lock the same mutex).
355            let sockets: Vec<(Socket, (bool, bool))> = match socket_map.lock() {
356                Ok(g) => g.iter().map(|(s, bo)| (*s, *bo)).collect(),
357                Err(poison) => {
358                    trace!("perform_curl_multi: socket_map mutex poisoned, recovering");
359                    let g = poison.into_inner();
360                    g.iter().map(|(s, bo)| (*s, *bo)).collect()
361                }
362            };
363
364            let mut waitfds: Vec<WaitFd> = Vec::with_capacity(sockets.len());
365            for (fd, (inp, out)) in sockets.into_iter() {
366                let mut w = WaitFd::new();
367                w.set_fd(fd);
368                if inp {
369                    w.poll_on_read(true);
370                }
371                if out {
372                    w.poll_on_write(true);
373                }
374                waitfds.push(w);
375            }
376
377            let ready = multi
378                .wait(&mut waitfds, timeout)
379                .map_err(|e| Error::Multi(e))?;
380            trace!(
381                "perform_curl_multi: wait completed, {} fds ready (buffered {})",
382                ready,
383                waitfds.len()
384            );
385        }
386    }
387
388    // Inspect messages for transfer-level errors.
389    let mut transfer_error: Option<Error<H>> = None;
390    multi.messages(|msg| {
391        if let Some(Err(e)) = msg.result() {
392            transfer_error = Some(Error::Curl(e));
393        }
394    });
395
396    // Always attempt to remove the handle to clean up resources. If there was
397    // a transfer error prefer returning that error, but still try to perform
398    // the removal and log any cleanup failure.
399    let cleanup = multi.remove2(handle).map_err(|e| Error::Multi(e));
400
401    if let Some(e) = transfer_error {
402        if let Err(ref clean_err) = cleanup {
403            trace!(
404                "perform_curl_multi: remove2 failed during cleanup: {:?}",
405                clean_err
406            );
407        }
408        Err(e)
409    } else {
410        cleanup
411    }
412}
413
414/// This contains the Easy2 object and a oneshot sender channel when passing into the
415/// background task to perform Curl asynchronously.
416#[derive(Debug)]
417pub struct Request<H: Handler + Debug + Send + 'static>(
418    Easy2<H>,
419    oneshot::Sender<Result<Easy2<H>, Error<H>>>,
420);