netidx_protocols/
rpc.rs

1use anyhow::Result;
2use arcstr::{literal, ArcStr};
3use futures::{
4    channel::{mpsc, oneshot},
5    future,
6    prelude::*,
7    select_biased, stream,
8};
9use fxhash::{FxHashMap, FxHashSet};
10use log::{error, info};
11use netidx::{
12    path::Path,
13    publisher::{
14        ClId, Id, PublishFlags, Publisher, SendResult, Val, Value, WriteRequest,
15    },
16    subscriber::{Dval, Subscriber},
17};
18use poolshark::global::{GPooled, Pool};
19use std::{
20    borrow::Borrow,
21    collections::HashMap,
22    ops::Drop,
23    sync::Arc,
24    time::{Duration, Instant},
25};
26use tokio::task;
27
28#[macro_use]
29pub mod server {
30    use std::{
31        collections::HashSet,
32        panic::{catch_unwind, AssertUnwindSafe},
33        sync::LazyLock,
34    };
35
36    use super::*;
37
38    atomic_id!(ProcId);
39
40    /// for use in map functions, will reply to the client with an error and return None
41    #[macro_export]
42    macro_rules! rpc_err {
43        ($reply:expr, $msg:expr) => {{
44            $reply.send(Value::error($msg));
45            return None;
46        }};
47    }
48
49    /// defines a new rpc.
50    /// `define_rpc!(publisher, path, doc, mapfn, tx, arg: typ = default; doc, ...)`
51    /// see `Proc` for an example
52    #[macro_export]
53    macro_rules! define_rpc {
54        (
55            $publisher:expr,
56            $path:expr,
57            $topdoc:expr,
58            $map:expr,
59            $tx:expr,
60            $($arg:ident: $typ:ty = $default:expr; $doc:expr),*
61        ) => {
62            define_rpc!(
63                $publisher,
64                netidx::publisher::PublishFlags::empty(),
65                $path,
66                $topdoc,
67                $map,
68                $tx,
69                $($arg: $typ = $default; $doc),*
70            )
71        };
72        (
73            $publisher:expr,
74            $flags:expr,
75            $path:expr,
76            $topdoc:expr,
77            $map:expr,
78            $tx:expr,
79            $($arg:ident: $typ:ty = $default:expr; $doc:expr),*
80        ) => {{
81            let map = move |mut c: RpcCall| {
82                $(
83                    let d = Value::from($default);
84                    let $arg = match c.args.remove(stringify!($arg)).unwrap_or(d).cast_to::<$typ>() {
85                        Ok(t) => t,
86                        Err(_) => rpc_err!(c.reply, format!("arg: {} invalid type conversion", stringify!($arg)))
87                    };
88                )*
89                if c.args.len() != 0 {
90                    rpc_err!(c.reply, format!("unknown argument specified: {:?}", c.args.keys().collect::<Vec<_>>()))
91                }
92                $map(c, $($arg),*)
93            };
94            let args = [
95                $(ArgSpec {name: ArcStr::from(stringify!($arg)), default_value: Value::from($default), doc: Value::from($doc)}),*
96            ];
97            Proc::new_with_flags($publisher, $flags, $path, Value::from($topdoc), args, map, $tx)
98        }}
99    }
100
101    static ARGS: LazyLock<Pool<HashMap<ArcStr, Value>>> =
102        LazyLock::new(|| Pool::new(1000, 50));
103
104    #[derive(Debug)]
105    pub struct RpcReply(Option<SendResult>);
106
107    impl Drop for RpcReply {
108        fn drop(&mut self) {
109            if let Some(reply) = self.0.take() {
110                let _ = reply.send(Value::error(literal!("rpc call failed")));
111            }
112        }
113    }
114
115    impl RpcReply {
116        pub fn send<T: Into<Value>>(&mut self, m: T) {
117            if let Some(res) = self.0.take() {
118                res.send(m.into());
119            }
120        }
121    }
122
123    #[derive(Debug, Clone)]
124    pub struct ArgSpec {
125        pub name: ArcStr,
126        pub doc: Value,
127        pub default_value: Value,
128    }
129
130    #[derive(Debug)]
131    pub struct RpcCall {
132        pub client: ClId,
133        pub id: ProcId,
134        pub args: GPooled<HashMap<ArcStr, Value>>,
135        pub reply: RpcReply,
136    }
137
138    struct Arg {
139        name: ArcStr,
140        _value: Val,
141        _doc: Val,
142    }
143
144    struct PendingCall {
145        args: GPooled<HashMap<ArcStr, Value>>,
146        initiated: Instant,
147    }
148
149    struct ProcInner<M: FnMut(RpcCall) -> Option<T> + Send + 'static, T: Send + 'static> {
150        id: ProcId,
151        call: Arc<Val>,
152        _doc: Val,
153        args: FxHashMap<Id, Arg>,
154        arg_names: FxHashSet<ArcStr>,
155        pending: FxHashMap<ClId, PendingCall>,
156        handler: Option<mpsc::Sender<T>>,
157        map: M,
158        events: stream::Fuse<mpsc::Receiver<GPooled<Vec<WriteRequest>>>>,
159        stop: future::Fuse<oneshot::Receiver<()>>,
160        last_gc: Instant,
161    }
162
163    impl<M, T> ProcInner<M, T>
164    where
165        M: FnMut(RpcCall) -> Option<T> + Send + 'static,
166        T: Send + 'static,
167    {
168        async fn run(mut self) {
169            static GC_FREQ: Duration = Duration::from_secs(1);
170            static GC_THRESHOLD: usize = 128;
171            fn gc_pending(pending: &mut FxHashMap<ClId, PendingCall>, now: Instant) {
172                static STALE: Duration = Duration::from_secs(60);
173                pending.retain(|_, pc| now - pc.initiated < STALE);
174                pending.shrink_to_fit();
175            }
176            let mut stop = self.stop;
177            loop {
178                #[rustfmt::skip]
179                select_biased! {
180                    _ = stop => break,
181                    mut batch = self.events.select_next_some() => for req in batch.drain(..) {
182                        if req.id == self.call.id() {
183                            let mut args = self.pending.remove(&req.client).map(|pc| pc.args)
184                                .unwrap_or_else(|| ARGS.take());
185			    match req.value {
186				Value::Null => (),
187				Value::Array(a) => for v in &*a {
188				    match v.clone().cast_to::<(ArcStr, Value)>() {
189					Ok((name, val)) => {
190					    if let Some(name) = self.arg_names.get(&*name) {
191						args.insert(name.clone(), val);
192					    }
193					}
194					Err(_) => ()
195				    }
196				}
197				_ => ()
198			    };
199                            let call = RpcCall {
200                                client: req.client,
201                                id: self.id,
202                                args,
203                                reply: RpcReply(req.send_result),
204                            };
205                            let t = match catch_unwind(AssertUnwindSafe(|| (self.map)(call))) {
206                                Ok(t) => t,
207                                Err(_) => {
208                                    error!("rpc map args panic");
209                                    continue
210                                }
211                            };
212                            if let Some(t) = t {
213                                if let Some(handler) = &mut self.handler {
214                                    let _: std::result::Result<_, _> = handler.send(t).await;
215                                }
216                            }
217                        } else {
218                            let mut gc = false;
219                            let pending = self.pending.entry(req.client)
220                                .or_insert_with(|| {
221                                    gc = true;
222                                    PendingCall {
223                                        args: ARGS.take(),
224                                        initiated: Instant::now()
225                                    }
226                                });
227                            if let Some(Arg {name, ..}) = self.args.get(&req.id) {
228                                pending.args.insert(name.clone(), req.value);
229                            }
230                            if gc && self.pending.len() > GC_THRESHOLD {
231                                let now = Instant::now();
232                                if now - self.last_gc > GC_FREQ {
233                                    self.last_gc = now;
234                                    gc_pending(&mut self.pending, now);
235                                }
236                            }
237                        }
238                    }
239                }
240            }
241        }
242    }
243
244    /// A remote procedure published in netidx
245    #[derive(Debug)]
246    pub struct Proc {
247        _stop: oneshot::Sender<()>,
248        id: ProcId,
249    }
250
251    impl Proc {
252        /**
253        Publish a new remote procedure. If successful this will return
254        a `Proc` which, if dropped, will cause the removal of the
255        procedure from netidx.
256
257        # Arguments
258
259        * `publisher` - A reference to the publisher that will publish the procedure.
260        * `name` - The path of the procedure in netidx.
261        * `doc` - The procedure level doc string to be published along with the procedure
262        * `args` - An iterator containing the procedure arguments
263        * `map` - A function that will map the raw parameters into the type of the channel.
264          if it returns None then nothing will be pushed into the channel.
265        * `handler` - The channel that will receive the rpc call invocations (if any)
266
267        If you can handle the procedure entirely without async (or blocking) then you only
268        need to define map, you don't need to pass a handler channel. Your map function should
269        handle the call, reply to the client, and return None.
270
271        If you need to do something async in order to handle the call, then you must pass
272        an mpsc channel that will receive the output of your map function. You can define
273        as little or as much slack as you desire, however be aware that if the channel fills up
274        then clients attempting to call your procedure will wait.
275
276        # Example
277        ```no_run
278        #[macro_use] extern crate netidx_protocols;
279        use netidx::{path::Path, subscriber::Value};
280        use netidx_protocols::rpc::server::{Proc, ArgSpec, RpcCall};
281        use arcstr::ArcStr;
282        # use anyhow::Result;
283        # async fn z() -> Result<()> {
284        #   let publisher = unimplemented!();
285            let echo = define_rpc!(
286                &publisher,
287                Path::from("/examples/api/echo"),
288                "echos it's argument",
289                |mut c: RpcCall, arg: Value| -> Option<()> {
290                    c.reply.send(arg);
291                    None
292                },
293                None,
294                arg: Value = Value::Null; "argument to echo"
295            );
296        #   drop(echo);
297        #   Ok(())
298        # }
299        ```
300
301        # Notes
302
303        If more than one publisher is publishing the same compatible
304        RPC (same arguments, same name, hopefully the same
305        semantics!), then clients will randomly pick one procedure
306        from the set at client creation time.
307
308        Arguments with the same key that are specified multiple times
309        will overwrite previous versions; the procedure will receive
310        only the last version set.
311         **/
312        pub fn new<T: Send + 'static, F: FnMut(RpcCall) -> Option<T> + Send + 'static>(
313            publisher: &Publisher,
314            name: Path,
315            doc: Value,
316            args: impl IntoIterator<Item = ArgSpec>,
317            map: F,
318            handler: Option<mpsc::Sender<T>>,
319        ) -> Result<Proc> {
320            Self::new_with_flags(
321                publisher,
322                PublishFlags::empty(),
323                name,
324                doc,
325                args,
326                map,
327                handler,
328            )
329        }
330
331        pub fn new_with_flags<
332            T: Send + 'static,
333            F: FnMut(RpcCall) -> Option<T> + Send + 'static,
334        >(
335            publisher: &Publisher,
336            flags: PublishFlags,
337            name: Path,
338            doc: Value,
339            args: impl IntoIterator<Item = ArgSpec>,
340            map: F,
341            handler: Option<mpsc::Sender<T>>,
342        ) -> Result<Proc> {
343            let id = ProcId::new();
344            let (tx_ev, rx_ev) = mpsc::channel(3);
345            let (tx_stop, rx_stop) = oneshot::channel();
346            let _doc = publisher.publish_with_flags(
347                flags | PublishFlags::USE_EXISTING,
348                name.append("doc"),
349                doc,
350            )?;
351            let mut arg_names = HashSet::default();
352            let args = args
353                .into_iter()
354                .map(|ArgSpec { name: arg, doc, default_value }| {
355                    arg_names.insert(arg.clone());
356                    let base = name.append(&*arg);
357                    let _value = publisher
358                        .publish_with_flags(
359                            flags | PublishFlags::USE_EXISTING,
360                            base.append("val"),
361                            default_value,
362                        )
363                        .map(|val| {
364                            publisher.writes(val.id(), tx_ev.clone());
365                            val
366                        })?;
367                    let _doc = publisher.publish_with_flags(
368                        flags | PublishFlags::USE_EXISTING,
369                        base.append("doc"),
370                        doc,
371                    )?;
372                    Ok((_value.id(), Arg { name: arg, _value, _doc }))
373                })
374                .collect::<Result<FxHashMap<Id, Arg>>>()?;
375            let call = Arc::new(publisher.publish_with_flags(
376                flags | PublishFlags::USE_EXISTING,
377                name.clone(),
378                arg_names.clone(),
379            )?);
380            publisher.writes(call.id(), tx_ev.clone());
381            let inner = ProcInner {
382                id,
383                call,
384                _doc,
385                args,
386                arg_names,
387                pending: HashMap::default(),
388                map,
389                handler,
390                events: rx_ev.fuse(),
391                stop: rx_stop.fuse(),
392                last_gc: Instant::now(),
393            };
394            task::spawn(async move {
395                inner.run().await;
396                info!("rpc proc {} shutdown", name);
397            });
398            Ok(Proc { id, _stop: tx_stop })
399        }
400
401        /// Get the rpc procedure id
402        pub fn id(&self) -> ProcId {
403            self.id
404        }
405    }
406}
407
408#[macro_use]
409pub mod client {
410    use super::*;
411    use fxhash::FxHashSet;
412    use log::{debug, trace};
413    use netidx::subscriber::Event;
414    use once_cell::sync::OnceCell;
415    use std::collections::HashSet;
416    use tokio::time;
417
418    /// Convenience macro for calling rpcs.
419    /// `call_rpc!(proc, arg0: 3, arg1: "foo", arg2: vec!["foo", "bar", "baz"])`
420    #[macro_export]
421    macro_rules! call_rpc {
422        ($proc:expr, $($name:ident: $arg:expr),*) => {
423            $proc.call([
424                $(
425                    (stringify!($name), $arg.try_into()?)
426                ),*
427            ])
428        }
429    }
430
431    #[derive(Debug)]
432    struct ProcInner {
433        call: Dval,
434        args: OnceCell<FxHashSet<ArcStr>>,
435        subscribe_timeout: Duration,
436    }
437
438    #[derive(Debug, Clone)]
439    pub struct Proc(Arc<ProcInner>);
440
441    impl Proc {
442        /// Subscribe to the procedure specified by `name`, if
443        /// successful return a `Proc` structure that may be used to
444        /// call the procedure. Dropping the `Proc` structure will
445        /// unsubscribe from the procedure and free all associated
446        /// resources.
447        pub fn new(subscriber: &Subscriber, name: Path) -> Result<Proc> {
448            let call = subscriber.subscribe(name.clone());
449            Ok(Proc(Arc::new(ProcInner {
450                call,
451                args: OnceCell::new(),
452                subscribe_timeout: Duration::from_secs(10),
453            })))
454        }
455
456        /// Exactly the same as subscribe, except allows setting the
457        /// subscribe timeout, which is 10 seconds by default.
458        pub fn new_with_timeout(
459            subscriber: &Subscriber,
460            name: Path,
461            subscribe_timeout: Duration,
462        ) -> Result<Proc> {
463            let call = subscriber.subscribe(name.clone());
464            Ok(Proc(Arc::new(ProcInner {
465                call,
466                args: OnceCell::new(),
467                subscribe_timeout,
468            })))
469        }
470
471        /**
472        Call the procedure. `call` may be reused to call the procedure again.
473
474        # Example
475        ```no_run
476        #[macro_use] extern crate netidx_protocols;
477        use netidx::{path::Path, subscriber::Value};
478        use netidx_protocols::rpc::client::Proc;
479        # use anyhow::Result;
480        # async fn z() -> Result<()> {
481        #   let subscriber = unimplemented!();
482            let echo = Proc::new(subscriber, Path::from("/examples/api/echo"))?;
483            let v = call_rpc!(echo, arg1: "hello echo").await?;
484        #   drop(echo);
485        #   Ok(())
486        # }
487        ```
488
489        # Notes
490
491        `call` may safely be called concurrently on multiple
492        instances of `Proc` that call the same procedure
493        **/
494        pub async fn call<I, K>(&self, args: I) -> Result<Value>
495        where
496            I: IntoIterator<Item = (K, Value)>,
497            K: Borrow<str>,
498        {
499            if self.0.args.get().is_none() {
500                loop {
501                    debug!("waiting for subscription to procedure");
502                    time::timeout(
503                        self.0.subscribe_timeout,
504                        self.0.call.wait_subscribed(),
505                    )
506                    .await
507                    .map_err(|_| anyhow!("timeout subscribing to procedure"))??;
508                    debug!("fetching args");
509                    match self.0.call.last() {
510                        Event::Unsubscribed => (),
511                        Event::Update(v) => {
512                            debug!("args are {:?}", v);
513                            let args = v
514                                .clone()
515                                .cast_to::<FxHashSet<ArcStr>>()
516                                .ok()
517                                .unwrap_or(HashSet::default());
518                            // Another thread may have set these args already,
519                            // so ignore if `set` returns Err.
520                            let _: Result<(), FxHashSet<ArcStr>> = self.0.args.set(args);
521                            break;
522                        }
523                    }
524                }
525            }
526            let args = {
527                let mut set: FxHashMap<ArcStr, Value> = HashMap::default();
528                let names = match self.0.args.get() {
529                    Some(names) => names,
530                    None => bail!("no args set"),
531                };
532                for (name, val) in args {
533                    match names.get(name.borrow()) {
534                        None => bail!("no such argument {}", name.borrow()),
535                        Some(name) => {
536                            set.insert(name.clone(), val);
537                        }
538                    }
539                }
540                set
541            };
542            trace!("calling procedure");
543            let res = self
544                .0
545                .call
546                .write_with_recipt(args.into())
547                .await
548                .map_err(|_| anyhow!("call cancelled before a reply was received"))?;
549            trace!("procedure called");
550            Ok(res)
551        }
552    }
553}
554
555#[cfg(test)]
556mod test {
557    use crate::{channel::test::Ctx, rpc::server::ArgSpec};
558
559    use super::server::*;
560    use super::*;
561    use tokio::{runtime::Runtime, time};
562
563    #[test]
564    fn call_proc() {
565        let _ = env_logger::try_init();
566        Runtime::new()
567            .unwrap()
568            .block_on(async move {
569                let ctx = Ctx::new().await;
570                let proc_name = Path::from("/rpc/procedure");
571                let (tx, mut rx) = mpsc::channel(10);
572                let _server_proc = define_rpc!(
573                    &ctx.publisher,
574                    proc_name.clone(),
575                    "test rpc procedure",
576                    |c, a| Some((c, a)),
577                    Some(tx),
578                    arg1: Value = Value::Null; "arg1 doc"
579                )
580                .unwrap();
581                task::spawn(async move {
582                    while let Some((mut c, a)) = rx.next().await {
583                        assert_eq!(a, Value::from("hello rpc"));
584                        c.reply.send(Value::U32(42))
585                    }
586                });
587                time::sleep(Duration::from_millis(100)).await;
588                let proc: client::Proc =
589                    client::Proc::new(&ctx.subscriber, proc_name.clone()).unwrap();
590                let res = call_rpc!(proc, arg1: "hello rpc").await.unwrap();
591                assert_eq!(res, Value::U32(42));
592                let args: Vec<(Arc<str>, Value)> = vec![];
593                let res = proc.call(args.into_iter()).await.unwrap();
594                assert!(match res {
595                    Value::Error(_) => true,
596                    _ => false,
597                });
598                let args = vec![("arg2", Value::from("hello rpc"))];
599                assert!(proc.call(args.into_iter()).await.is_err());
600                Ok::<(), anyhow::Error>(())
601            })
602            .unwrap()
603    }
604}