1use anyhow::Result;
2use arcstr::{literal, ArcStr};
3use futures::{
4 channel::{mpsc, oneshot},
5 future,
6 prelude::*,
7 select_biased, stream,
8};
9use fxhash::{FxHashMap, FxHashSet};
10use log::{error, info};
11use netidx::{
12 path::Path,
13 publisher::{
14 ClId, Id, PublishFlags, Publisher, SendResult, Val, Value, WriteRequest,
15 },
16 subscriber::{Dval, Subscriber},
17};
18use poolshark::global::{GPooled, Pool};
19use std::{
20 borrow::Borrow,
21 collections::HashMap,
22 ops::Drop,
23 sync::Arc,
24 time::{Duration, Instant},
25};
26use tokio::task;
27
28#[macro_use]
29pub mod server {
30 use std::{
31 collections::HashSet,
32 panic::{catch_unwind, AssertUnwindSafe},
33 sync::LazyLock,
34 };
35
36 use super::*;
37
38 atomic_id!(ProcId);
39
40 #[macro_export]
42 macro_rules! rpc_err {
43 ($reply:expr, $msg:expr) => {{
44 $reply.send(Value::error($msg));
45 return None;
46 }};
47 }
48
49 #[macro_export]
53 macro_rules! define_rpc {
54 (
55 $publisher:expr,
56 $path:expr,
57 $topdoc:expr,
58 $map:expr,
59 $tx:expr,
60 $($arg:ident: $typ:ty = $default:expr; $doc:expr),*
61 ) => {
62 define_rpc!(
63 $publisher,
64 netidx::publisher::PublishFlags::empty(),
65 $path,
66 $topdoc,
67 $map,
68 $tx,
69 $($arg: $typ = $default; $doc),*
70 )
71 };
72 (
73 $publisher:expr,
74 $flags:expr,
75 $path:expr,
76 $topdoc:expr,
77 $map:expr,
78 $tx:expr,
79 $($arg:ident: $typ:ty = $default:expr; $doc:expr),*
80 ) => {{
81 let map = move |mut c: RpcCall| {
82 $(
83 let d = Value::from($default);
84 let $arg = match c.args.remove(stringify!($arg)).unwrap_or(d).cast_to::<$typ>() {
85 Ok(t) => t,
86 Err(_) => rpc_err!(c.reply, format!("arg: {} invalid type conversion", stringify!($arg)))
87 };
88 )*
89 if c.args.len() != 0 {
90 rpc_err!(c.reply, format!("unknown argument specified: {:?}", c.args.keys().collect::<Vec<_>>()))
91 }
92 $map(c, $($arg),*)
93 };
94 let args = [
95 $(ArgSpec {name: ArcStr::from(stringify!($arg)), default_value: Value::from($default), doc: Value::from($doc)}),*
96 ];
97 Proc::new_with_flags($publisher, $flags, $path, Value::from($topdoc), args, map, $tx)
98 }}
99 }
100
101 static ARGS: LazyLock<Pool<HashMap<ArcStr, Value>>> =
102 LazyLock::new(|| Pool::new(1000, 50));
103
104 #[derive(Debug)]
105 pub struct RpcReply(Option<SendResult>);
106
107 impl Drop for RpcReply {
108 fn drop(&mut self) {
109 if let Some(reply) = self.0.take() {
110 let _ = reply.send(Value::error(literal!("rpc call failed")));
111 }
112 }
113 }
114
115 impl RpcReply {
116 pub fn send<T: Into<Value>>(&mut self, m: T) {
117 if let Some(res) = self.0.take() {
118 res.send(m.into());
119 }
120 }
121 }
122
123 #[derive(Debug, Clone)]
124 pub struct ArgSpec {
125 pub name: ArcStr,
126 pub doc: Value,
127 pub default_value: Value,
128 }
129
130 #[derive(Debug)]
131 pub struct RpcCall {
132 pub client: ClId,
133 pub id: ProcId,
134 pub args: GPooled<HashMap<ArcStr, Value>>,
135 pub reply: RpcReply,
136 }
137
138 struct Arg {
139 name: ArcStr,
140 _value: Val,
141 _doc: Val,
142 }
143
144 struct PendingCall {
145 args: GPooled<HashMap<ArcStr, Value>>,
146 initiated: Instant,
147 }
148
149 struct ProcInner<M: FnMut(RpcCall) -> Option<T> + Send + 'static, T: Send + 'static> {
150 id: ProcId,
151 call: Arc<Val>,
152 _doc: Val,
153 args: FxHashMap<Id, Arg>,
154 arg_names: FxHashSet<ArcStr>,
155 pending: FxHashMap<ClId, PendingCall>,
156 handler: Option<mpsc::Sender<T>>,
157 map: M,
158 events: stream::Fuse<mpsc::Receiver<GPooled<Vec<WriteRequest>>>>,
159 stop: future::Fuse<oneshot::Receiver<()>>,
160 last_gc: Instant,
161 }
162
163 impl<M, T> ProcInner<M, T>
164 where
165 M: FnMut(RpcCall) -> Option<T> + Send + 'static,
166 T: Send + 'static,
167 {
168 async fn run(mut self) {
169 static GC_FREQ: Duration = Duration::from_secs(1);
170 static GC_THRESHOLD: usize = 128;
171 fn gc_pending(pending: &mut FxHashMap<ClId, PendingCall>, now: Instant) {
172 static STALE: Duration = Duration::from_secs(60);
173 pending.retain(|_, pc| now - pc.initiated < STALE);
174 pending.shrink_to_fit();
175 }
176 let mut stop = self.stop;
177 loop {
178 #[rustfmt::skip]
179 select_biased! {
180 _ = stop => break,
181 mut batch = self.events.select_next_some() => for req in batch.drain(..) {
182 if req.id == self.call.id() {
183 let mut args = self.pending.remove(&req.client).map(|pc| pc.args)
184 .unwrap_or_else(|| ARGS.take());
185 match req.value {
186 Value::Null => (),
187 Value::Array(a) => for v in &*a {
188 match v.clone().cast_to::<(ArcStr, Value)>() {
189 Ok((name, val)) => {
190 if let Some(name) = self.arg_names.get(&*name) {
191 args.insert(name.clone(), val);
192 }
193 }
194 Err(_) => ()
195 }
196 }
197 _ => ()
198 };
199 let call = RpcCall {
200 client: req.client,
201 id: self.id,
202 args,
203 reply: RpcReply(req.send_result),
204 };
205 let t = match catch_unwind(AssertUnwindSafe(|| (self.map)(call))) {
206 Ok(t) => t,
207 Err(_) => {
208 error!("rpc map args panic");
209 continue
210 }
211 };
212 if let Some(t) = t {
213 if let Some(handler) = &mut self.handler {
214 let _: std::result::Result<_, _> = handler.send(t).await;
215 }
216 }
217 } else {
218 let mut gc = false;
219 let pending = self.pending.entry(req.client)
220 .or_insert_with(|| {
221 gc = true;
222 PendingCall {
223 args: ARGS.take(),
224 initiated: Instant::now()
225 }
226 });
227 if let Some(Arg {name, ..}) = self.args.get(&req.id) {
228 pending.args.insert(name.clone(), req.value);
229 }
230 if gc && self.pending.len() > GC_THRESHOLD {
231 let now = Instant::now();
232 if now - self.last_gc > GC_FREQ {
233 self.last_gc = now;
234 gc_pending(&mut self.pending, now);
235 }
236 }
237 }
238 }
239 }
240 }
241 }
242 }
243
244 #[derive(Debug)]
246 pub struct Proc {
247 _stop: oneshot::Sender<()>,
248 id: ProcId,
249 }
250
251 impl Proc {
252 pub fn new<T: Send + 'static, F: FnMut(RpcCall) -> Option<T> + Send + 'static>(
313 publisher: &Publisher,
314 name: Path,
315 doc: Value,
316 args: impl IntoIterator<Item = ArgSpec>,
317 map: F,
318 handler: Option<mpsc::Sender<T>>,
319 ) -> Result<Proc> {
320 Self::new_with_flags(
321 publisher,
322 PublishFlags::empty(),
323 name,
324 doc,
325 args,
326 map,
327 handler,
328 )
329 }
330
331 pub fn new_with_flags<
332 T: Send + 'static,
333 F: FnMut(RpcCall) -> Option<T> + Send + 'static,
334 >(
335 publisher: &Publisher,
336 flags: PublishFlags,
337 name: Path,
338 doc: Value,
339 args: impl IntoIterator<Item = ArgSpec>,
340 map: F,
341 handler: Option<mpsc::Sender<T>>,
342 ) -> Result<Proc> {
343 let id = ProcId::new();
344 let (tx_ev, rx_ev) = mpsc::channel(3);
345 let (tx_stop, rx_stop) = oneshot::channel();
346 let _doc = publisher.publish_with_flags(
347 flags | PublishFlags::USE_EXISTING,
348 name.append("doc"),
349 doc,
350 )?;
351 let mut arg_names = HashSet::default();
352 let args = args
353 .into_iter()
354 .map(|ArgSpec { name: arg, doc, default_value }| {
355 arg_names.insert(arg.clone());
356 let base = name.append(&*arg);
357 let _value = publisher
358 .publish_with_flags(
359 flags | PublishFlags::USE_EXISTING,
360 base.append("val"),
361 default_value,
362 )
363 .map(|val| {
364 publisher.writes(val.id(), tx_ev.clone());
365 val
366 })?;
367 let _doc = publisher.publish_with_flags(
368 flags | PublishFlags::USE_EXISTING,
369 base.append("doc"),
370 doc,
371 )?;
372 Ok((_value.id(), Arg { name: arg, _value, _doc }))
373 })
374 .collect::<Result<FxHashMap<Id, Arg>>>()?;
375 let call = Arc::new(publisher.publish_with_flags(
376 flags | PublishFlags::USE_EXISTING,
377 name.clone(),
378 arg_names.clone(),
379 )?);
380 publisher.writes(call.id(), tx_ev.clone());
381 let inner = ProcInner {
382 id,
383 call,
384 _doc,
385 args,
386 arg_names,
387 pending: HashMap::default(),
388 map,
389 handler,
390 events: rx_ev.fuse(),
391 stop: rx_stop.fuse(),
392 last_gc: Instant::now(),
393 };
394 task::spawn(async move {
395 inner.run().await;
396 info!("rpc proc {} shutdown", name);
397 });
398 Ok(Proc { id, _stop: tx_stop })
399 }
400
401 pub fn id(&self) -> ProcId {
403 self.id
404 }
405 }
406}
407
408#[macro_use]
409pub mod client {
410 use super::*;
411 use fxhash::FxHashSet;
412 use log::{debug, trace};
413 use netidx::subscriber::Event;
414 use once_cell::sync::OnceCell;
415 use std::collections::HashSet;
416 use tokio::time;
417
418 #[macro_export]
421 macro_rules! call_rpc {
422 ($proc:expr, $($name:ident: $arg:expr),*) => {
423 $proc.call([
424 $(
425 (stringify!($name), $arg.try_into()?)
426 ),*
427 ])
428 }
429 }
430
431 #[derive(Debug)]
432 struct ProcInner {
433 call: Dval,
434 args: OnceCell<FxHashSet<ArcStr>>,
435 subscribe_timeout: Duration,
436 }
437
438 #[derive(Debug, Clone)]
439 pub struct Proc(Arc<ProcInner>);
440
441 impl Proc {
442 pub fn new(subscriber: &Subscriber, name: Path) -> Result<Proc> {
448 let call = subscriber.subscribe(name.clone());
449 Ok(Proc(Arc::new(ProcInner {
450 call,
451 args: OnceCell::new(),
452 subscribe_timeout: Duration::from_secs(10),
453 })))
454 }
455
456 pub fn new_with_timeout(
459 subscriber: &Subscriber,
460 name: Path,
461 subscribe_timeout: Duration,
462 ) -> Result<Proc> {
463 let call = subscriber.subscribe(name.clone());
464 Ok(Proc(Arc::new(ProcInner {
465 call,
466 args: OnceCell::new(),
467 subscribe_timeout,
468 })))
469 }
470
471 pub async fn call<I, K>(&self, args: I) -> Result<Value>
495 where
496 I: IntoIterator<Item = (K, Value)>,
497 K: Borrow<str>,
498 {
499 if self.0.args.get().is_none() {
500 loop {
501 debug!("waiting for subscription to procedure");
502 time::timeout(
503 self.0.subscribe_timeout,
504 self.0.call.wait_subscribed(),
505 )
506 .await
507 .map_err(|_| anyhow!("timeout subscribing to procedure"))??;
508 debug!("fetching args");
509 match self.0.call.last() {
510 Event::Unsubscribed => (),
511 Event::Update(v) => {
512 debug!("args are {:?}", v);
513 let args = v
514 .clone()
515 .cast_to::<FxHashSet<ArcStr>>()
516 .ok()
517 .unwrap_or(HashSet::default());
518 let _: Result<(), FxHashSet<ArcStr>> = self.0.args.set(args);
521 break;
522 }
523 }
524 }
525 }
526 let args = {
527 let mut set: FxHashMap<ArcStr, Value> = HashMap::default();
528 let names = match self.0.args.get() {
529 Some(names) => names,
530 None => bail!("no args set"),
531 };
532 for (name, val) in args {
533 match names.get(name.borrow()) {
534 None => bail!("no such argument {}", name.borrow()),
535 Some(name) => {
536 set.insert(name.clone(), val);
537 }
538 }
539 }
540 set
541 };
542 trace!("calling procedure");
543 let res = self
544 .0
545 .call
546 .write_with_recipt(args.into())
547 .await
548 .map_err(|_| anyhow!("call cancelled before a reply was received"))?;
549 trace!("procedure called");
550 Ok(res)
551 }
552 }
553}
554
555#[cfg(test)]
556mod test {
557 use crate::{channel::test::Ctx, rpc::server::ArgSpec};
558
559 use super::server::*;
560 use super::*;
561 use tokio::{runtime::Runtime, time};
562
563 #[test]
564 fn call_proc() {
565 let _ = env_logger::try_init();
566 Runtime::new()
567 .unwrap()
568 .block_on(async move {
569 let ctx = Ctx::new().await;
570 let proc_name = Path::from("/rpc/procedure");
571 let (tx, mut rx) = mpsc::channel(10);
572 let _server_proc = define_rpc!(
573 &ctx.publisher,
574 proc_name.clone(),
575 "test rpc procedure",
576 |c, a| Some((c, a)),
577 Some(tx),
578 arg1: Value = Value::Null; "arg1 doc"
579 )
580 .unwrap();
581 task::spawn(async move {
582 while let Some((mut c, a)) = rx.next().await {
583 assert_eq!(a, Value::from("hello rpc"));
584 c.reply.send(Value::U32(42))
585 }
586 });
587 time::sleep(Duration::from_millis(100)).await;
588 let proc: client::Proc =
589 client::Proc::new(&ctx.subscriber, proc_name.clone()).unwrap();
590 let res = call_rpc!(proc, arg1: "hello rpc").await.unwrap();
591 assert_eq!(res, Value::U32(42));
592 let args: Vec<(Arc<str>, Value)> = vec![];
593 let res = proc.call(args.into_iter()).await.unwrap();
594 assert!(match res {
595 Value::Error(_) => true,
596 _ => false,
597 });
598 let args = vec![("arg2", Value::from("hello rpc"))];
599 assert!(proc.call(args.into_iter()).await.is_err());
600 Ok::<(), anyhow::Error>(())
601 })
602 .unwrap()
603 }
604}