1use crate::{error::Error, handler::RequestHandler, task::WakerExt};
12use async_channel::{Receiver, Sender};
13use crossbeam_utils::{atomic::AtomicCell, sync::WaitGroup};
14use curl::multi::{Events, Multi, Socket, SocketEvents};
15use futures_lite::future::block_on;
16use slab::Slab;
17use std::{
18 io,
19 sync::{Arc, Mutex},
20 task::Waker,
21 thread,
22 time::{Duration, Instant},
23};
24
25use self::{selector::Selector, timer::Timer};
26
27mod selector;
28mod timer;
29
30static NEXT_AGENT_ID: AtomicCell<usize> = AtomicCell::new(0);
31const WAIT_TIMEOUT: Duration = Duration::from_millis(1000);
32
33type EasyHandle = curl::easy::Easy2<RequestHandler>;
34
35#[derive(Debug, Default)]
37pub(crate) struct AgentBuilder {
38 max_connections: usize,
39 max_connections_per_host: usize,
40 connection_cache_size: usize,
41}
42
43impl AgentBuilder {
44 pub(crate) fn max_connections(mut self, max: usize) -> Self {
45 self.max_connections = max;
46 self
47 }
48
49 pub(crate) fn max_connections_per_host(mut self, max: usize) -> Self {
50 self.max_connections_per_host = max;
51 self
52 }
53
54 pub(crate) fn connection_cache_size(mut self, size: usize) -> Self {
55 self.connection_cache_size = size;
56 self
57 }
58
59 pub(crate) fn spawn(&self) -> io::Result<Handle> {
62 let create_start = Instant::now();
63
64 curl::init();
75
76 let id = NEXT_AGENT_ID.fetch_add(1);
77
78 let selector = Selector::new()?;
80
81 let (message_tx, message_rx) = async_channel::unbounded();
82
83 let wait_group = WaitGroup::new();
84 let wait_group_thread = wait_group.clone();
85
86 let max_connections = self.max_connections;
87 let max_connections_per_host = self.max_connections_per_host;
88 let connection_cache_size = self.connection_cache_size;
89
90 let agent_span = tracing::debug_span!("agent_thread", id);
93 agent_span.follows_from(tracing::Span::current());
94
95 let waker = selector.waker();
96 let message_tx_clone = message_tx.clone();
97
98 let thread_main = move || {
99 let _enter = agent_span.enter();
100 let mut multi = Multi::new();
101
102 if max_connections > 0 {
103 multi
104 .set_max_total_connections(max_connections)
105 .map_err(Error::from_any)?;
106 }
107
108 if max_connections_per_host > 0 {
109 multi
110 .set_max_host_connections(max_connections_per_host)
111 .map_err(Error::from_any)?;
112 }
113
114 if connection_cache_size > 0 {
116 multi
117 .set_max_connects(connection_cache_size)
118 .map_err(Error::from_any)?;
119 }
120
121 let agent = AgentContext::new(multi, selector, message_tx_clone, message_rx)?;
122
123 drop(wait_group_thread);
124
125 tracing::debug!("agent took {:?} to start up", create_start.elapsed());
126
127 let result = agent.run();
128
129 if let Err(e) = &result {
130 tracing::error!("agent shut down with error: {:?}", e);
131 }
132
133 result
134 };
135
136 let handle = Handle {
137 message_tx,
138 waker,
139 join_handle: Mutex::new(Some(
140 thread::Builder::new()
141 .name(format!("isahc-agent-{}", id))
142 .spawn(thread_main)?,
143 )),
144 };
145
146 wait_group.wait();
148
149 Ok(handle)
150 }
151}
152
153#[derive(Debug)]
158pub(crate) struct Handle {
159 message_tx: Sender<Message>,
161
162 waker: Waker,
164
165 join_handle: Mutex<Option<thread::JoinHandle<Result<(), Error>>>>,
167}
168
169struct AgentContext {
175 multi: curl::multi::Multi,
177
178 message_tx: Sender<Message>,
180
181 message_rx: Receiver<Message>,
183
184 requests: Slab<curl::multi::Easy2Handle<RequestHandler>>,
186
187 close_requested: bool,
189
190 waker: Waker,
192
193 selector: Selector,
195
196 timer: Arc<Timer>,
198
199 socket_updates: Receiver<(Socket, SocketEvents, usize)>,
201}
202
203#[derive(Debug)]
205enum Message {
206 Close,
208
209 Execute(EasyHandle),
211
212 UnpauseRead(usize),
215
216 UnpauseWrite(usize),
219}
220
221#[derive(Debug)]
222enum JoinResult {
223 AlreadyJoined,
224 Ok,
225 Err(Error),
226 Panic,
227}
228
229impl Handle {
230 pub(crate) fn submit_request(&self, request: EasyHandle) -> Result<(), Error> {
232 self.send_message(Message::Execute(request))
233 }
234
235 fn send_message(&self, message: Message) -> Result<(), Error> {
239 match self.message_tx.try_send(message) {
240 Ok(()) => {
241 self.waker.wake_by_ref();
243 Ok(())
244 }
245 Err(_) => match self.try_join() {
246 JoinResult::Err(e) => panic!("agent thread terminated with error: {:?}", e),
247 JoinResult::Panic => panic!("agent thread panicked"),
248 _ => panic!("agent thread terminated prematurely"),
249 },
250 }
251 }
252
253 fn try_join(&self) -> JoinResult {
254 let mut option = self.join_handle.lock().unwrap();
255
256 if let Some(join_handle) = option.take() {
257 match join_handle.join() {
258 Ok(Ok(())) => JoinResult::Ok,
259 Ok(Err(e)) => JoinResult::Err(e),
260 Err(_) => JoinResult::Panic,
261 }
262 } else {
263 JoinResult::AlreadyJoined
264 }
265 }
266}
267
268impl Drop for Handle {
269 fn drop(&mut self) {
270 if self.send_message(Message::Close).is_err() {
272 tracing::error!("agent thread terminated prematurely");
273 }
274
275 match self.try_join() {
277 JoinResult::Ok => tracing::trace!("agent thread joined cleanly"),
278 JoinResult::Err(e) => tracing::error!("agent thread terminated with error: {}", e),
279 JoinResult::Panic => tracing::error!("agent thread panicked"),
280 _ => {}
281 }
282 }
283}
284
285impl AgentContext {
286 fn new(
287 mut multi: Multi,
288 selector: Selector,
289 message_tx: Sender<Message>,
290 message_rx: Receiver<Message>,
291 ) -> Result<Self, Error> {
292 let timer = Arc::new(Timer::new());
293 let (socket_updates_tx, socket_updates_rx) = async_channel::unbounded();
294
295 multi
296 .socket_function(move |socket, events, key| {
297 let _ = socket_updates_tx.try_send((socket, events, key));
298 })
299 .map_err(Error::from_any)?;
300
301 multi
302 .timer_function({
303 let timer = timer.clone();
304
305 move |timeout| match timeout {
306 Some(timeout) => {
307 timer.start(timeout);
308 true
309 }
310 None => {
311 timer.stop();
312 true
313 }
314 }
315 })
316 .map_err(Error::from_any)?;
317
318 Ok(Self {
319 multi,
320 message_tx,
321 message_rx,
322 requests: Slab::new(),
323 close_requested: false,
324 waker: selector.waker(),
325 selector,
326 timer,
327 socket_updates: socket_updates_rx,
328 })
329 }
330
331 #[tracing::instrument(level = "trace", skip(self))]
332 fn begin_request(&mut self, mut request: EasyHandle) -> Result<(), Error> {
333 let entry = self.requests.vacant_entry();
335 let id = entry.key();
336 let handle = request.raw();
337
338 request.get_mut().init(
340 id,
341 handle,
342 {
343 let tx = self.message_tx.clone();
344
345 self.waker
346 .chain(move |inner| match tx.try_send(Message::UnpauseRead(id)) {
347 Ok(()) => inner.wake_by_ref(),
348 Err(_) => {
349 tracing::warn!(id, "agent went away while resuming read for request")
350 }
351 })
352 },
353 {
354 let tx = self.message_tx.clone();
355
356 self.waker
357 .chain(move |inner| match tx.try_send(Message::UnpauseWrite(id)) {
358 Ok(()) => inner.wake_by_ref(),
359 Err(_) => {
360 tracing::warn!(id, "agent went away while resuming write for request")
361 }
362 })
363 },
364 );
365
366 let mut handle = self.multi.add2(request).map_err(Error::from_any)?;
368 handle.set_token(id).map_err(Error::from_any)?;
369
370 entry.insert(handle);
372
373 Ok(())
374 }
375
376 #[tracing::instrument(level = "trace", skip(self))]
377 fn complete_request(
378 &mut self,
379 token: usize,
380 result: Result<(), curl::Error>,
381 ) -> Result<(), Error> {
382 let handle = self.requests.remove(token);
383 let mut handle = self.multi.remove2(handle).map_err(Error::from_any)?;
384
385 handle.get_mut().set_result(result.map_err(Error::from_any));
386
387 Ok(())
388 }
389
390 #[tracing::instrument(level = "trace", skip(self))]
395 fn poll_messages(&mut self) -> Result<(), Error> {
396 while !self.close_requested {
397 if self.requests.is_empty() {
398 match block_on(self.message_rx.recv()) {
399 Ok(message) => self.handle_message(message)?,
400 _ => {
401 tracing::warn!("agent handle disconnected without close message");
402 self.close_requested = true;
403 break;
404 }
405 }
406 } else {
407 match self.message_rx.try_recv() {
408 Ok(message) => self.handle_message(message)?,
409 Err(async_channel::TryRecvError::Empty) => break,
410 Err(async_channel::TryRecvError::Closed) => {
411 tracing::warn!("agent handle disconnected without close message");
412 self.close_requested = true;
413 break;
414 }
415 }
416 }
417 }
418
419 Ok(())
420 }
421
422 #[tracing::instrument(level = "trace", skip(self))]
423 fn handle_message(&mut self, message: Message) -> Result<(), Error> {
424 tracing::trace!("received message from agent handle");
425
426 match message {
427 Message::Close => self.close_requested = true,
428 Message::Execute(request) => self.begin_request(request)?,
429 Message::UnpauseRead(token) => {
430 if let Some(request) = self.requests.get(token) {
431 if let Err(e) = request.unpause_read() {
432 tracing::debug!(id = token, "error unpausing read for request: {:?}", e);
440 }
441 } else {
442 tracing::warn!(
443 "received unpause request for unknown request token: {}",
444 token
445 );
446 }
447 }
448 Message::UnpauseWrite(token) => {
449 if let Some(request) = self.requests.get(token) {
450 if let Err(e) = request.unpause_write() {
451 tracing::debug!(id = token, "error unpausing write for request: {:?}", e);
459 }
460 } else {
461 tracing::warn!(
462 "received unpause request for unknown request token: {}",
463 token
464 );
465 }
466 }
467 }
468
469 Ok(())
470 }
471
472 fn run(mut self) -> Result<(), Error> {
474 let mut multi_messages = Vec::new();
475
476 loop {
478 self.poll_messages()?;
479
480 if self.close_requested {
481 break;
482 }
483
484 self.poll()?;
486
487 self.multi.messages(|message| {
490 if let Some(result) = message.result() {
491 if let Ok(token) = message.token() {
492 multi_messages.push((token, result));
493 }
494 }
495 });
496
497 for (token, result) in multi_messages.drain(..) {
498 self.complete_request(token, result)?;
499 }
500 }
501
502 tracing::debug!("agent shutting down");
503
504 self.requests.clear();
505
506 Ok(())
507 }
508
509 fn poll(&mut self) -> Result<(), Error> {
511 let now = Instant::now();
512 let timeout = self.timer.get_remaining(now);
513
514 let poll_timeout = timeout.map(|t| t.min(WAIT_TIMEOUT)).unwrap_or(WAIT_TIMEOUT);
517
518 if self.selector.poll(poll_timeout)? {
521 for (socket, readable, writable) in self.selector.events() {
523 tracing::trace!(socket, readable, writable, "socket event");
524 let mut events = Events::new();
525 events.input(readable);
526 events.output(writable);
527 self.multi
528 .action(socket, &events)
529 .map_err(Error::from_any)?;
530 }
531 }
532
533 if self.timer.is_expired(now) {
535 self.timer.stop();
536 self.multi.timeout().map_err(Error::from_any)?;
537 }
538
539 while let Ok((socket, events, _)) = self.socket_updates.try_recv() {
541 if events.remove() {
543 self.selector.deregister(socket).unwrap();
544 } else {
545 let readable = events.input() || events.input_and_output();
546 let writable = events.output() || events.input_and_output();
547
548 self.selector.register(socket, readable, writable).unwrap();
549 }
550 }
551
552 Ok(())
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 static_assertions::assert_impl_all!(Handle: Send, Sync);
561 static_assertions::assert_impl_all!(Message: Send);
562}