capybara_core/upstream/
weighted.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use capybara_util::WeightedResource;
6
7use crate::{CapybaraError, Result};
8
9use super::pools::{Pool, Pools};
10
11pub struct WeightedPools(WeightedResource<Arc<Pool>>);
12
13#[async_trait]
14impl Pools for WeightedPools {
15    async fn next(&self, _: u64) -> Result<Arc<Pool>> {
16        match self.0.next() {
17            None => Err(CapybaraError::InvalidUpstreamPool),
18            Some(next) => Ok(Clone::clone(next)),
19        }
20    }
21}
22
23impl From<WeightedResource<Arc<Pool>>> for WeightedPools {
24    fn from(value: WeightedResource<Arc<Pool>>) -> Self {
25        Self(value)
26    }
27}
28
29#[cfg(test)]
30mod tests {
31    use tokio::sync::Notify;
32
33    use crate::transport::tcp::TcpStreamPoolBuilder;
34
35    use super::*;
36
37    fn init() {
38        pretty_env_logger::try_init_timed().ok();
39    }
40
41    #[tokio::test]
42    async fn test_weighted_pools() -> anyhow::Result<()> {
43        init();
44
45        let closer = Arc::new(Notify::new());
46
47        let new_pool = |domain: &str| {
48            let closer = Clone::clone(&closer);
49            let bu = TcpStreamPoolBuilder::with_domain(domain, 80);
50            async { bu.build(closer).await.map(|it| Arc::new(Pool::Tcp(it))) }
51        };
52
53        let pools = {
54            let p = WeightedResource::<Arc<Pool>>::builder()
55                .push(50, new_pool("httpbin.org").await?)
56                .push(50, new_pool("httpbingo.org").await?)
57                .build();
58            WeightedPools::from(p)
59        };
60
61        assert!(pools.next(0).await.is_ok());
62
63        Ok(())
64    }
65}