1use std::{
2 cell::RefCell,
3 collections::HashMap,
4 error::Error as StdError,
5 io::{BufReader, BufWriter, Write},
6 net::{Shutdown, SocketAddr, TcpStream},
7 sync::{
8 atomic::{AtomicU64, Ordering},
9 mpsc::{self, Receiver, SendError, Sender},
10 Arc, Mutex,
11 },
12 thread,
13 time::Duration,
14};
15
16use rpcx_protocol::{call::*, *};
17use tokio::runtime::Runtime;
18
19#[derive(Debug, Copy, Clone)]
20pub struct Opt {
21 pub retry: u8,
22 pub compress_type: CompressType,
23 pub serialize_type: SerializeType,
24 pub connect_timeout: Duration,
25 pub read_timeout: Duration,
26 pub write_timeout: Duration,
27 pub nodelay: Option<bool>,
28 pub ttl: Option<u32>,
29}
30
31impl Default for Opt {
32 fn default() -> Self {
33 Opt {
34 retry: 3,
35 compress_type: CompressType::CompressNone,
36 serialize_type: SerializeType::JSON,
37 connect_timeout: Default::default(),
38 read_timeout: Default::default(),
39 write_timeout: Default::default(),
40 nodelay: None,
41 ttl: None,
42 }
43 }
44}
45
46#[derive(Debug, Default)]
47struct RpcData {
48 seq: u64,
49 data: Vec<u8>,
50}
51
52#[derive(Debug)]
54pub struct Client {
55 pub opt: Opt,
56 addr: String,
57 stream: Option<TcpStream>,
58 seq: Arc<AtomicU64>,
59 chan_sender: Sender<RpcData>,
60 chan_receiver: Arc<Mutex<Receiver<RpcData>>>,
61 calls: Arc<Mutex<HashMap<u64, ArcCall>>>,
62}
63
64impl Client {
65 pub fn new(addr: &str) -> Client {
66 let (sender, receiver) = mpsc::channel();
67
68 Client {
69 opt: Default::default(),
70 addr: String::from(addr),
71 stream: None,
72 seq: Arc::new(AtomicU64::new(0)),
73 chan_sender: sender,
74 chan_receiver: Arc::new(Mutex::new(receiver)),
75 calls: Arc::new(Mutex::new(HashMap::new())),
76 }
77 }
78 pub fn start(&mut self) -> Result<()> {
79 let stream = if self.opt.connect_timeout.as_millis() == 0 {
80 TcpStream::connect(self.addr.as_str())?
81 } else {
82 let socket_addr: SocketAddr = self
83 .addr
84 .parse()
85 .map_err(|err| Error::new(ErrorKind::Network, err))?;
86 TcpStream::connect_timeout(&socket_addr, self.opt.connect_timeout)?
87 };
88
89 if self.opt.read_timeout.as_millis() > 0 {
90 stream.set_read_timeout(Some(self.opt.read_timeout))?;
91 }
92 if self.opt.write_timeout.as_millis() > 0 {
93 stream.set_write_timeout(Some(self.opt.read_timeout))?;
94 }
95
96 if self.opt.nodelay.is_some() {
97 stream.set_nodelay(self.opt.nodelay.unwrap())?;
98 }
99 if self.opt.ttl.is_some() {
100 stream.set_ttl(self.opt.ttl.unwrap())?;
101 }
102 let read_stream = stream.try_clone()?;
103 let write_stream = stream.try_clone()?;
104 self.stream = Some(stream);
105
106 let calls = self.calls.clone();
107 thread::spawn(move || {
108 let mut reader = BufReader::new(read_stream.try_clone().unwrap());
109
110 loop {
111 let mut msg = Message::new();
112 match msg.decode(&mut reader) {
113 Ok(()) => {
114 if let Some(call) = calls.lock().unwrap().remove(&msg.get_seq()) {
115 let internal_call_cloned = call.clone();
116 let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
117 let internal_call = internal_call_mutex.get_mut();
118 internal_call.is_client_error = false;
119 if let Some(MessageStatusType::Error) = msg.get_message_status_type() {
120 internal_call.error =
121 msg.get_error().unwrap_or_else(|| "".to_owned());
122 } else {
123 internal_call.reply_data.extend_from_slice(&msg.payload);
124 }
125
126 let mut status = internal_call.state.lock().unwrap();
127 status.ready = true;
128 if let Some(ref task) = status.task {
129 task.clone().wake()
130 }
131 }
132 }
133 Err(err) => {
134 println!("failed to read: {}", err.to_string());
135 Self::drain_calls(calls, err);
136 match read_stream.shutdown(Shutdown::Both) {
137 Ok(_) => {}
138 Err(err) => eprintln!("failed to shutdown stream: {}", err),
139 }
140 return;
141 }
142 }
143 }
144 });
145
146 let chan_receiver = self.chan_receiver.clone();
147 let send_calls = self.calls.clone();
148 thread::spawn(move || {
149 let mut writer = BufWriter::new(write_stream.try_clone().unwrap());
150 loop {
151 match chan_receiver.lock().unwrap().recv() {
152 Err(_err) => {
153 write_stream.shutdown(Shutdown::Both).unwrap();
155 return;
156 }
157 Ok(rpcdata) => {
158 match writer.write_all(rpcdata.data.as_slice()) {
159 Ok(()) => {
160 }
162 Err(err) => {
163 Self::drain_calls(send_calls.clone(), err);
165 write_stream.shutdown(Shutdown::Both).unwrap();
166 return;
167 }
168 }
169
170 match writer.flush() {
171 Ok(()) => {
172 }
174 Err(err) => {
175 Self::drain_calls(send_calls.clone(), err);
177 write_stream.shutdown(Shutdown::Both).unwrap();
178 return;
179 }
180 }
181 }
182 }
183 }
184 });
185
186 Ok(())
187 }
188
189 pub fn send(
190 &self,
191 service_path: &str,
192 service_method: &str,
193 is_oneway: bool,
194 is_heartbeat: bool,
195 metadata: &Metadata,
196 args: &dyn RpcxParam,
197 ) -> CallFuture {
198 let seq = self.seq.clone().fetch_add(1, Ordering::SeqCst);
199
200 let mut req = Message::new();
201 req.set_version(0);
202 req.set_message_type(MessageType::Request);
203 req.set_serialize_type(self.opt.serialize_type);
204 req.set_compress_type(self.opt.compress_type);
205 req.set_seq(seq);
206 req.service_path = service_path.to_string();
207 req.service_method = service_method.to_string();
208
209 let mut new_metadata = HashMap::with_capacity(metadata.len());
210 for (k, v) in metadata {
211 new_metadata.insert(k.clone(), v.clone());
212 }
213 req.metadata.replace(new_metadata);
214 let payload = args.into_bytes(self.opt.serialize_type).unwrap();
215 req.payload = payload;
216
217 let data = req.encode();
218
219 let call_future = if !is_oneway && !is_heartbeat {
220 let callback = Call::new(seq);
221 let arc_call = Arc::new(Mutex::new(RefCell::from(callback)));
222 self.calls
223 .clone()
224 .lock()
225 .unwrap()
226 .insert(seq, arc_call.clone());
227
228 CallFuture::new(Some(arc_call))
229 } else {
230 CallFuture::new(None)
231 };
232
233 let send_data = RpcData { seq, data };
234 match self.chan_sender.clone().send(send_data) {
235 Ok(_) => {}
236 Err(err) => self.remove_call_with_senderr(err),
237 }
238
239 call_future
240 }
241
242 fn remove_call_with_senderr(&self, err: SendError<RpcData>) {
243 let seq = err.0.seq;
244 let calls = self.calls.clone();
245 let mut m = calls.lock().unwrap();
246 if let Some(call) = m.remove(&seq) {
247 let internal_call_cloned = call.clone();
248 let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
249 let internal_call = internal_call_mutex.get_mut();
250 internal_call.error = String::from(err.description());
251 let mut status = internal_call.state.lock().unwrap();
252 status.ready = true;
253 if let Some(ref task) = status.task {
254 task.clone().wake()
255 }
256 }
257 }
258
259 fn drain_calls<T: StdError>(calls: Arc<Mutex<HashMap<u64, ArcCall>>>, err: T) {
260 let mut m = calls.lock().unwrap();
261 for (_, call) in m.drain().take(1) {
262 let internal_call_cloned = call.clone();
263 let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
264 let internal_call = internal_call_mutex.get_mut();
265 internal_call.error = String::from(err.description());
266 let mut status = internal_call.state.lock().unwrap();
267 status.ready = true;
268 if let Some(ref task) = status.task {
269 task.clone().wake()
270 }
271 }
272 }
273
274 #[allow(dead_code)]
275 fn remove_call_with_err<T: StdError>(&mut self, seq: u64, err: T) {
276 let calls = self.calls.clone();
277 let m = calls.lock().unwrap();
278 if let Some(call) = m.get(&seq) {
279 let internal_call_cloned = call.clone();
280 let mut internal_call_mutex = internal_call_cloned.lock().unwrap();
281 let internal_call = internal_call_mutex.get_mut();
282 internal_call.error = String::from(err.description());
283 let mut status = internal_call.state.lock().unwrap();
284 status.ready = true;
285 if let Some(ref task) = status.task {
286 task.clone().wake()
287 }
288 }
289 }
290
291 pub fn call<T>(
292 &mut self,
293 service_path: &str,
294 service_method: &str,
295 is_oneway: bool,
296 metadata: &Metadata,
297 args: &dyn RpcxParam,
298 ) -> Option<Result<T>>
299 where
300 T: RpcxParam + Default,
301 {
302 let rt = Runtime::new().unwrap();
303 let callfuture = rt.block_on(async {
304 let f = self.send(
305 service_path,
306 service_method,
307 is_oneway,
308 false,
309 metadata,
310 args,
311 );
312 f.await
313 });
314
315 if is_oneway {
316 return None;
317 }
318
319 let arc_call_1 = callfuture.unwrap().clone();
320 let mut arc_call_2 = arc_call_1.lock().unwrap();
321 let arc_call_3 = arc_call_2.get_mut();
322 let reply_data = &arc_call_3.reply_data;
323
324 if !arc_call_3.error.is_empty() {
325 let err = &arc_call_3.error;
326 if arc_call_3.is_client_error {
327 return Some(Err(Error::new(ErrorKind::Client, String::from(err))));
328 } else {
329 return Some(Err(Error::from(String::from(err))));
330 }
331 }
332
333 let mut reply: T = Default::default();
334 match reply.from_slice(self.opt.serialize_type, &reply_data) {
335 Ok(()) => Some(Ok(reply)),
336 Err(err) => Some(Err(err)),
337 }
338 }
339}