rpcx_client/
client.rs

1use std::{
2    cell::RefCell,
3    collections::HashMap,
4    error::Error as StdError,
5    io::{BufReader, BufWriter, Write},
6    net::{Shutdown, SocketAddr, TcpStream},
7    sync::{
8        atomic::{AtomicU64, Ordering},
9        mpsc::{self, Receiver, SendError, Sender},
10        Arc, Mutex,
11    },
12    thread,
13    time::Duration,
14};
15
16use rpcx_protocol::{call::*, *};
17use tokio::runtime::Runtime;
18
19#[derive(Debug, Copy, Clone)]
20pub struct Opt {
21    pub retry: u8,
22    pub compress_type: CompressType,
23    pub serialize_type: SerializeType,
24    pub connect_timeout: Duration,
25    pub read_timeout: Duration,
26    pub write_timeout: Duration,
27    pub nodelay: Option<bool>,
28    pub ttl: Option<u32>,
29}
30
31impl Default for Opt {
32    fn default() -> Self {
33        Opt {
34            retry: 3,
35            compress_type: CompressType::CompressNone,
36            serialize_type: SerializeType::JSON,
37            connect_timeout: Default::default(),
38            read_timeout: Default::default(),
39            write_timeout: Default::default(),
40            nodelay: None,
41            ttl: None,
42        }
43    }
44}
45
46#[derive(Debug, Default)]
47struct RpcData {
48    seq: u64,
49    data: Vec<u8>,
50}
51
52/// a direct client to connect rpcx services.
53#[derive(Debug)]
54pub struct Client {
55    pub opt: Opt,
56    addr: String,
57    stream: Option<TcpStream>,
58    seq: Arc<AtomicU64>,
59    chan_sender: Sender<RpcData>,
60    chan_receiver: Arc<Mutex<Receiver<RpcData>>>,
61    calls: Arc<Mutex<HashMap<u64, ArcCall>>>,
62}
63
64impl Client {
65    pub fn new(addr: &str) -> Client {
66        let (sender, receiver) = mpsc::channel();
67
68        Client {
69            opt: Default::default(),
70            addr: String::from(addr),
71            stream: None,
72            seq: Arc::new(AtomicU64::new(0)),
73            chan_sender: sender,
74            chan_receiver: Arc::new(Mutex::new(receiver)),
75            calls: Arc::new(Mutex::new(HashMap::new())),
76        }
77    }
78    pub fn start(&mut self) -> Result<()> {
79        let stream = if self.opt.connect_timeout.as_millis() == 0 {
80            TcpStream::connect(self.addr.as_str())?
81        } else {
82            let socket_addr: SocketAddr = self
83                .addr
84                .parse()
85                .map_err(|err| Error::new(ErrorKind::Network, err))?;
86            TcpStream::connect_timeout(&socket_addr, self.opt.connect_timeout)?
87        };
88
89        if self.opt.read_timeout.as_millis() > 0 {
90            stream.set_read_timeout(Some(self.opt.read_timeout))?;
91        }
92        if self.opt.write_timeout.as_millis() > 0 {
93            stream.set_write_timeout(Some(self.opt.read_timeout))?;
94        }
95
96        if self.opt.nodelay.is_some() {
97            stream.set_nodelay(self.opt.nodelay.unwrap())?;
98        }
99        if self.opt.ttl.is_some() {
100            stream.set_ttl(self.opt.ttl.unwrap())?;
101        }
102        let read_stream = stream.try_clone()?;
103        let write_stream = stream.try_clone()?;
104        self.stream = Some(stream);
105
106        let calls = self.calls.clone();
107        thread::spawn(move || {
108            let mut reader = BufReader::new(read_stream.try_clone().unwrap());
109
110            loop {
111                let mut msg = Message::new();
112                match msg.decode(&mut reader) {
113                    Ok(()) => {
114                        if let Some(call) = calls.lock().unwrap().remove(&msg.get_seq()) {
115                            let internal_call_cloned = call.clone();
116                            let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
117                            let internal_call = internal_call_mutex.get_mut();
118                            internal_call.is_client_error = false;
119                            if let Some(MessageStatusType::Error) = msg.get_message_status_type() {
120                                internal_call.error =
121                                    msg.get_error().unwrap_or_else(|| "".to_owned());
122                            } else {
123                                internal_call.reply_data.extend_from_slice(&msg.payload);
124                            }
125
126                            let mut status = internal_call.state.lock().unwrap();
127                            status.ready = true;
128                            if let Some(ref task) = status.task {
129                                task.clone().wake()
130                            }
131                        }
132                    }
133                    Err(err) => {
134                        println!("failed to read: {}", err.to_string());
135                        Self::drain_calls(calls, err);
136                        match read_stream.shutdown(Shutdown::Both) {
137                            Ok(_) => {}
138                            Err(err) => eprintln!("failed to shutdown stream: {}", err),
139                        }
140                        return;
141                    }
142                }
143            }
144        });
145
146        let chan_receiver = self.chan_receiver.clone();
147        let send_calls = self.calls.clone();
148        thread::spawn(move || {
149            let mut writer = BufWriter::new(write_stream.try_clone().unwrap());
150            loop {
151                match chan_receiver.lock().unwrap().recv() {
152                    Err(_err) => {
153                        //eprintln!("failed to fetch RpcData: {}", err.to_string());
154                        write_stream.shutdown(Shutdown::Both).unwrap();
155                        return;
156                    }
157                    Ok(rpcdata) => {
158                        match writer.write_all(rpcdata.data.as_slice()) {
159                            Ok(()) => {
160                                //println!("wrote");
161                            }
162                            Err(err) => {
163                                //println!("failed to write: {}", err.to_string());
164                                Self::drain_calls(send_calls.clone(), err);
165                                write_stream.shutdown(Shutdown::Both).unwrap();
166                                return;
167                            }
168                        }
169
170                        match writer.flush() {
171                            Ok(()) => {
172                                //println!("flushed");
173                            }
174                            Err(err) => {
175                                //println!("failed to flush: {}", err.to_string());
176                                Self::drain_calls(send_calls.clone(), err);
177                                write_stream.shutdown(Shutdown::Both).unwrap();
178                                return;
179                            }
180                        }
181                    }
182                }
183            }
184        });
185
186        Ok(())
187    }
188
189    pub fn send(
190        &self,
191        service_path: &str,
192        service_method: &str,
193        is_oneway: bool,
194        is_heartbeat: bool,
195        metadata: &Metadata,
196        args: &dyn RpcxParam,
197    ) -> CallFuture {
198        let seq = self.seq.clone().fetch_add(1, Ordering::SeqCst);
199
200        let mut req = Message::new();
201        req.set_version(0);
202        req.set_message_type(MessageType::Request);
203        req.set_serialize_type(self.opt.serialize_type);
204        req.set_compress_type(self.opt.compress_type);
205        req.set_seq(seq);
206        req.service_path = service_path.to_string();
207        req.service_method = service_method.to_string();
208
209        let mut new_metadata = HashMap::with_capacity(metadata.len());
210        for (k, v) in metadata {
211            new_metadata.insert(k.clone(), v.clone());
212        }
213        req.metadata.replace(new_metadata);
214        let payload = args.into_bytes(self.opt.serialize_type).unwrap();
215        req.payload = payload;
216
217        let data = req.encode();
218
219        let call_future = if !is_oneway && !is_heartbeat {
220            let callback = Call::new(seq);
221            let arc_call = Arc::new(Mutex::new(RefCell::from(callback)));
222            self.calls
223                .clone()
224                .lock()
225                .unwrap()
226                .insert(seq, arc_call.clone());
227
228            CallFuture::new(Some(arc_call))
229        } else {
230            CallFuture::new(None)
231        };
232
233        let send_data = RpcData { seq, data };
234        match self.chan_sender.clone().send(send_data) {
235            Ok(_) => {}
236            Err(err) => self.remove_call_with_senderr(err),
237        }
238
239        call_future
240    }
241
242    fn remove_call_with_senderr(&self, err: SendError<RpcData>) {
243        let seq = err.0.seq;
244        let calls = self.calls.clone();
245        let mut m = calls.lock().unwrap();
246        if let Some(call) = m.remove(&seq) {
247            let internal_call_cloned = call.clone();
248            let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
249            let internal_call = internal_call_mutex.get_mut();
250            internal_call.error = String::from(err.description());
251            let mut status = internal_call.state.lock().unwrap();
252            status.ready = true;
253            if let Some(ref task) = status.task {
254                task.clone().wake()
255            }
256        }
257    }
258
259    fn drain_calls<T: StdError>(calls: Arc<Mutex<HashMap<u64, ArcCall>>>, err: T) {
260        let mut m = calls.lock().unwrap();
261        for (_, call) in m.drain().take(1) {
262            let internal_call_cloned = call.clone();
263            let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
264            let internal_call = internal_call_mutex.get_mut();
265            internal_call.error = String::from(err.description());
266            let mut status = internal_call.state.lock().unwrap();
267            status.ready = true;
268            if let Some(ref task) = status.task {
269                task.clone().wake()
270            }
271        }
272    }
273
274    #[allow(dead_code)]
275    fn remove_call_with_err<T: StdError>(&mut self, seq: u64, err: T) {
276        let calls = self.calls.clone();
277        let m = calls.lock().unwrap();
278        if let Some(call) = m.get(&seq) {
279            let internal_call_cloned = call.clone();
280            let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
281            let internal_call = internal_call_mutex.get_mut();
282            internal_call.error = String::from(err.description());
283            let mut status = internal_call.state.lock().unwrap();
284            status.ready = true;
285            if let Some(ref task) = status.task {
286                task.clone().wake()
287            }
288        }
289    }
290
291    pub fn call<T>(
292        &mut self,
293        service_path: &str,
294        service_method: &str,
295        is_oneway: bool,
296        metadata: &Metadata,
297        args: &dyn RpcxParam,
298    ) -> Option<Result<T>>
299    where
300        T: RpcxParam + Default,
301    {
302        let rt = Runtime::new().unwrap();
303        let callfuture = rt.block_on(async {
304            let f = self.send(
305                service_path,
306                service_method,
307                is_oneway,
308                false,
309                metadata,
310                args,
311            );
312            f.await
313        });
314
315        if is_oneway {
316            return None;
317        }
318
319        let arc_call_1 = callfuture.unwrap().clone();
320        let mut arc_call_2 = arc_call_1.lock().unwrap();
321        let arc_call_3 = arc_call_2.get_mut();
322        let reply_data = &arc_call_3.reply_data;
323
324        if !arc_call_3.error.is_empty() {
325            let err = &arc_call_3.error;
326            if arc_call_3.is_client_error {
327                return Some(Err(Error::new(ErrorKind::Client, String::from(err))));
328            } else {
329                return Some(Err(Error::from(String::from(err))));
330            }
331        }
332
333        let mut reply: T = Default::default();
334        match reply.from_slice(self.opt.serialize_type, &reply_data) {
335            Ok(()) => Some(Ok(reply)),
336            Err(err) => Some(Err(err)),
337        }
338    }
339}