1use std::collections::HashMap;
2use std::io;
3
4use std::marker::Unpin;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use futures::channel::{mpsc, oneshot};
9use futures::io::{AsyncRead, AsyncWrite};
10use futures::{ready, Future, FutureExt, Sink, Stream, TryFutureExt};
11use rmpv::Value;
12use tokio_util::codec::{Decoder, Framed};
13use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt};
14
15use crate::codec::Codec;
16use crate::message::Response as MsgPackResponse;
17use crate::message::{Message, Notification, Request};
18use crate::cakeservice::{ Service, ServiceWithClient };
19
20
21struct Server<S> {
22 service: S,
23 pending_responses: mpsc::UnboundedReceiver<(u32, Result<Value, Value>)>,
24 response_sender: mpsc::UnboundedSender<(u32, Result<Value, Value>)>,
25}
26
27impl<S: ServiceWithClient> Server<S> {
28 fn new(service: S) -> Self {
29 let (send, recv) = mpsc::unbounded();
30
31 Server {
32 service,
33 pending_responses: recv,
34 response_sender: send,
35 }
36 }
37
38 fn send_responses<T: AsyncRead + AsyncWrite>(
39 &mut self,
40 cx: &mut Context,
41 mut sink: Pin<&mut Transport<T>>,
42 ) -> Poll<io::Result<()>> {
43 trace!("Server: flushing responses");
44 loop {
45 ready!(sink.as_mut().poll_ready(cx)?);
46 match Pin::new(&mut self.pending_responses).poll_next(cx) {
47 Poll::Ready(Some((id, result))) => {
48 let msg = Message::Response(MsgPackResponse { id, result });
49 sink.as_mut().start_send(msg).unwrap();
50 }
51 Poll::Ready(None) => panic!("we store the sender, it can't be dropped"),
52 Poll::Pending => return sink.as_mut().poll_flush(cx),
53 }
54 }
55 }
56
57 fn spawn_request_worker<F: Future<Output = Result<Value, Value>> + 'static + Send>(
58 &self,
59 id: u32,
60 f: F,
61 ) {
62 trace!("spawning a new task");
63 trace!("spawning the task on the event loop");
64 let send = self.response_sender.clone();
65 tokio::task::spawn(f.map(move |result| send.unbounded_send((id, result))));
66 }
67}
68
69trait MessageHandler {
70 fn handle_incoming(&mut self, msg: Message);
71
72 fn send_outgoing<T: AsyncRead + AsyncWrite>(
73 &mut self,
74 cx: &mut Context,
75 sink: Pin<&mut Transport<T>>,
76 ) -> Poll<io::Result<()>>;
77
78 fn is_finished(&self) -> bool {
79 false
80 }
81}
82
83type ResponseTx = oneshot::Sender<Result<Value, Value>>;
84
85pub struct Response(oneshot::Receiver<Result<Value, Value>>);
86
87type AckTx = oneshot::Sender<()>;
88
89pub struct Ack(oneshot::Receiver<()>);
90
91type RequestTx = mpsc::UnboundedSender<(Request, ResponseTx)>;
93type RequestRx = mpsc::UnboundedReceiver<(Request, ResponseTx)>;
94
95type NotificationTx = mpsc::UnboundedSender<(Notification, AckTx)>;
96type NotificationRx = mpsc::UnboundedReceiver<(Notification, AckTx)>;
97
98impl Future for Response {
99 type Output = Result<Value, Value>;
100
101 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
103 trace!("=== Response: polling, 触发了Response的Future poll fn ===");
104 Poll::Ready(match ready!(Pin::new(&mut self.0).poll(cx)) {
105 Ok(Ok(v)) => Ok(v),
106 Ok(Err(v)) => Err(v),
107 Err(_) => Err(Value::Nil),
108 })
109 }
110}
111
112impl Future for Ack {
113 type Output = Result<(), ()>;
114
115 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
116 trace!("Ack: polling");
117 Pin::new(&mut self.0).poll(cx).map_err(|_| ())
118 }
119}
120
121struct InnerClient {
122 client_closed: bool,
123 request_id: u32,
124 requests_rx: RequestRx,
125 notifications_rx: NotificationRx,
126 pending_requests: HashMap<u32, ResponseTx>,
127 pending_notifications: Vec<AckTx>,
128}
129
130impl InnerClient {
131 fn new() -> (Self, Client) {
132 let (requests_tx, requests_rx) = mpsc::unbounded();
133 let (notifications_tx, notifications_rx) = mpsc::unbounded();
134
135 let client = Client {
136 requests_tx,
137 notifications_tx
138 };
139
140 let inner_client = InnerClient {
141 client_closed: false,
142 request_id: 0,
143 requests_rx,
144 notifications_rx,
145 pending_requests: HashMap::new(),
146 pending_notifications: Vec::new(),
147 };
148
149 (inner_client, client)
150 }
151
152 fn process_notifications<T: AsyncRead + AsyncWrite>(
153 &mut self,
154 cx: &mut Context,
155 mut stream: Pin<&mut Transport<T>>,
156 ) -> io::Result<()> {
157 if self.client_closed {
158 return Ok(());
159 }
160
161 trace!("Polling client notifications channel");
162
163 while let Poll::Ready(()) = stream.as_mut().poll_ready(cx)? {
164 match Pin::new(&mut self.notifications_rx).poll_next(cx) {
165 Poll::Ready(Some((notification, ack_sender))) => {
166 trace!("Got notification from client.");
167 stream
168 .as_mut()
169 .start_send(Message::Notification(notification))?;
170 self.pending_notifications.push(ack_sender);
171 }
172 Poll::Ready(None) => {
173 trace!("Client closed the notifications channel.");
174 self.client_closed = true;
175 break;
176 }
177 Poll::Pending => {
178 trace!("No new notification from client");
179 break;
180 }
181 }
182 }
183 Ok(())
184 }
185
186 fn send_messages<T: AsyncRead + AsyncWrite>(
187 &mut self,
188 cx: &mut Context,
189 mut stream: Pin<&mut Transport<T>>,
190 ) -> Poll<io::Result<()>> {
191 self.process_requests(cx, stream.as_mut())?;
192 self.process_notifications(cx, stream.as_mut())?;
193
194 match stream.poll_flush(cx)? {
195 Poll::Ready(()) => {
196 self.acknowledge_notifications();
197 Poll::Ready(Ok(()))
198 }
199 Poll::Pending => Poll::Pending,
200 }
201 }
202
203 fn process_requests<T: AsyncRead + AsyncWrite>(
204 &mut self,
205 cx: &mut Context,
206 mut stream: Pin<&mut Transport<T>>,
207 ) -> io::Result<()> {
208 if self.client_closed {
211 return Ok(());
212 }
213 trace!("Polling client requests channel");
214 while let Poll::Ready(()) = stream.as_mut().poll_ready(cx)? {
215 match Pin::new(&mut self.requests_rx).poll_next(cx) {
216 Poll::Ready(Some((mut request, response_sender))) => {
217 self.request_id += 1;
218 request.id = self.request_id;
220 trace!("=== Send Message to Service-serv: {:?}", request);
221 stream.as_mut().start_send(Message::Request(request))?;
222 self.pending_requests
223 .insert(self.request_id, response_sender);
224 }
225 Poll::Ready(None) => {
226 trace!("Client closed the requests channel.");
227 self.client_closed = true;
228 break;
229 }
230 Poll::Pending => {
231 trace!("No new request from client");
232 break;
233 }
234 }
235 }
236 Ok(())
237 }
238
239 fn process_response(&mut self, response: MsgPackResponse) {
240 trace!("一个客户端的请求处理完成,response.id为{},\
241 在pennding_requests中去掉这个id的key", &response.id);
242 if let Some(response_tx) = self.pending_requests.remove(&response.id) {
243 trace!("协程转发数据给客户端主线程 == Forwarding response to the client.");
244 if let Err(e) = response_tx.send(response.result) {
245 warn!("Failed to send response to client: {:?}", e);
246 }
247 } else {
248 warn!("no pending request found for response {}", &response.id);
249 }
250 }
251
252 fn acknowledge_notifications(&mut self) {
253 for chan in self.pending_notifications.drain(..) {
254 trace!("Acknowledging notification.");
255 if let Err(e) = chan.send(()) {
256 warn!("Failed to send ack to client: {:?}", e);
257 }
258 }
259 }
260}
261
262struct Transport<T>(Framed<Compat<T>, Codec>);
263impl<T> Transport<T>
266 where
267 T: AsyncRead + AsyncWrite,
268{
269 fn inner(self: Pin<&mut Self>) -> Pin<&mut Framed<Compat<T>, Codec>> {
270 trace!("=== Transport inner 返回Transport Framed ===");
271 unsafe { self.map_unchecked_mut(|this| &mut this.0) }
272 }
273}
274
275impl<T> Stream for Transport<T>
276 where
277 T: AsyncRead + AsyncWrite,
278{
279 type Item = io::Result<Message>;
280
281 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
282 trace!("=== Transport polling动作,一旦有Transport,就会触发下面的逻辑 ===");
284 self.inner().poll_next(cx)
285 }
286}
287
288impl<T> Sink<Message> for Transport<T>
289 where
290 T: AsyncRead + AsyncWrite,
291{
292 type Error = io::Error;
293
294 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
295 self.inner().poll_ready(cx)
296 }
297
298 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
299 self.inner().start_send(item)
300 }
301
302 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
303 self.inner().poll_flush(cx)
304 }
305
306 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
307 self.inner().poll_close(cx)
308 }
309}
310
311impl<S: Service> MessageHandler for Server<S> {
312 fn handle_incoming(&mut self, msg: Message) {
314 trace!("=== impl MessageHandler for Server handle_incoming ===");
315 match msg {
316 Message::Request(req) => {
317 let f = self.service.handle_request(&req.method, &req.params);
319 self.spawn_request_worker(req.id, f);
320 }
321 Message::Notification(note) => {
322 self.service.handle_notification(¬e.method, ¬e.params);
323 }
324 Message::Response(_) => {
325 trace!("This endpoint doesn't handle responses, ignoring the msg.");
326 }
327 };
328 }
329
330 fn send_outgoing<T: AsyncRead + AsyncWrite>(
331 &mut self,
332 cx: &mut Context,
333 sink: Pin<&mut Transport<T>>,
334 ) -> Poll<io::Result<()>> {
335 self.send_responses(cx, sink)
336 }
337}
338
339impl MessageHandler for InnerClient {
340 fn handle_incoming(&mut self, msg: Message) {
342 trace!("=== 接收到服务端数据, impl MessageHandler for InnerClient handle_incoming ===");
343 trace!("handle_incoming Received {:?}", msg);
344 if let Message::Response(response) = msg {
345 self.process_response(response);
346 } else {
347 trace!("This endpoint only handles reponses, ignoring the msg.");
348 }
349 }
350
351 fn send_outgoing<T: AsyncRead + AsyncWrite>(
352 &mut self,
353 cx: &mut Context,
354 sink: Pin<&mut Transport<T>>,
355 ) -> Poll<io::Result<()>> {
356 trace!("=== impl MessageHandler for InnerClient invoke send_outgoing ===");
357 self.send_messages(cx, sink)
358 }
359
360 fn is_finished(&self) -> bool {
361 self.client_closed
362 && self.pending_requests.is_empty()
363 && self.pending_notifications.is_empty()
364 }
365}
366
367struct ClientAndServer<S> {
368 inner_client: InnerClient,
369 server: Server<S>,
370 client: Client,
371}
372
373impl<S: ServiceWithClient> MessageHandler for ClientAndServer<S> {
374 fn handle_incoming(&mut self, msg: Message) {
375 trace!("=== impl MessageHandler for ClientAndServer<S> handle_incoming ===");
376 match msg {
377 Message::Request(req) => {
378 let f =
379 self.server
380 .service
381 .handle_request(&mut self.client, &req.method, &req.params);
382 self.server.spawn_request_worker(req.id, f);
383 }
384 Message::Notification(note) => {
385 self.server.service.handle_notification(
386 &mut self.client,
387 ¬e.method,
388 ¬e.params,
389 );
390 }
391 Message::Response(response) => self.inner_client.process_response(response),
392 };
393 }
394
395 fn send_outgoing<T: AsyncRead + AsyncWrite>(
396 &mut self,
397 cx: &mut Context,
398 mut sink: Pin<&mut Transport<T>>,
399 ) -> Poll<io::Result<()>> {
400 if let Poll::Ready(()) = self.server.send_responses(cx, sink.as_mut())? {
401 self.inner_client.send_messages(cx, sink)
402 } else {
403 Poll::Pending
404 }
405 }
406}
407
408struct InnerEndpoint<MH, T> {
409 handler: MH,
410 stream: Transport<T>,
411}
412
413impl<MH: MessageHandler + Unpin, T: AsyncRead + AsyncWrite> Future for InnerEndpoint<MH, T> {
415 type Output = io::Result<()>;
416
417 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
418 trace!("=== 触发Future for InnerEndpoint 的 poll, InnerEndpoint: polling ===");
419 let (handler, mut stream) = unsafe {
420 let this = self.get_unchecked_mut();
421 (&mut this.handler, Pin::new_unchecked(&mut this.stream))
422 };
423 trace!("=== InnerEndpoint handler.send_outgoing, 客户端在这里发送数据给服务端! ===");
424 if let Poll::Pending = handler.send_outgoing(cx, stream.as_mut())? {
425 trace!("Sink not yet flushed, waiting...");
426 return Poll::Pending;
427 }
428
429 trace!("=== 客户端Polling stream, 轮询stream, 也就是轮询socket事件, 接收服务端的返回! ===");
430 while let Poll::Ready(msg) = stream.as_mut().poll_next(cx)? {
432 trace!("---check msg struct---");
433 if let Some(msg) = msg {
434 trace!("---handle_incoming msg---.");
435 handler.handle_incoming(msg);
436 } else {
437 trace!("Stream closed by remote peer.");
438 return Poll::Ready(Ok(()));
442 }
443 }
444
445 if handler.is_finished() {
446 trace!("inner client finished, exiting...");
447 Poll::Ready(Ok(()))
448 } else {
449 trace!("notifying the reactor that we're not done yet");
450 trace!("=== 这里执行 Poll:Pending, 如果客户端已经没有发送数据给服务端的话,那就是不会触发 InnerEndpoint: polling(通信入口的轮询), 客户端就不会polling socket事件, 客户端程序会退出 ===");
451 Poll::Pending
452 }
453 }
454}
455
456pub fn serve<'a, S: Service + Unpin + 'a, T: AsyncRead + AsyncWrite + 'a + Send>(
461 stream: T,
462 service: S,
463) -> impl Future<Output = io::Result<()>> + 'a + Send {
464 ServerEndpoint::new(stream, service)
465}
466
467struct ServerEndpoint<S, T> {
468 inner: InnerEndpoint<Server<S>, T>,
469}
470
471impl<S: Service + Unpin, T: AsyncRead + AsyncWrite> ServerEndpoint<S, T> {
472 pub fn new(stream: T, service: S) -> Self {
473 let stream = FuturesAsyncWriteCompatExt::compat_write(stream);
474 ServerEndpoint {
475 inner: InnerEndpoint {
476 stream: Transport(Codec.framed(stream)),
477 handler: Server::new(service),
478 },
479 }
480 }
481}
482
483impl<S: Service + Unpin, T: AsyncRead + AsyncWrite> Future for ServerEndpoint<S, T> {
484 type Output = io::Result<()>;
485
486 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
487 trace!("ServerEndpoint: polling");
488 unsafe { self.map_unchecked_mut(|this| &mut this.inner) }.poll(cx)
489 }
490}
491
492pub struct Endpoint<S, T> {
493 inner: InnerEndpoint<ClientAndServer<S>, T>,
494}
495
496impl<S: ServiceWithClient + Unpin, T: AsyncRead + AsyncWrite> Endpoint<S, T> {
497 pub fn new(stream: T, service: S) -> Self {
499 let (inner_client, client) = InnerClient::new();
500 let stream = FuturesAsyncWriteCompatExt::compat_write(stream);
501 Endpoint {
502 inner: InnerEndpoint {
503 stream: Transport(Codec.framed(stream)),
504 handler: ClientAndServer {
505 inner_client,
506 client,
507 server: Server::new(service),
508 },
509 },
510 }
511 }
512
513 pub fn client(&self) -> Client {
516 self.inner.handler.client.clone()
517 }
518}
519
520impl<S: ServiceWithClient + Unpin, T: AsyncRead + AsyncWrite> Future for Endpoint<S, T> {
521 type Output = io::Result<()>;
522
523 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
524 trace!("Endpoint: polling");
525 unsafe { self.map_unchecked_mut(|this| &mut this.inner) }.poll(cx)
526 }
527}
528
529#[derive(Clone)]
531pub struct Client {
532 requests_tx: RequestTx,
533 notifications_tx: NotificationTx,
534}
535
536impl Client {
537
538 pub fn new<T: AsyncRead + AsyncWrite + 'static + Send>(stream: T) -> Self {
547 let (inner_client, client) = InnerClient::new();
548 let stream = FuturesAsyncWriteCompatExt::compat_write(stream);
549 let endpoint = InnerEndpoint {
551 stream: Transport(Codec.framed(stream)),
552 handler: inner_client,
553 };
554 tokio::task::spawn(
559 endpoint.map_err(|e| error!("Client endpoint closed because of an error: {}", e)),
560 );
561
562 client
563 }
564
565 pub fn call(&self, method: &str, params: &[Value]) -> Response {
575 trace!("New call (method={}, params={:?})", method, params);
576 let request = Request {
577 id: 0,
578 method: method.to_owned(),
579 params: Vec::from(params),
580 };
581 let (tx, rx) = oneshot::channel();
582 let _ = mpsc::UnboundedSender::unbounded_send(&self.requests_tx, (request, tx));
588 Response(rx)
589 }
590
591 pub fn call_notify(&self, method: &str, params: &[Value]) -> Ack {
593 trace!("New notification (method={}, params={:?})", method, params);
594 let notification = Notification {
595 method: method.to_owned(),
596 params: Vec::from(params),
597 };
598 let (tx, rx) = oneshot::channel();
599 let _ = mpsc::UnboundedSender::unbounded_send(&self.notifications_tx, (notification, tx));
600 Ack(rx)
601 }
602}
603
604impl Future for Client {
605 type Output = io::Result<()>;
606
607 fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
608 trace!("Client: polling");
609 Poll::Ready(Ok(()))
610 }
611}
612
613