1use super::msg::*;
2use std::collections::HashMap;
3use std::future::Future;
4pub use std::ops::Deref;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicU32, Ordering};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::TcpStream;
9#[cfg(unix)]
10use tokio::net::UnixStream;
11use tokio::sync::{mpsc, oneshot};
12
13type Return<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
15pub trait Call {
17 fn call<'a>(&'a self, id: u32, data: Bytes) -> Return<'a, Result<Msg, Error>>;
18}
19
20pub trait SubcribeCallback<T> {
22 fn callback(&mut self, data: T) -> bool;
24}
25
26pub struct MyStream {
27 tx: mpsc::Sender<(u32, oneshot::Sender<Result<Msg, Error>>, Bytes)>,
28}
29
30impl MyStream {
31 fn new<Stream: AsyncReadExt + AsyncWriteExt + std::marker::Unpin + Send + 'static>(mut stream: Stream) -> Self {
33 let (tx, mut rx) = mpsc::channel::<(u32, oneshot::Sender<Result<Msg, Error>>, Bytes)>(1);
34 tokio::spawn(async move {
35 let mut header = [0u8; RPC_HEADER_LEN];
36 let mut callers = HashMap::<u32, oneshot::Sender<Result<Msg, Error>>>::new();
37 let error = loop {
38 tokio::select! {
39 Some((id, otx, data)) = rx.recv() => {
40 callers.insert(id, otx);
41 if let Err(err) = stream.write_all(&data[..]).await {
42 break Error::from(err);
43 }
44 }
45 ret = stream.read_exact(&mut header[..]) => {
46 match ret {
47 Ok(0) => break Error::new("对端已关闭读取数据长度为0"),
48 Ok(_) => {
49 if let Ok(mut msg) = Msg::decode(&header[..]) {
50 match msg.mode() {
51 Mode::Respond | Mode::Publish => {
52 if let Some(buf) = msg.body() {
53 if let Err(err) = stream.read_exact(buf).await {
55 break Error::from(err);
56 }
57 }
58 }
59 _ => (),
60 }
61 if let Some(otx) = callers.remove(&msg.id()) {
62 let _ = otx.send(Ok(msg));
63 }
64 }
65 }
66 Err(err) => break Error::from(err),
67 }
68 }
69 }
70 };
71 callers.into_iter().for_each(|(_, otx)| {
73 let _ = otx.send(Err(error.clone()));
74 });
75 });
76 Self { tx }
77 }
78}
79impl Call for MyStream {
81 fn call<'a>(&'a self, id: u32, data: Bytes) -> Return<'a, Result<Msg, Error>> {
82 Box::pin(async move {
83 let (tx, rx) = oneshot::channel();
84 let _ = self.tx.send((id, tx, data)).await;
85 rx.await.unwrap()
86 })
87 }
88}
89
90pub struct Client<Stream: Call + Send> {
91 stream: Stream,
92 id: AtomicU32,
93}
94
95const HEARTBEAT: &'static str = "heartbeat";
96
97impl<Stream: Call + Send> Client<Stream> {
98 pub fn new(stream: Stream) -> Self {
100 Self { stream, id: AtomicU32::new(0) }
101 }
102 fn encode_heartbeat(&self) -> (u32, Bytes) {
104 let id = self.id.fetch_add(1, Ordering::SeqCst);
105 let msg = Msg::new(id, HEARTBEAT);
106 let data = msg.encode_without_body(Mode::HeartBeat);
107 (id, data)
108 }
109 fn encode_without_arg(&self, name: &str) -> (u32, Bytes) {
111 let id = self.id.fetch_add(1, Ordering::SeqCst);
112 let msg = Msg::new(id, name);
113 let data = msg.encode_without_body(Mode::Request);
114 (id, data)
115 }
116 fn encode_with_arg<Args>(&self, name: &str, args: Args) -> (u32, Bytes)
118 where
119 Args: serde::ser::Serialize,
120 {
121 let id = self.id.fetch_add(1, Ordering::SeqCst);
122 let msg = Msg::new(id, name);
123 let data = msg.encode(Mode::Request, &args);
124 (id, data)
125 }
126 fn decode_heartbeat(msg: Msg) -> Result<(), Error> {
128 if msg.name() != HEARTBEAT.as_bytes() {
129 return Err(Error::new("心跳回复函数名称不匹配"));
130 }
131 if !msg.headeronly() {
132 return Err(Error::new("心跳不应当返回消息体"));
133 }
134 match msg.mode() {
135 Mode::HeartBeat => Ok(()),
136 _ => Err(Error::new("返回消息模式不正确")),
137 }
138 }
139 fn decode_without_ret(msg: Msg, name: &str) -> Result<(), Error> {
141 if msg.name() != name.as_bytes() {
142 return Err(Error::new("回复函数名称不匹配"));
143 }
144 if !msg.headeronly() {
145 return Err(Error::new("不应当返回消息体"));
146 }
147 match msg.mode() {
148 Mode::Respond => Ok(()),
149 Mode::NotFound => Err(Error::new("没有找到相应的函数")),
150 Mode::NotMatch => Err(Error::new("函数参数不匹配")),
151 Mode::NoAccess => Err(Error::new("没有权限")),
152 _ => Err(Error::new("返回消息模式不正确")),
153 }
154 }
155 fn decode_with_ret<Ret>(msg: Msg, name: &str) -> Result<Ret, Error>
157 where
158 Ret: for<'a> serde::de::Deserialize<'a>,
159 {
160 if msg.name() != name.as_bytes() {
161 return Err(Error::new("回复函数名称不匹配"));
162 }
163 match msg.mode() {
164 Mode::Respond => {
165 if msg.headeronly() {
166 Err(Error::new("没有消息体"))
167 } else {
168 msg.parse()
169 }
170 }
171 Mode::NotFound => Err(Error::new("没有找到相应的函数")),
172 Mode::NotMatch => Err(Error::new("函数参数不匹配")),
173 Mode::NoAccess => Err(Error::new("没有权限")),
174 _ => Err(Error::new("返回消息模式不正确")),
175 }
176 }
177
178 #[inline]
180 pub async fn heartbeat(&self) -> Result<(), Error> {
181 let (id, data) = self.encode_heartbeat();
182 let msg = self.stream.call(id, data).await?;
183 Self::decode_heartbeat(msg)
184 }
185 #[inline]
187 pub async fn call_without_arg_ret(&self, name: &str) -> Result<(), Error> {
188 let (id, data) = self.encode_without_arg(name);
189 let msg = self.stream.call(id, data).await?;
190 Self::decode_without_ret(msg, name)
191 }
192 #[inline]
194 pub async fn call_with_ret<Ret>(&self, name: &str) -> Result<Ret, Error>
195 where
196 Ret: for<'a> serde::de::Deserialize<'a>,
197 {
198 let (id, data) = self.encode_without_arg(name);
199 let msg = self.stream.call(id, data).await?;
200 Self::decode_with_ret(msg, name)
201 }
202 #[inline]
204 pub async fn call_with_arg<Args>(&self, name: &str, args: Args) -> Result<(), Error>
205 where
206 Args: serde::ser::Serialize,
207 {
208 let (id, data) = self.encode_with_arg(name, args);
209 let msg = self.stream.call(id, data).await?;
210 Self::decode_without_ret(msg, name)
211 }
212 #[inline]
214 pub async fn call_with_arg_ret<Args, Ret>(&self, name: &str, args: Args) -> Result<Ret, Error>
215 where
216 Args: serde::ser::Serialize,
217 Ret: for<'a> serde::de::Deserialize<'a>,
218 {
219 let (id, data) = self.encode_with_arg(name, args);
220 let msg = self.stream.call(id, data).await?;
221 Self::decode_with_ret(msg, name)
222 }
223 pub async fn subcribe_with_lambda<Ret, F>(&self, topic: &str, mut f: F) -> Result<(), Error>
225 where
226 Ret: for<'a> serde::de::Deserialize<'a>,
227 F: FnMut(Ret),
228 {
229 loop {
230 let id = self.id.fetch_add(1, Ordering::SeqCst);
231 let msg = Msg::new(id, topic);
232 let data = msg.encode_without_body(Mode::Subcribe);
233 let msg = self.stream.call(id, data).await?;
234 if msg.name() != topic.as_bytes() {
235 break Err(Error::new("订阅主题名称不匹配"));
236 }
237 match msg.mode() {
238 Mode::Publish => {
239 if msg.headeronly() {
240 break Err(Error::new("没有订阅到消息体"));
241 } else {
242 f(msg.parse()?);
243 continue;
244 }
245 }
246 Mode::NotFound => break Err(Error::new("没有找到相应的函数")),
247 Mode::NotMatch => break Err(Error::new("函数参数不匹配")),
248 Mode::NoAccess => break Err(Error::new("没有权限")),
249 _ => break Err(Error::new("返回消息模式不正确")),
250 }
251 }
252 }
253 pub async fn subcribe_with_trait<Ret, T>(&self, topic: &str, t: &mut T) -> Result<(), Error>
255 where
256 Ret: for<'a> serde::de::Deserialize<'a>,
257 T: SubcribeCallback<Ret>,
258 {
259 loop {
260 let id = self.id.fetch_add(1, Ordering::SeqCst);
261 let msg = Msg::new(id, topic);
262 let data = msg.encode_without_body(Mode::Subcribe);
263 let msg = self.stream.call(id, data).await?;
264 if msg.name() != topic.as_bytes() {
265 break Err(Error::new("订阅主题名称不匹配"));
266 }
267 match msg.mode() {
268 Mode::Publish => {
269 if msg.headeronly() {
270 break Err(Error::new("没有订阅到消息体"));
271 } else {
272 if t.callback(msg.parse()?) {
273 continue; } else {
275 break Ok(()); }
277 }
278 }
279 Mode::NotFound => break Err(Error::new("没有找到相应的函数")),
280 Mode::NotMatch => break Err(Error::new("函数参数不匹配")),
281 Mode::NoAccess => break Err(Error::new("没有权限")),
282 _ => break Err(Error::new("返回消息模式不正确")),
283 }
284 }
285 }
286}
287
288pub type TCPClient = Client<MyStream>;
290#[cfg(unix)]
292pub type UnixClient = Client<MyStream>;
293
294#[inline]
296pub async fn new_tcp_client(addr: &str) -> std::io::Result<TCPClient> {
297 Ok(Client::new(MyStream::new(TcpStream::connect(addr).await?)))
298}
299#[inline]
301#[cfg(unix)]
302pub async fn new_unix_client(path: &str) -> std::io::Result<UnixClient> {
303 Ok(Client::new(MyStream::new(UnixStream::connect(path).await?)))
304}
305
306#[macro_export]
326macro_rules! call {
327 (@call $connect:ident, $addr:expr) => {
377 async move {
378 match $crate::$connect($addr).await {
379 Ok(client) => client.heartbeat().await,
380 Err(e) => Err($crate::msg::Error::from(e)),
381 }
382 }
383 };
384 (@call $connect:ident, $addr:expr, $func:ident()) => {
386 async move {
387 match $crate::$connect($addr).await {
388 Ok(client) => client.call_without_arg_ret(stringify!($func)).await,
389 Err(e) => Err($crate::msg::Error::from(e)),
390 }
391 }
392 };
393 (@call $connect:ident, $addr:expr, $func:ident() -> $ret:ty) => {
395 async move {
396 match $crate::$connect($addr).await {
397 Ok(client) => {
398 let result: $ret = client.call_with_ret(stringify!($func)).await?;
399 Ok(result)
400 }
401 Err(e) => Err($crate::msg::Error::from(e)),
402 }
403 }
404 };
405 (@call $connect:ident, $addr:expr, $func:ident($($arg:expr),+)) => {
407 async move {
408 match $crate::$connect($addr).await {
409 Ok(client) => client.call_with_arg(stringify!($func), ($($arg,)+)).await,
410 Err(e) => Err($crate::msg::Error::from(e)),
411 }
412 }
413 };
414 (@call $connect:ident, $addr:expr, $func:ident($($arg:expr),+) -> $ret:ty) => {
416 async move {
417 match $crate::$connect($addr).await {
418 Ok(client) => {
419 let result: $ret = client.call_with_arg_ret(stringify!($func), ($($arg,)+)).await?;
420 Ok(result)
421 }
422 Err(e) => Err($crate::msg::Error::from(e)),
423 }
424 }
425 };
426 (tcp, $($var:tt)+) => {
427 async move {
428 call!(@call new_tcp_client, $($var)+).await
429 }
430 };
431 (unix, $($var:tt)+) => {
432 async move {
433 call!(@call new_unix_client, $($var)+).await
434 }
435 };
436}
437
438#[macro_export]
465macro_rules! subcribe {
466 (@sub $connect:ident, $addr:expr, $topic:expr, |$($arg:ident:$argType:ty),+|$body:block) => {
468 async move {
469 match $crate::$connect($addr).await {
470 Ok(client) => client.subcribe_with_lambda($topic, |($($arg,)+):($($argType,)+)|$body).await,
471 Err(e) => Err($crate::msg::Error::from(e)),
472 }
473 }
474 };
475 (@sub $connect:ident, $addr:expr, $topic:expr, $var:expr) => {
477 async move {
478 match $crate::$connect($addr).await {
479 Ok(client) => client.subcribe_with_trait($topic, $var).await,
480 Err(e) => Err($crate::msg::Error::from(e)),
481 }
482 }
483 };
484 (tcp, $($var:tt)+) => {
485 async move {
486 subcribe!(@sub new_tcp_client, $($var)+).await
487 }
488 };
489 (unix, $($var:tt)+) => {
490 async move {
491 subcribe!(@sub new_unix_client, $($var)+).await
492 }
493 };
494}
495
496#[macro_export]
498macro_rules! define_new_type {
499 (@method fn $name:ident$(<$generic:tt>)?()) => {
501 pub async fn $name$(<$generic>)?(&self) -> Result<(), $crate::msg::Error> {
502 self.0.call_without_arg_ret(stringify!($name)).await
503 }
504 };
505 (@method fn $name:ident$(<$generic:tt>)?($($arg:ident:$argType:ty,)+)) => {
507 pub async fn $name$(<$generic>)?(&self, $($arg:$argType,)+) -> Result<(), $crate::msg::Error>
508 where
509 $($argType: serde::ser::Serialize,)+
510 {
511 self.0.call_with_arg(stringify!($name), ($($arg,)+)).await
512 }
513 };
514 (@method fn $name:ident$(<$generic:tt>)?()->$ret:ty) => {
516 pub async fn $name$(<$generic>)?(&self) -> Result<$ret, $crate::msg::Error>
517 where
518 $ret: for<'a> serde::de::Deserialize<'a>,
519 {
520 self.0.call_with_ret(stringify!($name)).await
521 }
522 };
523 (@method fn $name:ident$(<$generic:tt>)?($($arg:ident:$argType:ty,)+)->$ret:ty) => {
525 pub async fn $name$(<$generic>)?(&self, $($arg:$argType,)+) -> Result<$ret, $crate::msg::Error>
526 where
527 $($argType: serde::ser::Serialize,)+
528 $ret: for<'a> serde::de::Deserialize<'a>,
529 {
530 self.0.call_with_arg_ret(stringify!($name), ($($arg,)+)).await
531 }
532 };
533 (@method sub $name:ident($topic:ident:$topicType:ty, $f:ident:$(impl)? FnMut($ArgType:ty))) => {
535 pub async fn $name<F>(&self, $topic:$topicType, $f:F) -> Result<(), $crate::msg::Error>
536 where
537 F: FnMut($ArgType),
538 $topicType: Deref<Target = str>,
539 $ArgType: for<'a> serde::de::Deserialize<'a>,
540 {
541 self.0.subcribe_with_lambda(&$topic, $f).await
542 }
543 };
544 (@method sub $name:ident($topic:ident:$topicType:ty, $var:ident:&mut $ArgType:ty)) => {
546 pub async fn $name<Ret>(&self, $topic:&str, $var:&mut $ArgType) -> Result<(), $crate::msg::Error>
547 where
548 Ret: for<'a> serde::de::Deserialize<'a>,
549 $topicType: Deref<Target = str>,
550 $ArgType: SubcribeCallback<Ret>,
551 {
552 self.0.subcribe_with_trait(&$topic, $var).await
553 }
554 };
555 ($f:ident, $t:ident, $StructName:ident, $(fn $name:ident$(<$generic:tt>)?($($arg:ident:$argType:ty),*)$(->$ret:ty)?),+) => {
557 struct $StructName($t);
558 impl $StructName {
559 pub async fn new(path:&str) -> std::io::Result<Self> {
560 Ok(Self($f(path).await?))
561 }
562 pub async fn heartbeat(&self) -> Result<(), $crate::msg::Error> {
563 self.0.heartbeat().await
564 }
565 $(define_new_type!(@method fn $name$(<$generic>)?($($arg:$argType,)*)$(->$ret)?);)+
566 }
567 };
568 ($f:ident, $t:ident, $StructName:ident, $(sub $name:ident($topic:ident:$topicType:ty, $arg:ident:$($argType:tt)+)),+) => {
570 struct $StructName($t);
571 impl $StructName {
572 pub async fn new(path:&str) -> std::io::Result<Self> {
573 Ok(Self($f(path).await?))
574 }
575 pub async fn heartbeat(&self) -> Result<(), $crate::msg::Error> {
576 self.0.heartbeat().await
577 }
578 $(define_new_type!(@method sub $name($topic:$topicType,$arg:$($argType)+));)+
579 }
580 };
581 ($f:ident, $t:ident, $StructName:ident, $(fn $name:ident($($arg:ident:$argType:ty),*)$(->$ret:ty)?),+, $(sub $name2:ident($topic:ident:$topicType:ty, $arg2:ident:$($argType2:tt)+)),+) => {
583 struct $StructName($t);
584 impl $StructName {
585 pub async fn new(path:&str) -> std::io::Result<Self> {
586 Ok(Self($f(path).await?))
587 }
588 pub async fn heartbeat(&self) -> Result<(), $crate::msg::Error> {
589 self.0.heartbeat().await
590 }
591 $(define_new_type!(@method fn $name($($arg:$argType,)*)$(->$ret)?);)+
592 $(define_new_type!(@method sub $name2($topic:$topicType,$arg2:$($argType2)+));)+
593 }
594 };
595}
596
597#[macro_export]
656macro_rules! define {
657 (tcp, $($var:tt)+) => {
658 define_new_type!(new_tcp_client, TCPClient, $($var)+);
659 };
660 (unix, $($var:tt)+) => {
661 define_new_type!(new_unix_client, UnixClient, $($var)+);
662 };
663}