1use std::{
2 fmt::{self, Debug, Formatter},
3 future::Future,
4};
5
6use crate::{
7 body::Body,
8 decode,
9 request::{self, BoxRequest},
10 response::BoxResponse,
11 Response,
12};
13
14use self::{
15 boxed::BoxedTransport,
16 transport::{SocketChannels, SocketRequestMarker, TransportError},
17};
18
19use super::Request;
20use error::*;
21use futures_util::TryFutureExt;
22use socket::*;
23use tower::{Layer, Service};
24
25pub mod boxed;
27pub mod error;
29pub mod layer;
31pub mod socket;
33pub mod transport;
45
46#[doc(hidden)]
47pub mod prelude {
48 pub use super::{
49 error::{ClientError, ClientResult},
50 socket::Socket,
51 transport::TransportError,
52 Client,
53 };
54 pub use crate::{
55 request::{BoxRequest, IntoRequest, Request},
56 response::{BoxResponse, Response},
57 };
58 pub use std::{borrow::Cow, convert::TryInto, fmt::Debug, future::Future};
59 pub use tower::Service;
60}
61
62pub struct Client<Inner> {
64 transport: Inner,
65}
66
67impl<Inner: Debug> Debug for Client<Inner> {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 f.debug_struct("Client")
70 .field("inner", &self.transport)
71 .finish()
72 }
73}
74
75impl<Inner: Clone> Clone for Client<Inner> {
76 fn clone(&self) -> Self {
77 Self {
78 transport: self.transport.clone(),
79 }
80 }
81}
82
83impl<Inner> Client<Inner> {
84 pub fn new(transport: Inner) -> Client<Inner> {
86 Client { transport }
87 }
88}
89
90impl<Inner, InnerErr> Client<Inner>
91where
92 Inner: Service<BoxRequest, Response = BoxResponse, Error = TransportError<InnerErr>>
93 + Send
94 + Clone
95 + 'static,
96 Inner::Future: Send,
97 InnerErr: std::error::Error + Sync + Send + 'static,
98{
99 pub fn boxed(self) -> Client<BoxedTransport> {
102 Client {
103 transport: BoxedTransport::new(self.transport),
104 }
105 }
106}
107
108impl<Inner, InnerErr> Client<Inner>
109where
110 Inner: Service<BoxRequest, Response = BoxResponse, Error = TransportError<InnerErr>> + 'static,
111 InnerErr: 'static,
112{
113 pub fn layer<S, L>(self, l: L) -> Client<S>
115 where
116 L: Layer<Inner, Service = S>,
117 S: Service<BoxRequest>,
118 {
119 Client {
120 transport: l.layer(self.transport),
121 }
122 }
123
124 pub fn execute_request<Req, Resp>(
126 &mut self,
127 req: Request<Req>,
128 ) -> impl Future<Output = ClientResult<Response<Resp>, InnerErr>> + 'static
129 where
130 Req: prost::Message,
131 Resp: prost::Message + Default,
132 {
133 Service::call(&mut self.transport, req.map::<()>())
134 .map_ok(|resp| resp.map::<Resp>())
135 .map_err(ClientError::from)
136 }
137
138 pub fn connect_socket<Req, Resp>(
140 &mut self,
141 mut req: Request<()>,
142 ) -> impl Future<Output = ClientResult<Socket<Req, Resp>, InnerErr>> + 'static
143 where
144 Req: prost::Message,
145 Resp: prost::Message + Default,
146 {
147 req.extensions_mut().insert(SocketRequestMarker);
148 Service::call(&mut self.transport, req)
149 .map_ok(|mut resp| {
150 let chans = resp
151 .extensions_mut()
152 .remove::<SocketChannels>()
153 .expect("transport did not return socket channels - this is a bug");
154
155 Socket::new(
156 chans.rx,
157 chans.tx,
158 socket::encode_message,
159 socket::decode_message,
160 )
161 })
162 .map_err(ClientError::from)
163 }
164
165 pub fn connect_socket_req<Req, Resp>(
169 &mut self,
170 request: Request<Req>,
171 ) -> impl Future<Output = ClientResult<Socket<Req, Resp>, InnerErr>> + 'static
172 where
173 Req: prost::Message + Default + 'static,
174 Resp: prost::Message + Default + 'static,
175 {
176 let request::Parts {
177 body,
178 extensions,
179 endpoint,
180 ..
181 } = request.into();
182
183 let request: BoxRequest = Request::from(request::Parts {
184 body: Body::empty(),
185 endpoint: endpoint.clone(),
186 extensions,
187 });
188
189 let connect_fut = self.connect_socket(request);
190
191 async move {
192 let mut socket = connect_fut.await?;
193
194 let message = decode::decode_body(body).await?;
195 socket
196 .send_message(message)
197 .await
198 .map_err(|err| match err {
199 SocketError::MessageDecode(err) => ClientError::MessageDecode(err),
200 SocketError::Protocol(err) => ClientError::EndpointError {
201 hrpc_error: err,
202 endpoint,
203 },
204 SocketError::Transport(err) => ClientError::EndpointError {
207 hrpc_error: HrpcError::from(err).with_identifier("hrpcrs.socket-error"),
208 endpoint,
209 },
210 })?;
211
212 Ok(socket)
213 }
214 }
215}