1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use crate::{
    Codec, FramedTransport, IntoSplit, RawTransport, RawTransportRead, RawTransportWrite, Request,
    Response, TypedAsyncRead, TypedAsyncWrite,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{
    ops::{Deref, DerefMut},
    sync::Arc,
};
use tokio::{
    io,
    sync::mpsc,
    task::{JoinError, JoinHandle},
};

mod channel;
pub use channel::*;

mod ext;
pub use ext::*;

/// Represents a client that can be used to send requests & receive responses from a server
pub struct Client<T, U>
where
    T: Send + Sync + Serialize + 'static,
    U: Send + Sync + DeserializeOwned + 'static,
{
    /// Used to send requests to a server
    channel: Channel<T, U>,

    /// Contains the task that is running to send requests to a server
    request_task: JoinHandle<()>,

    /// Contains the task that is running to receive responses from a server
    response_task: JoinHandle<()>,
}

impl<T, U> Client<T, U>
where
    T: Send + Sync + Serialize,
    U: Send + Sync + DeserializeOwned,
{
    /// Initializes a client using the provided reader and writer
    pub fn new<R, W>(mut writer: W, mut reader: R) -> io::Result<Self>
    where
        R: TypedAsyncRead<Response<U>> + Send + 'static,
        W: TypedAsyncWrite<Request<T>> + Send + 'static,
    {
        let post_office = Arc::new(PostOffice::default());
        let weak_post_office = Arc::downgrade(&post_office);

        // Start a task that continually checks for responses and delivers them using the
        // post office
        let response_task = tokio::spawn(async move {
            loop {
                match reader.read().await {
                    Ok(Some(res)) => {
                        // Try to send response to appropriate mailbox
                        // TODO: How should we handle false response? Did logging in past
                        post_office.deliver_response(res).await;
                    }
                    Ok(None) => {
                        break;
                    }
                    Err(_) => {
                        break;
                    }
                }
            }
        });

        let (tx, mut rx) = mpsc::channel::<Request<T>>(1);
        let request_task = tokio::spawn(async move {
            while let Some(req) = rx.recv().await {
                if writer.write(req).await.is_err() {
                    break;
                }
            }
        });

        let channel = Channel {
            tx,
            post_office: weak_post_office,
        };

        Ok(Self {
            channel,
            request_task,
            response_task,
        })
    }

    /// Initializes a client using the provided framed transport
    pub fn from_framed_transport<TR, C>(transport: FramedTransport<TR, C>) -> io::Result<Self>
    where
        TR: RawTransport + IntoSplit + 'static,
        <TR as IntoSplit>::Read: RawTransportRead,
        <TR as IntoSplit>::Write: RawTransportWrite,
        C: Codec + Send + 'static,
    {
        let (writer, reader) = transport.into_split();
        Self::new(writer, reader)
    }

    /// Convert into underlying channel
    pub fn into_channel(self) -> Channel<T, U> {
        self.channel
    }

    /// Clones the underlying channel for requests and returns the cloned instance
    pub fn clone_channel(&self) -> Channel<T, U> {
        self.channel.clone()
    }

    /// Waits for the client to terminate, which results when the receiving end of the network
    /// connection is closed (or the client is shutdown)
    pub async fn wait(self) -> Result<(), JoinError> {
        tokio::try_join!(self.request_task, self.response_task).map(|_| ())
    }

    /// Abort the client's current connection by forcing its tasks to abort
    pub fn abort(&self) {
        self.request_task.abort();
        self.response_task.abort();
    }

    /// Returns true if client's underlying event processing has finished/terminated
    pub fn is_finished(&self) -> bool {
        self.request_task.is_finished() && self.response_task.is_finished()
    }
}

impl<T, U> Deref for Client<T, U>
where
    T: Send + Sync + Serialize + 'static,
    U: Send + Sync + DeserializeOwned + 'static,
{
    type Target = Channel<T, U>;

    fn deref(&self) -> &Self::Target {
        &self.channel
    }
}

impl<T, U> DerefMut for Client<T, U>
where
    T: Send + Sync + Serialize + 'static,
    U: Send + Sync + DeserializeOwned + 'static,
{
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.channel
    }
}

impl<T, U> From<Client<T, U>> for Channel<T, U>
where
    T: Send + Sync + Serialize + 'static,
    U: Send + Sync + DeserializeOwned + 'static,
{
    fn from(client: Client<T, U>) -> Self {
        client.channel
    }
}