trillium_testing/
runtimeless.rs

1use crate::{spawn, TestTransport};
2use async_channel::{Receiver, Sender};
3use dashmap::DashMap;
4use once_cell::sync::Lazy;
5use std::{
6    future::Future,
7    io::{Error, ErrorKind, Result},
8    pin::Pin,
9};
10use trillium::Info;
11use trillium_server_common::{Acceptor, Config, ConfigExt, Connector, Server};
12use url::Url;
13
14type Servers = Lazy<DashMap<(String, u16), (Sender<TestTransport>, Receiver<TestTransport>)>>;
15static SERVERS: Servers = Lazy::new(Default::default);
16
17/// A [`Server`] for testing that does not depend on any runtime
18#[derive(Debug)]
19pub struct RuntimelessServer {
20    host: String,
21    port: u16,
22    channel: Receiver<TestTransport>,
23}
24
25impl Server for RuntimelessServer {
26    type Transport = TestTransport;
27    const DESCRIPTION: &'static str = "test server";
28    fn accept(&mut self) -> Pin<Box<dyn Future<Output = Result<Self::Transport>> + Send + '_>> {
29        Box::pin(async move {
30            self.channel
31                .recv()
32                .await
33                .map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
34        })
35    }
36
37    fn build_listener<A>(config: &Config<Self, A>) -> Self
38    where
39        A: Acceptor<Self::Transport>,
40    {
41        let mut port = config.port();
42        let host = config.host();
43        if port == 0 {
44            loop {
45                port = fastrand::u16(..);
46                if !SERVERS.contains_key(&(host.clone(), port)) {
47                    break;
48                }
49            }
50        }
51
52        let entry = SERVERS
53            .entry((host.clone(), port))
54            .or_insert_with(async_channel::unbounded);
55
56        let (_, channel) = entry.value();
57
58        Self {
59            host,
60            channel: channel.clone(),
61            port,
62        }
63    }
64
65    fn info(&self) -> Info {
66        Info::from(&*format!("{}:{}", &self.host, &self.port))
67    }
68
69    fn block_on(fut: impl Future<Output = ()> + 'static) {
70        crate::block_on(fut)
71    }
72
73    fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
74        spawn(fut);
75    }
76}
77
78impl Drop for RuntimelessServer {
79    fn drop(&mut self) {
80        SERVERS.remove(&(self.host.clone(), self.port));
81    }
82}
83
84/// An in-memory Connector to use with GenericServer.
85#[derive(Default, Debug, Clone, Copy)]
86pub struct RuntimelessClientConfig(());
87
88impl RuntimelessClientConfig {
89    /// constructs a GenericClientConfig
90    pub fn new() -> Self {
91        Self(())
92    }
93}
94
95#[trillium::async_trait]
96impl Connector for RuntimelessClientConfig {
97    type Transport = TestTransport;
98    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
99        let (tx, _) = &*SERVERS
100            .get(&(
101                url.host_str().unwrap().to_string(),
102                url.port_or_known_default().unwrap(),
103            ))
104            .ok_or(Error::new(ErrorKind::AddrNotAvailable, "not available"))?;
105        let (client_transport, server_transport) = TestTransport::new();
106        tx.send(server_transport).await.unwrap();
107        Ok(client_transport)
108    }
109
110    fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
111        spawn(fut);
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use super::*;
118    use crate::{harness, TestResult};
119    use test_harness::test;
120    #[test(harness)]
121    async fn round_trip() -> TestResult {
122        let handle1 = Config::<RuntimelessServer, ()>::new()
123            .with_host("host.com")
124            .with_port(80)
125            .spawn("server 1");
126        handle1.info().await;
127
128        let handle2 = Config::<RuntimelessServer, ()>::new()
129            .with_host("other_host.com")
130            .with_port(80)
131            .spawn("server 2");
132        handle2.info().await;
133
134        let client = trillium_client::Client::new(RuntimelessClientConfig::default());
135        let mut conn = client.get("http://host.com").await?;
136        assert_eq!(conn.response_body().await?, "server 1");
137
138        let mut conn = client.get("http://other_host.com").await?;
139        assert_eq!(conn.response_body().await?, "server 2");
140
141        handle1.stop().await;
142        assert!(client.get("http://host.com").await.is_err());
143        assert!(client.get("http://other_host.com").await.is_ok());
144
145        handle2.stop().await;
146        assert!(client.get("http://other_host.com").await.is_err());
147
148        assert!(SERVERS.is_empty());
149
150        Ok(())
151    }
152}