isahc/agent/
mod.rs

1//! Curl agent that executes multiple requests simultaneously.
2//!
3//! The agent is implemented as a single background thread attached to a
4//! "handle". The handle communicates with the agent thread by using message
5//! passing. The agent executes multiple curl requests simultaneously by using a
6//! single "multi" handle.
7//!
8//! Since request executions are driven through futures, the agent also acts as
9//! a specialized task executor for tasks related to requests.
10
11use crate::{error::Error, handler::RequestHandler, task::WakerExt};
12use async_channel::{Receiver, Sender};
13use crossbeam_utils::{atomic::AtomicCell, sync::WaitGroup};
14use curl::multi::{Events, Multi, Socket, SocketEvents};
15use futures_lite::future::block_on;
16use slab::Slab;
17use std::{
18    io,
19    sync::{Arc, Mutex},
20    task::Waker,
21    thread,
22    time::{Duration, Instant},
23};
24
25use self::{selector::Selector, timer::Timer};
26
27mod selector;
28mod timer;
29
30static NEXT_AGENT_ID: AtomicCell<usize> = AtomicCell::new(0);
31const WAIT_TIMEOUT: Duration = Duration::from_millis(1000);
32
33type EasyHandle = curl::easy::Easy2<RequestHandler>;
34
35/// Builder for configuring and spawning an agent.
36#[derive(Debug, Default)]
37pub(crate) struct AgentBuilder {
38    max_connections: usize,
39    max_connections_per_host: usize,
40    connection_cache_size: usize,
41}
42
43impl AgentBuilder {
44    pub(crate) fn max_connections(mut self, max: usize) -> Self {
45        self.max_connections = max;
46        self
47    }
48
49    pub(crate) fn max_connections_per_host(mut self, max: usize) -> Self {
50        self.max_connections_per_host = max;
51        self
52    }
53
54    pub(crate) fn connection_cache_size(mut self, size: usize) -> Self {
55        self.connection_cache_size = size;
56        self
57    }
58
59    /// Spawn a new agent using the configuration in this builder and return a
60    /// handle for communicating with the agent.
61    pub(crate) fn spawn(&self) -> io::Result<Handle> {
62        let create_start = Instant::now();
63
64        // Initialize libcurl, if necessary, on the current thread.
65        //
66        // Note that as of 0.4.30, the curl crate will attempt to do this for us
67        // on the main thread automatically at program start on most targets,
68        // but on other targets must still be initialized on the main thread. We
69        // do this here in the hope that the user builds an `HttpClient` on the
70        // main thread (as opposed to waiting for `Multi::new()` to do it for
71        // us below, which we _know_ is not on the main thread).
72        //
73        // See #189.
74        curl::init();
75
76        let id = NEXT_AGENT_ID.fetch_add(1);
77
78        // Create an I/O selector for driving curl's sockets.
79        let selector = Selector::new()?;
80
81        let (message_tx, message_rx) = async_channel::unbounded();
82
83        let wait_group = WaitGroup::new();
84        let wait_group_thread = wait_group.clone();
85
86        let max_connections = self.max_connections;
87        let max_connections_per_host = self.max_connections_per_host;
88        let connection_cache_size = self.connection_cache_size;
89
90        // Create a span for the agent thread that outlives this method call,
91        // but rather was caused by it.
92        let agent_span = tracing::debug_span!("agent_thread", id);
93        agent_span.follows_from(tracing::Span::current());
94
95        let waker = selector.waker();
96        let message_tx_clone = message_tx.clone();
97
98        let thread_main = move || {
99            let _enter = agent_span.enter();
100            let mut multi = Multi::new();
101
102            if max_connections > 0 {
103                multi
104                    .set_max_total_connections(max_connections)
105                    .map_err(Error::from_any)?;
106            }
107
108            if max_connections_per_host > 0 {
109                multi
110                    .set_max_host_connections(max_connections_per_host)
111                    .map_err(Error::from_any)?;
112            }
113
114            // Only set maxconnects if greater than 0, because 0 actually means unlimited.
115            if connection_cache_size > 0 {
116                multi
117                    .set_max_connects(connection_cache_size)
118                    .map_err(Error::from_any)?;
119            }
120
121            let agent = AgentContext::new(multi, selector, message_tx_clone, message_rx)?;
122
123            drop(wait_group_thread);
124
125            tracing::debug!("agent took {:?} to start up", create_start.elapsed());
126
127            let result = agent.run();
128
129            if let Err(e) = &result {
130                tracing::error!("agent shut down with error: {:?}", e);
131            }
132
133            result
134        };
135
136        let handle = Handle {
137            message_tx,
138            waker,
139            join_handle: Mutex::new(Some(
140                thread::Builder::new()
141                    .name(format!("isahc-agent-{}", id))
142                    .spawn(thread_main)?,
143            )),
144        };
145
146        // Block until the agent thread responds.
147        wait_group.wait();
148
149        Ok(handle)
150    }
151}
152
153/// A handle to an active agent running in a background thread.
154///
155/// Dropping the handle will cause the agent thread to shut down and abort any
156/// pending transfers.
157#[derive(Debug)]
158pub(crate) struct Handle {
159    /// Used to send messages to the agent thread.
160    message_tx: Sender<Message>,
161
162    /// A waker that can wake up the agent thread while it is polling.
163    waker: Waker,
164
165    /// A join handle for the agent thread.
166    join_handle: Mutex<Option<thread::JoinHandle<Result<(), Error>>>>,
167}
168
169/// Internal state of an agent thread.
170///
171/// The agent thread runs the primary client event loop, which is essentially a
172/// traditional curl multi event loop with some extra bookkeeping and async
173/// features like wakers.
174struct AgentContext {
175    /// A curl multi handle, of course.
176    multi: curl::multi::Multi,
177
178    /// Used to send messages to the agent thread.
179    message_tx: Sender<Message>,
180
181    /// Incoming messages from the agent handle.
182    message_rx: Receiver<Message>,
183
184    /// Contains all of the active requests.
185    requests: Slab<curl::multi::Easy2Handle<RequestHandler>>,
186
187    /// Indicates if the thread has been requested to stop.
188    close_requested: bool,
189
190    /// A waker that can wake up the agent thread while it is polling.
191    waker: Waker,
192
193    /// This is the poller we use to poll for socket activity!
194    selector: Selector,
195
196    /// A timer we use to keep track of curl's timeouts.
197    timer: Arc<Timer>,
198
199    /// Queue of socket registration updates from the multi handle.
200    socket_updates: Receiver<(Socket, SocketEvents, usize)>,
201}
202
203/// A message sent from the main thread to the agent thread.
204#[derive(Debug)]
205enum Message {
206    /// Requests the agent to close.
207    Close,
208
209    /// Begin executing a new request.
210    Execute(EasyHandle),
211
212    /// Request to resume reading the request body for the request with the
213    /// given ID.
214    UnpauseRead(usize),
215
216    /// Request to resume writing the response body for the request with the
217    /// given ID.
218    UnpauseWrite(usize),
219}
220
221#[derive(Debug)]
222enum JoinResult {
223    AlreadyJoined,
224    Ok,
225    Err(Error),
226    Panic,
227}
228
229impl Handle {
230    /// Begin executing a request with this agent.
231    pub(crate) fn submit_request(&self, request: EasyHandle) -> Result<(), Error> {
232        self.send_message(Message::Execute(request))
233    }
234
235    /// Send a message to the agent thread.
236    ///
237    /// If the agent is not connected, an error is returned.
238    fn send_message(&self, message: Message) -> Result<(), Error> {
239        match self.message_tx.try_send(message) {
240            Ok(()) => {
241                // Wake the agent thread up so it will check its messages soon.
242                self.waker.wake_by_ref();
243                Ok(())
244            }
245            Err(_) => match self.try_join() {
246                JoinResult::Err(e) => panic!("agent thread terminated with error: {:?}", e),
247                JoinResult::Panic => panic!("agent thread panicked"),
248                _ => panic!("agent thread terminated prematurely"),
249            },
250        }
251    }
252
253    fn try_join(&self) -> JoinResult {
254        let mut option = self.join_handle.lock().unwrap();
255
256        if let Some(join_handle) = option.take() {
257            match join_handle.join() {
258                Ok(Ok(())) => JoinResult::Ok,
259                Ok(Err(e)) => JoinResult::Err(e),
260                Err(_) => JoinResult::Panic,
261            }
262        } else {
263            JoinResult::AlreadyJoined
264        }
265    }
266}
267
268impl Drop for Handle {
269    fn drop(&mut self) {
270        // Request the agent thread to shut down.
271        if self.send_message(Message::Close).is_err() {
272            tracing::error!("agent thread terminated prematurely");
273        }
274
275        // Wait for the agent thread to shut down before continuing.
276        match self.try_join() {
277            JoinResult::Ok => tracing::trace!("agent thread joined cleanly"),
278            JoinResult::Err(e) => tracing::error!("agent thread terminated with error: {}", e),
279            JoinResult::Panic => tracing::error!("agent thread panicked"),
280            _ => {}
281        }
282    }
283}
284
285impl AgentContext {
286    fn new(
287        mut multi: Multi,
288        selector: Selector,
289        message_tx: Sender<Message>,
290        message_rx: Receiver<Message>,
291    ) -> Result<Self, Error> {
292        let timer = Arc::new(Timer::new());
293        let (socket_updates_tx, socket_updates_rx) = async_channel::unbounded();
294
295        multi
296            .socket_function(move |socket, events, key| {
297                let _ = socket_updates_tx.try_send((socket, events, key));
298            })
299            .map_err(Error::from_any)?;
300
301        multi
302            .timer_function({
303                let timer = timer.clone();
304
305                move |timeout| match timeout {
306                    Some(timeout) => {
307                        timer.start(timeout);
308                        true
309                    }
310                    None => {
311                        timer.stop();
312                        true
313                    }
314                }
315            })
316            .map_err(Error::from_any)?;
317
318        Ok(Self {
319            multi,
320            message_tx,
321            message_rx,
322            requests: Slab::new(),
323            close_requested: false,
324            waker: selector.waker(),
325            selector,
326            timer,
327            socket_updates: socket_updates_rx,
328        })
329    }
330
331    #[tracing::instrument(level = "trace", skip(self))]
332    fn begin_request(&mut self, mut request: EasyHandle) -> Result<(), Error> {
333        // Prepare an entry for storing this request while it executes.
334        let entry = self.requests.vacant_entry();
335        let id = entry.key();
336        let handle = request.raw();
337
338        // Initialize the handler.
339        request.get_mut().init(
340            id,
341            handle,
342            {
343                let tx = self.message_tx.clone();
344
345                self.waker
346                    .chain(move |inner| match tx.try_send(Message::UnpauseRead(id)) {
347                        Ok(()) => inner.wake_by_ref(),
348                        Err(_) => {
349                            tracing::warn!(id, "agent went away while resuming read for request")
350                        }
351                    })
352            },
353            {
354                let tx = self.message_tx.clone();
355
356                self.waker
357                    .chain(move |inner| match tx.try_send(Message::UnpauseWrite(id)) {
358                        Ok(()) => inner.wake_by_ref(),
359                        Err(_) => {
360                            tracing::warn!(id, "agent went away while resuming write for request")
361                        }
362                    })
363            },
364        );
365
366        // Register the request with curl.
367        let mut handle = self.multi.add2(request).map_err(Error::from_any)?;
368        handle.set_token(id).map_err(Error::from_any)?;
369
370        // Add the handle to our bookkeeping structure.
371        entry.insert(handle);
372
373        Ok(())
374    }
375
376    #[tracing::instrument(level = "trace", skip(self))]
377    fn complete_request(
378        &mut self,
379        token: usize,
380        result: Result<(), curl::Error>,
381    ) -> Result<(), Error> {
382        let handle = self.requests.remove(token);
383        let mut handle = self.multi.remove2(handle).map_err(Error::from_any)?;
384
385        handle.get_mut().set_result(result.map_err(Error::from_any));
386
387        Ok(())
388    }
389
390    /// Polls the message channel for new messages from any agent handles.
391    ///
392    /// If there are no active requests right now, this function will block
393    /// until a message is received.
394    #[tracing::instrument(level = "trace", skip(self))]
395    fn poll_messages(&mut self) -> Result<(), Error> {
396        while !self.close_requested {
397            if self.requests.is_empty() {
398                match block_on(self.message_rx.recv()) {
399                    Ok(message) => self.handle_message(message)?,
400                    _ => {
401                        tracing::warn!("agent handle disconnected without close message");
402                        self.close_requested = true;
403                        break;
404                    }
405                }
406            } else {
407                match self.message_rx.try_recv() {
408                    Ok(message) => self.handle_message(message)?,
409                    Err(async_channel::TryRecvError::Empty) => break,
410                    Err(async_channel::TryRecvError::Closed) => {
411                        tracing::warn!("agent handle disconnected without close message");
412                        self.close_requested = true;
413                        break;
414                    }
415                }
416            }
417        }
418
419        Ok(())
420    }
421
422    #[tracing::instrument(level = "trace", skip(self))]
423    fn handle_message(&mut self, message: Message) -> Result<(), Error> {
424        tracing::trace!("received message from agent handle");
425
426        match message {
427            Message::Close => self.close_requested = true,
428            Message::Execute(request) => self.begin_request(request)?,
429            Message::UnpauseRead(token) => {
430                if let Some(request) = self.requests.get(token) {
431                    if let Err(e) = request.unpause_read() {
432                        // If unpausing returned an error, it is likely because
433                        // curl called our callback inline and the callback
434                        // returned an error. Unfortunately this does not affect
435                        // the normal state of the transfer, so we need to keep
436                        // the transfer alive until it errors through the normal
437                        // means, which is likely to happen this turn of the
438                        // event loop anyway.
439                        tracing::debug!(id = token, "error unpausing read for request: {:?}", e);
440                    }
441                } else {
442                    tracing::warn!(
443                        "received unpause request for unknown request token: {}",
444                        token
445                    );
446                }
447            }
448            Message::UnpauseWrite(token) => {
449                if let Some(request) = self.requests.get(token) {
450                    if let Err(e) = request.unpause_write() {
451                        // If unpausing returned an error, it is likely because
452                        // curl called our callback inline and the callback
453                        // returned an error. Unfortunately this does not affect
454                        // the normal state of the transfer, so we need to keep
455                        // the transfer alive until it errors through the normal
456                        // means, which is likely to happen this turn of the
457                        // event loop anyway.
458                        tracing::debug!(id = token, "error unpausing write for request: {:?}", e);
459                    }
460                } else {
461                    tracing::warn!(
462                        "received unpause request for unknown request token: {}",
463                        token
464                    );
465                }
466            }
467        }
468
469        Ok(())
470    }
471
472    /// Run the agent in the current thread until requested to stop.
473    fn run(mut self) -> Result<(), Error> {
474        let mut multi_messages = Vec::new();
475
476        // Agent main loop.
477        loop {
478            self.poll_messages()?;
479
480            if self.close_requested {
481                break;
482            }
483
484            // Block until activity is detected or the timeout passes.
485            self.poll()?;
486
487            // Collect messages from curl about requests that have completed,
488            // whether successfully or with an error.
489            self.multi.messages(|message| {
490                if let Some(result) = message.result() {
491                    if let Ok(token) = message.token() {
492                        multi_messages.push((token, result));
493                    }
494                }
495            });
496
497            for (token, result) in multi_messages.drain(..) {
498                self.complete_request(token, result)?;
499            }
500        }
501
502        tracing::debug!("agent shutting down");
503
504        self.requests.clear();
505
506        Ok(())
507    }
508
509    /// Block until activity is detected or a timeout passes.
510    fn poll(&mut self) -> Result<(), Error> {
511        let now = Instant::now();
512        let timeout = self.timer.get_remaining(now);
513
514        // Get the latest timeout value from curl that we should use, limited to
515        // a maximum we chose.
516        let poll_timeout = timeout.map(|t| t.min(WAIT_TIMEOUT)).unwrap_or(WAIT_TIMEOUT);
517
518        // Block until either an I/O event occurs on a socket, the timeout is
519        // reached, or the agent handle interrupts us.
520        if self.selector.poll(poll_timeout)? {
521            // At least one I/O event occurred, handle them.
522            for (socket, readable, writable) in self.selector.events() {
523                tracing::trace!(socket, readable, writable, "socket event");
524                let mut events = Events::new();
525                events.input(readable);
526                events.output(writable);
527                self.multi
528                    .action(socket, &events)
529                    .map_err(Error::from_any)?;
530            }
531        }
532
533        // If curl gave us a timeout, check if it has expired.
534        if self.timer.is_expired(now) {
535            self.timer.stop();
536            self.multi.timeout().map_err(Error::from_any)?;
537        }
538
539        // Apply any requested socket updates now.
540        while let Ok((socket, events, _)) = self.socket_updates.try_recv() {
541            // Curl is asking us to stop polling this socket.
542            if events.remove() {
543                self.selector.deregister(socket).unwrap();
544            } else {
545                let readable = events.input() || events.input_and_output();
546                let writable = events.output() || events.input_and_output();
547
548                self.selector.register(socket, readable, writable).unwrap();
549            }
550        }
551
552        Ok(())
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    static_assertions::assert_impl_all!(Handle: Send, Sync);
561    static_assertions::assert_impl_all!(Message: Send);
562}