use std::sync::{Arc, Mutex};
use crate::easy::Easy;
use crate::error::Error;
use crate::protocol::http::response::Response;
use crate::share::Share;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PipeliningMode {
#[default]
Nothing,
Multiplex,
}
#[derive(Debug)]
pub struct TransferMessage {
pub index: usize,
pub result: Result<Response, Error>,
}
#[derive(Debug)]
pub struct Multi {
handles: Vec<Easy>,
max_total_connections: Option<usize>,
max_host_connections: Option<usize>,
pipelining: PipeliningMode,
share: Option<Share>,
messages: Arc<Mutex<Vec<TransferMessage>>>,
}
impl Default for Multi {
fn default() -> Self {
Self::new()
}
}
impl Multi {
#[must_use]
pub fn new() -> Self {
Self {
handles: Vec::new(),
max_total_connections: None,
max_host_connections: None,
pipelining: PipeliningMode::Nothing,
share: None,
messages: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn add(&mut self, easy: Easy) {
self.handles.push(easy);
}
pub fn remove(&mut self, index: usize) -> Option<Easy> {
if index < self.handles.len() {
Some(self.handles.remove(index))
} else {
None
}
}
#[must_use]
pub fn len(&self) -> usize {
self.handles.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.handles.is_empty()
}
pub const fn max_total_connections(&mut self, max: usize) {
self.max_total_connections = Some(max);
}
pub const fn max_host_connections(&mut self, max: usize) {
self.max_host_connections = Some(max);
}
pub const fn pipelining(&mut self, mode: PipeliningMode) {
self.pipelining = mode;
}
#[must_use]
pub const fn pipelining_mode(&self) -> PipeliningMode {
self.pipelining
}
pub fn set_share(&mut self, share: Share) {
self.share = Some(share);
}
#[allow(clippy::option_if_let_else)]
pub fn info_read(&mut self) -> Option<TransferMessage> {
if let Ok(mut msgs) = self.messages.lock() {
if msgs.is_empty() {
None
} else {
Some(msgs.remove(0))
}
} else {
None
}
}
#[must_use]
pub fn messages_in_queue(&self) -> usize {
self.messages.lock().map_or(0, |m| m.len())
}
pub async fn perform(&mut self) -> Vec<Result<Response, Error>> {
let mut handles: Vec<Easy> = self.handles.drain(..).collect();
if handles.is_empty() {
return Vec::new();
}
if let Some(ref share) = self.share {
for handle in &mut handles {
handle.set_share(share.clone());
}
}
let results = if let Some(max_conns) = self.max_total_connections {
perform_with_limit(handles, max_conns).await
} else {
perform_unlimited(handles).await
};
if let Ok(mut msgs) = self.messages.lock() {
for (idx, result) in results.iter().enumerate() {
msgs.push(TransferMessage {
index: idx,
result: match result {
Ok(r) => Ok(r.clone()),
Err(e) => Err(Error::Http(e.to_string())),
},
});
}
}
results
}
pub fn perform_blocking(&mut self) -> Result<Vec<Result<Response, Error>>, Error> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| Error::Http(format!("failed to create runtime: {e}")))?;
Ok(rt.block_on(self.perform()))
}
}
async fn perform_unlimited(handles: Vec<Easy>) -> Vec<Result<Response, Error>> {
let mut join_set = tokio::task::JoinSet::new();
for (idx, mut easy) in handles.into_iter().enumerate() {
let _handle = join_set.spawn(async move { (idx, easy.perform_async().await) });
}
collect_results(&mut join_set).await
}
async fn perform_with_limit(handles: Vec<Easy>, max_conns: usize) -> Vec<Result<Response, Error>> {
let semaphore = Arc::new(tokio::sync::Semaphore::new(max_conns));
let mut join_set = tokio::task::JoinSet::new();
for (idx, mut easy) in handles.into_iter().enumerate() {
let sem = semaphore.clone();
let _handle = join_set.spawn(async move {
let _permit =
sem.acquire().await.map_err(|e| Error::Http(format!("semaphore error: {e}")));
(idx, easy.perform_async().await)
});
}
collect_results(&mut join_set).await
}
async fn collect_results(
join_set: &mut tokio::task::JoinSet<(usize, Result<Response, Error>)>,
) -> Vec<Result<Response, Error>> {
let mut results: Vec<(usize, Result<Response, Error>)> = Vec::with_capacity(join_set.len());
while let Some(join_result) = join_set.join_next().await {
match join_result {
Ok(indexed_result) => results.push(indexed_result),
Err(join_err) => {
results.push((
usize::MAX,
Err(Error::Http(format!("transfer task failed: {join_err}"))),
));
}
}
}
results.sort_by_key(|(idx, _)| *idx);
results.into_iter().map(|(_, result)| result).collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn multi_new_is_empty() {
let multi = Multi::new();
assert!(multi.is_empty());
assert_eq!(multi.len(), 0);
}
#[test]
fn multi_add_increases_count() {
let mut multi = Multi::new();
let easy = Easy::new();
multi.add(easy);
assert!(!multi.is_empty());
assert_eq!(multi.len(), 1);
}
#[test]
fn multi_default() {
let multi = Multi::default();
assert!(multi.is_empty());
}
#[tokio::test]
async fn multi_perform_empty() {
let mut multi = Multi::new();
let results = multi.perform().await;
assert!(results.is_empty());
}
#[tokio::test]
async fn multi_perform_drains_handles() {
let mut multi = Multi::new();
let mut easy = Easy::new();
let _ = easy.url("http://127.0.0.1:1");
multi.add(easy);
assert_eq!(multi.len(), 1);
let _results = multi.perform().await;
assert!(multi.is_empty(), "handles should be drained after perform");
}
#[test]
fn multi_remove_valid_index() {
let mut multi = Multi::new();
multi.add(Easy::new());
multi.add(Easy::new());
assert_eq!(multi.len(), 2);
let removed = multi.remove(0);
assert!(removed.is_some());
assert_eq!(multi.len(), 1);
}
#[test]
fn multi_remove_invalid_index() {
let mut multi = Multi::new();
assert!(multi.remove(0).is_none());
}
#[test]
fn multi_max_total_connections() {
let mut multi = Multi::new();
multi.max_total_connections(4);
assert_eq!(multi.max_total_connections, Some(4));
}
#[test]
fn multi_max_host_connections() {
let mut multi = Multi::new();
multi.max_host_connections(2);
assert_eq!(multi.max_host_connections, Some(2));
}
#[test]
fn multi_messages_initially_empty() {
let multi = Multi::new();
assert_eq!(multi.messages_in_queue(), 0);
}
#[test]
fn multi_info_read_empty() {
let mut multi = Multi::new();
assert!(multi.info_read().is_none());
}
#[tokio::test]
async fn multi_perform_stores_messages() {
let mut multi = Multi::new();
let mut easy = Easy::new();
let _ = easy.url("http://127.0.0.1:1");
easy.connect_timeout(std::time::Duration::from_millis(50));
multi.add(easy);
let _results = multi.perform().await;
assert_eq!(multi.messages_in_queue(), 1);
let msg = multi.info_read().unwrap();
assert_eq!(msg.index, 0);
assert!(msg.result.is_err());
assert_eq!(multi.messages_in_queue(), 0);
}
#[tokio::test]
async fn multi_perform_unlimited_ordering() {
let handles: Vec<Easy> = (0..5)
.map(|_| {
let mut e = Easy::new();
let _ = e.url("http://127.0.0.1:1");
e.connect_timeout(std::time::Duration::from_millis(10));
e
})
.collect();
let results = perform_unlimited(handles).await;
assert_eq!(results.len(), 5);
for r in &results {
assert!(r.is_err());
}
}
#[tokio::test]
async fn multi_perform_with_limit() {
let handles: Vec<Easy> = (0..5)
.map(|_| {
let mut e = Easy::new();
let _ = e.url("http://127.0.0.1:1");
e.connect_timeout(std::time::Duration::from_millis(10));
e
})
.collect();
let results = perform_with_limit(handles, 2).await;
assert_eq!(results.len(), 5);
}
#[test]
fn multi_pipelining_default() {
let multi = Multi::new();
assert_eq!(multi.pipelining_mode(), PipeliningMode::Nothing);
}
#[test]
fn multi_pipelining_set() {
let mut multi = Multi::new();
multi.pipelining(PipeliningMode::Multiplex);
assert_eq!(multi.pipelining_mode(), PipeliningMode::Multiplex);
}
#[test]
fn multi_set_share() {
let mut multi = Multi::new();
let mut share = crate::share::Share::new();
share.add(crate::share::ShareType::Dns);
multi.set_share(share);
assert!(multi.share.is_some());
}
#[tokio::test]
async fn multi_perform_attaches_share() {
let mut share = crate::share::Share::new();
share.add(crate::share::ShareType::Dns);
let mut multi = Multi::new();
multi.set_share(share);
let mut easy = Easy::new();
let _ = easy.url("http://127.0.0.1:1");
easy.connect_timeout(std::time::Duration::from_millis(10));
multi.add(easy);
let results = multi.perform().await;
assert_eq!(results.len(), 1);
assert!(results[0].is_err());
}
}