async_curl/actor.rs
1use std::fmt::Debug;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use curl::easy::{Easy2, Handler};
6use curl::multi::{Multi, Socket, WaitFd};
7use log::trace;
8use std::collections::HashMap;
9use std::sync::Mutex;
10use tokio::runtime::{Builder, Runtime};
11use tokio::sync::mpsc::{self, Receiver, Sender};
12use tokio::sync::oneshot;
13use tokio::task::LocalSet;
14
15use crate::error::Error;
16
17#[async_trait]
18pub trait Actor<H>
19where
20 H: Handler + Debug + Send + 'static,
21{
22 async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>>;
23}
24
25/// CurlActor is responsible for performing
26/// the contructed Easy2 object at the background
27/// to perform it asynchronously.
28/// ```
29/// use async_curl::actor::{Actor, CurlActor};
30/// use curl::easy::{Easy2, Handler, WriteError};
31///
32/// #[derive(Debug, Clone, Default)]
33/// pub struct ResponseHandler {
34/// data: Vec<u8>,
35/// }
36///
37/// impl Handler for ResponseHandler {
38/// /// This will store the response from the server
39/// /// to the data vector.
40/// fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
41/// self.data.extend_from_slice(data);
42/// Ok(data.len())
43/// }
44/// }
45///
46/// impl ResponseHandler {
47/// /// Instantiation of the ResponseHandler
48/// /// and initialize the data vector.
49/// pub fn new() -> Self {
50/// Self::default()
51/// }
52///
53/// /// This will consumed the object and
54/// /// give the data to the caller
55/// pub fn get_data(self) -> Vec<u8> {
56/// self.data
57/// }
58/// }
59///
60/// # #[tokio::main(flavor = "current_thread")]
61/// # async fn main() -> Result<(), Box<dyn std::error::Error>>{
62/// let curl = CurlActor::new();
63/// let mut easy2 = Easy2::new(ResponseHandler::new());
64///
65/// easy2.url("https://www.rust-lang.org").unwrap();
66/// easy2.get(true).unwrap();
67///
68/// let response = curl.send_request(easy2).await.unwrap();
69/// eprintln!("{:?}", response.get_ref());
70///
71/// Ok(())
72/// # }
73/// ```
74///
75/// Example for multiple request executed
76/// at the same time.
77///
78/// ```
79/// use async_curl::actor::{Actor, CurlActor};
80/// use curl::easy::{Easy2, Handler, WriteError};
81///
82/// #[derive(Debug, Clone, Default)]
83/// pub struct ResponseHandler {
84/// data: Vec<u8>,
85/// }
86///
87/// impl Handler for ResponseHandler {
88/// /// This will store the response from the server
89/// /// to the data vector.
90/// fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
91/// self.data.extend_from_slice(data);
92/// Ok(data.len())
93/// }
94/// }
95///
96/// impl ResponseHandler {
97/// /// Instantiation of the ResponseHandler
98/// /// and initialize the data vector.
99/// pub fn new() -> Self {
100/// Self::default()
101/// }
102///
103/// /// This will consumed the object and
104/// /// give the data to the caller
105/// pub fn get_data(self) -> Vec<u8> {
106/// self.data
107/// }
108/// }
109///
110/// # #[tokio::main(flavor = "current_thread")]
111/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
112/// let actor = CurlActor::new();
113/// let mut easy2 = Easy2::new(ResponseHandler::new());
114/// easy2.url("https://www.rust-lang.org").unwrap();
115/// easy2.get(true).unwrap();
116///
117/// let actor1 = actor.clone();
118/// let spawn1 = tokio::spawn(async move {
119/// let response = actor1.send_request(easy2).await;
120/// let mut response = response.unwrap();
121///
122/// // Response body
123/// eprintln!(
124/// "Task 1 : {}",
125/// String::from_utf8_lossy(&response.get_ref().to_owned().get_data())
126/// );
127/// // Response status code
128/// let status_code = response.response_code().unwrap();
129/// eprintln!("Task 1 : {}", status_code);
130/// });
131///
132/// let mut easy2 = Easy2::new(ResponseHandler::new());
133/// easy2.url("https://www.rust-lang.org").unwrap();
134/// easy2.get(true).unwrap();
135///
136/// let spawn2 = tokio::spawn(async move {
137/// let response = actor.send_request(easy2).await;
138/// let mut response = response.unwrap();
139///
140/// // Response body
141/// eprintln!(
142/// "Task 2 : {}",
143/// String::from_utf8_lossy(&response.get_ref().to_owned().get_data())
144/// );
145/// // Response status code
146/// let status_code = response.response_code().unwrap();
147/// eprintln!("Task 2 : {}", status_code);
148/// });
149/// let (_, _) = tokio::join!(spawn1, spawn2);
150///
151/// Ok(())
152/// # }
153/// ```
154///
155use std::sync::Arc;
156use std::thread::JoinHandle;
157
158struct Inner<H>
159where
160 H: Handler + Debug + Send + 'static,
161{
162 request_sender: Option<Sender<Request<H>>>,
163 join_handle: Option<JoinHandle<()>>,
164}
165
166impl<H> Drop for Inner<H>
167where
168 H: Handler + Debug + Send + 'static,
169{
170 fn drop(&mut self) {
171 // Take and drop the sender so the background actor sees channel closed.
172 if let Some(sender) = self.request_sender.take() {
173 trace!("Dropping request sender to signal background actor to shut down.");
174 drop(sender);
175 trace!("Request sender dropped, signaling background actor to shut down.");
176 }
177 // Join the background thread to ensure graceful shutdown.
178 if let Some(handle) = self.join_handle.take() {
179 trace!("Attempting to join background actor thread for graceful shutdown...");
180 let _ = handle.join();
181 trace!("Background actor thread joined successfully.");
182 }
183 }
184}
185
186#[derive(Clone)]
187pub struct CurlActor<H>
188where
189 H: Handler + Debug + Send + 'static,
190{
191 inner: Arc<Inner<H>>,
192}
193
194impl<H> Default for CurlActor<H>
195where
196 H: Handler + Debug + Send + 'static,
197{
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203#[async_trait]
204impl<H> Actor<H> for CurlActor<H>
205where
206 H: Handler + Debug + Send + 'static,
207{
208 /// This will send Easy2 into the background task that will perform
209 /// curl asynchronously, await the response in the oneshot receiver and
210 /// return Easy2 back to the caller.
211 async fn send_request(&self, easy2: Easy2<H>) -> Result<Easy2<H>, Error<H>> {
212 let (oneshot_sender, oneshot_receiver) = oneshot::channel::<Result<Easy2<H>, Error<H>>>();
213 self.inner
214 .request_sender
215 .as_ref()
216 .expect("request_sender missing")
217 .send(Request(easy2, oneshot_sender))
218 .await?;
219 oneshot_receiver.await?
220 }
221}
222
223impl<H> CurlActor<H>
224where
225 H: Handler + Debug + Send + 'static,
226{
227 /// This creates the new instance of CurlActor to handle Curl perform asynchronously using Curl Multi
228 /// in a background thread to avoid blocking of other tasks.
229 pub fn new() -> Self {
230 let runtime = Builder::new_current_thread().enable_all().build().unwrap();
231 let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
232
233 let handle = Self::spawn_actor(runtime, request_receiver);
234
235 Self {
236 inner: Arc::new(Inner {
237 request_sender: Some(request_sender),
238 join_handle: Some(handle),
239 }),
240 }
241 }
242
243 /// This creates the new instance of CurlActor to handle Curl perform asynchronously using Curl Multi
244 /// in a background thread to avoid blocking of other tasks. The user can provide a custom runtime
245 /// to use for the background task.
246 pub fn new_runtime(runtime: Runtime) -> Self {
247 let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(1);
248
249 let handle = Self::spawn_actor(runtime, request_receiver);
250
251 Self {
252 inner: Arc::new(Inner {
253 request_sender: Some(request_sender),
254 join_handle: Some(handle),
255 }),
256 }
257 }
258
259 /// Create a new CurlActor with a user-provided runtime and configurable channel capacity.
260 pub fn new_runtime_with_capacity(runtime: Runtime, capacity: usize) -> Self {
261 let (request_sender, request_receiver) = mpsc::channel::<Request<H>>(capacity);
262
263 let handle = Self::spawn_actor(runtime, request_receiver);
264
265 Self {
266 inner: Arc::new(Inner {
267 request_sender: Some(request_sender),
268 join_handle: Some(handle),
269 }),
270 }
271 }
272
273 fn spawn_actor(runtime: Runtime, mut request_receiver: Receiver<Request<H>>) -> JoinHandle<()> {
274 std::thread::spawn(move || {
275 let local = LocalSet::new();
276 local.spawn_local(async move {
277 while let Some(Request(easy2, oneshot_sender)) = request_receiver.recv().await {
278 tokio::task::spawn_local(async move {
279 let response = perform_curl_multi(easy2).await;
280 if let Err(res) = oneshot_sender.send(response) {
281 trace!("Warning! The receiver has been dropped. {:?}", res);
282 }
283 });
284 }
285 });
286 runtime.block_on(local);
287 })
288 }
289}
290
291async fn perform_curl_multi<H: Handler + Debug + Send + 'static>(
292 easy2: Easy2<H>,
293) -> Result<Easy2<H>, Error<H>> {
294 let mut multi = Multi::new();
295
296 // Track sockets libcurl wants us to wait on. We populate this via
297 // `socket_function` and then construct `WaitFd` entries from it before
298 // calling `multi.wait`.
299 let socket_map: std::sync::Arc<Mutex<HashMap<Socket, (bool, bool)>>> =
300 std::sync::Arc::new(Mutex::new(HashMap::new()));
301
302 {
303 let map = socket_map.clone();
304 multi
305 .socket_function(move |socket, events, _| match map.lock() {
306 Ok(mut m) => {
307 if events.remove() {
308 m.remove(&socket);
309 } else {
310 m.insert(socket, (events.input(), events.output()));
311 }
312 }
313 Err(poison) => {
314 trace!("socket_function: socket_map mutex poisoned, recovering");
315 let mut m = poison.into_inner();
316 if events.remove() {
317 m.remove(&socket);
318 } else {
319 m.insert(socket, (events.input(), events.output()));
320 }
321 }
322 })
323 .map_err(|e| Error::Multi(e))?;
324 }
325
326 let handle = multi.add2(easy2).map_err(|e| Error::Multi(e))?;
327
328 while multi.perform().map_err(|e| Error::Multi(e))? != 0 {
329 let timeout_result = multi
330 .get_timeout()
331 .map(|d| d.unwrap_or_else(|| Duration::from_secs(2)));
332
333 let timeout = match timeout_result {
334 Ok(duration) => duration,
335 Err(multi_error) => {
336 if !multi_error.is_call_perform() {
337 return Err(Error::Multi(multi_error));
338 }
339 Duration::ZERO
340 }
341 };
342
343 if !timeout.is_zero() {
344 // Prefer libcurl's wait API to be event-driven and avoid arbitrary sleeps.
345 // This is cross-platform and should avoid the macOS SSL hang observed.
346 trace!(
347 "perform_curl_multi: waiting for IO or timeout {:?}",
348 timeout
349 );
350
351 // Snapshot the socket map while holding the mutex, then drop the
352 // guard before calling `multi.wait` to avoid deadlocks if libcurl
353 // invokes `socket_function` during the wait (which would try to
354 // lock the same mutex).
355 let sockets: Vec<(Socket, (bool, bool))> = match socket_map.lock() {
356 Ok(g) => g.iter().map(|(s, bo)| (*s, *bo)).collect(),
357 Err(poison) => {
358 trace!("perform_curl_multi: socket_map mutex poisoned, recovering");
359 let g = poison.into_inner();
360 g.iter().map(|(s, bo)| (*s, *bo)).collect()
361 }
362 };
363
364 let mut waitfds: Vec<WaitFd> = Vec::with_capacity(sockets.len());
365 for (fd, (inp, out)) in sockets.into_iter() {
366 let mut w = WaitFd::new();
367 w.set_fd(fd);
368 if inp {
369 w.poll_on_read(true);
370 }
371 if out {
372 w.poll_on_write(true);
373 }
374 waitfds.push(w);
375 }
376
377 let ready = multi
378 .wait(&mut waitfds, timeout)
379 .map_err(|e| Error::Multi(e))?;
380 trace!(
381 "perform_curl_multi: wait completed, {} fds ready (buffered {})",
382 ready,
383 waitfds.len()
384 );
385 }
386 }
387
388 // Inspect messages for transfer-level errors.
389 let mut transfer_error: Option<Error<H>> = None;
390 multi.messages(|msg| {
391 if let Some(Err(e)) = msg.result() {
392 transfer_error = Some(Error::Curl(e));
393 }
394 });
395
396 // Always attempt to remove the handle to clean up resources. If there was
397 // a transfer error prefer returning that error, but still try to perform
398 // the removal and log any cleanup failure.
399 let cleanup = multi.remove2(handle).map_err(|e| Error::Multi(e));
400
401 if let Some(e) = transfer_error {
402 if let Err(ref clean_err) = cleanup {
403 trace!(
404 "perform_curl_multi: remove2 failed during cleanup: {:?}",
405 clean_err
406 );
407 }
408 Err(e)
409 } else {
410 cleanup
411 }
412}
413
414/// This contains the Easy2 object and a oneshot sender channel when passing into the
415/// background task to perform Curl asynchronously.
416#[derive(Debug)]
417pub struct Request<H: Handler + Debug + Send + 'static>(
418 Easy2<H>,
419 oneshot::Sender<Result<Easy2<H>, Error<H>>>,
420);