abcperf_generic_client/cs/
http_warp.rs

1use std::future::Future;
2use std::net::SocketAddr;
3use std::pin::Pin;
4
5use crate::cs::{json, CSConfig};
6use anyhow::{anyhow, Result};
7use async_trait::async_trait;
8use bytes::Bytes;
9use futures::FutureExt;
10use reqwest::{
11    header::{HeaderValue, CONTENT_TYPE},
12    Client, StatusCode, Url,
13};
14use serde::{Deserialize, Serialize};
15use tokio::sync::oneshot;
16use warp::Filter;
17
18use crate::cs::{CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
19
20#[derive(Debug, Clone, Copy, Default)]
21pub struct HttpWarp;
22
23impl CSTrait for HttpWarp {
24    type Config = HttpWarpConfig;
25
26    fn configure(self, _ca: Vec<u8>, _priv_key: Vec<u8>, _cert: Vec<u8>) -> Self::Config {
27        HttpWarpConfig
28    }
29}
30
31#[derive(Clone)]
32pub struct HttpWarpConfig;
33
34impl CSConfig for HttpWarpConfig {
35    type Client = HttpWarpClient;
36
37    type Server = HttpWarpServer;
38
39    fn client(&self) -> Self::Client {
40        let client = Client::builder()
41            .no_gzip() // we want no compression since dummy payloads contain only zeros
42            .no_brotli() // we want no compression since dummy payloads contain only zeros
43            .no_deflate() // we want no compression since dummy payloads contain only zeros
44            .no_proxy() // always use direct connection
45            .http2_prior_knowledge()
46            .build()
47            .expect("static config always valid");
48
49        HttpWarpClient { client }
50    }
51
52    fn server(&self, local_socket: SocketAddr) -> Self::Server {
53        HttpWarpServer { local_socket }
54    }
55}
56
57#[derive(Debug)]
58pub struct HttpWarpClient {
59    client: Client,
60}
61
62#[async_trait]
63impl CSTraitClient for HttpWarpClient {
64    type Connection = HttpWarpClientConnection;
65
66    async fn connect(&self, addr: SocketAddr, _server_name: &str) -> Result<Self::Connection> {
67        let mut url = Url::parse("http://somehost/").expect("valid url");
68        url.set_ip_host(addr.ip()).expect("valid host");
69        url.set_port(Some(addr.port())).expect("valid port");
70
71        Ok(HttpWarpClientConnection {
72            client: self.client.clone(),
73            url,
74        })
75    }
76
77    fn local_addr(&self) -> SocketAddr {
78        unimplemented!()
79    }
80}
81
82#[derive(Debug)]
83pub struct HttpWarpClientConnection {
84    url: Url,
85    client: Client,
86}
87
88#[async_trait]
89impl CSTraitClientConnection for HttpWarpClientConnection {
90    async fn request<Request: Serialize + Send + Sync, Response: for<'a> Deserialize<'a>>(
91        &mut self,
92        request: Request,
93    ) -> Result<Response> {
94        let body = json::to_vec::<Request>(&request);
95
96        let response = self
97            .client
98            .post(self.url.clone())
99            .header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
100            .body(body)
101            .send()
102            .await?;
103
104        let status_code = response.status();
105        if status_code != StatusCode::OK {
106            let body = response.text().await?;
107            return Err(anyhow!(
108                "wrong status code it was {:?} {}",
109                status_code,
110                body
111            ));
112        }
113
114        let content_type = response.headers().get(CONTENT_TYPE);
115        if content_type != Some(&HeaderValue::from_static("application/json")) {
116            return Err(anyhow!("content type not json it was {:?}", content_type));
117        }
118
119        let response = response.bytes().await?;
120
121        let response = json::from_slice::<Response>(&response)?;
122
123        Ok(response)
124    }
125}
126
127#[derive(Debug)]
128pub struct HttpWarpServer {
129    local_socket: SocketAddr,
130}
131
132impl CSTraitServer for HttpWarpServer {
133    fn start<
134        S: Clone + Send + Sync + 'static,
135        Request: for<'a> Deserialize<'a> + Send + 'static,
136        Response: Serialize + Send + 'static,
137        Fut: Future<Output = Option<Response>> + Send + 'static,
138        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
139    >(
140        self,
141        shared: S,
142        on_request: F,
143        exit: oneshot::Receiver<()>,
144    ) -> (SocketAddr, Pin<Box<dyn Future<Output = ()> + Send>>) {
145        let send = warp::post()
146            .and(warp::header::exact("content-type", "application/json"))
147            .and(warp::body::bytes())
148            .and_then(move |body: Bytes| {
149                let shared = shared.clone();
150                let on_request = on_request.clone();
151                async move {
152                    let input = match json::from_slice::<Request>(&body) {
153                        Ok(input) => input,
154                        Err(err) => {
155                            tracing::warn!("request json body error: {}", err);
156                            return Err(warp::reject::reject());
157                        }
158                    };
159                    match on_request(shared, input).await {
160                        Some(resp) => Ok(resp),
161                        None => Err(warp::reject::reject()),
162                    }
163                }
164            })
165            .map(|o| {
166                let mut res = warp::reply::Response::new(json::to_vec::<Response>(&o).into());
167                res.headers_mut()
168                    .insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
169                res
170            });
171
172        let (addr, server) =
173            warp::serve(send).bind_with_graceful_shutdown(self.local_socket, async {
174                exit.await.ok();
175            });
176
177        let join_handle = tokio::spawn(server);
178
179        (addr, Box::pin(join_handle.map(|_| ())))
180    }
181}