abcperf_generic_client/cs/
typed.rs

1use std::{fmt::Debug, future::Future, marker::PhantomData, net::SocketAddr};
2
3use anyhow::Result;
4use derivative::Derivative;
5use serde::{Deserialize, Serialize};
6use tokio::sync::oneshot;
7
8use crate::cs::{CSConfig, CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
9
10#[derive(Derivative)]
11#[derivative(
12    Debug(bound = "CS: Debug"),
13    Clone(bound = ""),
14    Copy(bound = ""),
15    Default(bound = "")
16)]
17pub struct TypedCSTrait<
18    CS: CSTrait,
19    Request: Serialize + Send + Sync,
20    Response: for<'a> Deserialize<'a>,
21> {
22    cs: CS,
23    phantom_data: PhantomData<(Request, Response)>,
24}
25
26impl<
27        CS: CSTrait,
28        Request: Serialize + for<'a> Deserialize<'a> + Send + Sync + 'static,
29        Response: Serialize + for<'a> Deserialize<'a> + Send + 'static,
30    > TypedCSTrait<CS, Request, Response>
31{
32    pub fn new(cs: CS) -> Self {
33        Self {
34            cs,
35            phantom_data: PhantomData::default(),
36        }
37    }
38
39    pub fn configure(
40        self,
41        ca: Vec<u8>,
42        priv_key: Vec<u8>,
43        cert: Vec<u8>,
44    ) -> TypedCSConfig<CS::Config, Request, Response> {
45        TypedCSConfig {
46            cs_config: self.cs.configure(ca, priv_key, cert),
47            phantom_data: self.phantom_data,
48        }
49    }
50
51    pub fn configure_debug(self) -> TypedCSConfig<CS::Config, Request, Response> {
52        self.configure(
53            abcperf::CERTIFICATE_AUTHORITY.to_vec(),
54            abcperf::PRIVATE_KEY.to_vec(),
55            abcperf::CERTIFICATE.to_vec(),
56        )
57    }
58}
59
60#[derive(Derivative)]
61#[derivative(Clone(bound = ""))]
62pub struct TypedCSConfig<
63    Config: CSConfig,
64    Request: Serialize + Send + Sync,
65    Response: for<'a> Deserialize<'a>,
66> {
67    cs_config: Config,
68    phantom_data: PhantomData<(Request, Response)>,
69}
70
71impl<
72        Config: CSConfig,
73        Request: Serialize + for<'a> Deserialize<'a> + Send + Sync,
74        Response: Serialize + for<'a> Deserialize<'a> + Send,
75    > TypedCSConfig<Config, Request, Response>
76{
77    pub fn client(&self) -> TypedCSTraitClient<Config::Client, Request, Response> {
78        TypedCSTraitClient {
79            cs_client: self.cs_config.client(),
80            phantom_data: self.phantom_data,
81        }
82    }
83
84    pub fn server(
85        &self,
86        local_socket: impl Into<SocketAddr>,
87    ) -> TypedCSTraitServer<Config::Server, Request, Response> {
88        TypedCSTraitServer {
89            cs_server: self.cs_config.server(local_socket.into()),
90            phantom_data: PhantomData,
91        }
92    }
93}
94
95#[derive(Derivative)]
96#[derivative(Debug(bound = "CSClient: Debug"))]
97pub struct TypedCSTraitClient<
98    CSClient: CSTraitClient,
99    Request: Serialize + Send + Sync,
100    Response: for<'a> Deserialize<'a>,
101> {
102    cs_client: CSClient,
103    phantom_data: PhantomData<(Request, Response)>,
104}
105
106impl<
107        CSClient: CSTraitClient,
108        Request: Serialize + Send + Sync,
109        Response: for<'a> Deserialize<'a>,
110    > TypedCSTraitClient<CSClient, Request, Response>
111{
112    pub async fn connect(
113        &self,
114        addr: impl Into<SocketAddr>,
115        server_name: &str,
116    ) -> Result<TypedCSTraitClientConnection<CSClient::Connection, Request, Response>> {
117        self.cs_client.connect(addr.into(), server_name).await.map(
118            |cs_client_connection: <CSClient as CSTraitClient>::Connection| {
119                TypedCSTraitClientConnection {
120                    cs_client_connection,
121                    phantom_data: PhantomData,
122                }
123            },
124        )
125    }
126
127    pub fn local_addr(&self) -> SocketAddr {
128        self.cs_client.local_addr()
129    }
130}
131
132#[derive(Derivative)]
133#[derivative(Debug(bound = "CSClientConnection: Debug"))]
134pub struct TypedCSTraitClientConnection<CSClientConnection, Request, Response> {
135    cs_client_connection: CSClientConnection,
136    phantom_data: PhantomData<(Request, Response)>,
137}
138
139impl<
140        CSClientConnection: CSTraitClientConnection,
141        Request: Serialize + Send + Sync,
142        Response: for<'a> Deserialize<'a>,
143    > TypedCSTraitClientConnection<CSClientConnection, Request, Response>
144{
145    pub async fn request(&mut self, request: impl Into<Request>) -> Result<Response> {
146        self.cs_client_connection.request(request.into()).await
147    }
148}
149
150#[derive(Derivative)]
151#[derivative(Debug(bound = "CSServer: Debug"))]
152pub struct TypedCSTraitServer<
153    CSServer: CSTraitServer,
154    Request: for<'a> Deserialize<'a> + Send,
155    Response: Serialize + Send,
156> {
157    cs_server: CSServer,
158    phantom_data: PhantomData<(Request, Response)>,
159}
160
161impl<
162        CSServer: CSTraitServer,
163        Request: for<'a> Deserialize<'a> + Send + 'static,
164        Response: Serialize + Send + 'static,
165    > TypedCSTraitServer<CSServer, Request, Response>
166{
167    pub fn start<
168        S: Clone + Send + Sync + 'static,
169        Fut: Future<Output = Option<Response>> + Send + 'static,
170        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
171    >(
172        self,
173        shared: S,
174        on_request: F,
175        exit: oneshot::Receiver<()>,
176    ) -> (SocketAddr, impl Future<Output = ()> + Send) {
177        self.cs_server.start(shared, on_request, exit)
178    }
179}