netidx_wsproxy/
lib.rs

1use crate::protocol::{Request, Response, Update};
2use anyhow::{bail, Result};
3use futures::{
4    channel::mpsc,
5    prelude::*,
6    select_biased,
7    stream::{FuturesUnordered, SplitSink},
8    StreamExt,
9};
10use fxhash::FxHashMap;
11use log::warn;
12use netidx::{
13    path::Path,
14    pool::{Pool, Pooled},
15    protocol::value::Value,
16    publisher::{Id as PubId, Publisher, UpdateBatch, Val as Pub},
17    subscriber::{Dval as Sub, Event, SubId, Subscriber, UpdatesFlags},
18    utils::{BatchItem, Batched},
19};
20use netidx_protocols::rpc::client::Proc;
21use once_cell::sync::Lazy;
22use std::{
23    collections::{hash_map::Entry, HashMap},
24    net::SocketAddr,
25    pin::Pin,
26    result,
27};
28use warp::{
29    filters::BoxedFilter,
30    ws::{Message, WebSocket, Ws},
31    Filter, Reply,
32};
33use std::time::Duration;
34
35pub mod config;
36mod protocol;
37
38struct SubEntry {
39    count: usize,
40    path: Path,
41    val: Sub,
42}
43
44struct PubEntry {
45    path: Path,
46    val: Pub,
47}
48
49type PendingCall =
50    Pin<Box<dyn Future<Output = (u64, Result<Value>)> + Send + Sync + 'static>>;
51
52async fn reply<'a>(tx: &mut SplitSink<WebSocket, Message>, m: &Response, timeout: Option<Duration>) -> Result<()> {
53    let s = serde_json::to_string(m)?;
54    // CR base1172 for estokes: Here we're only enforcing that the SplitSink write completes within
55    // [timeout], with no guarantee on how long it takes to actually flush the message to the client.
56    // In a perfect world we'd probably want a proper flush timeout (similar to what [WriteChannel] does).
57    // For now, just requiring that [tx.send(..)] completes within [timeout] is probably good enough.
58    // DUR
59    let fut = tx.send(Message::text(s));
60    match timeout {
61        None => Ok(fut.await?),
62        Some(timeout) => Ok(tokio::time::timeout(timeout, fut).await??)
63    }
64}
65async fn err(
66    tx: &mut SplitSink<WebSocket, Message>,
67    message: impl Into<String>,
68    timeout: Option<Duration>,
69) -> Result<()> {
70    reply(tx, &Response::Error { error: message.into() }, timeout).await
71}
72
73struct ClientCtx {
74    publisher: Publisher,
75    subscriber: Subscriber,
76    subs: FxHashMap<SubId, SubEntry>,
77    pubs: FxHashMap<PubId, PubEntry>,
78    subs_by_path: HashMap<Path, SubId>,
79    pubs_by_path: HashMap<Path, PubId>,
80    rpcs: HashMap<Path, Proc>,
81    tx_up: mpsc::Sender<Pooled<Vec<(SubId, Event)>>>,
82}
83
84impl ClientCtx {
85    fn new(
86        publisher: Publisher,
87        subscriber: Subscriber,
88        tx_up: mpsc::Sender<Pooled<Vec<(SubId, Event)>>>,
89    ) -> Self {
90        Self {
91            publisher,
92            subscriber,
93            tx_up,
94            subs: HashMap::default(),
95            pubs: HashMap::default(),
96            subs_by_path: HashMap::default(),
97            pubs_by_path: HashMap::default(),
98            rpcs: HashMap::default(),
99        }
100    }
101
102    fn subscribe(&mut self, path: Path) -> SubId {
103        match self.subs_by_path.entry(path) {
104            Entry::Occupied(e) => {
105                let se = self.subs.get_mut(e.get()).unwrap();
106                se.count += 1;
107                se.val.id()
108            }
109            Entry::Vacant(e) => {
110                let path = e.key().clone();
111                let val = self.subscriber.subscribe(path.clone());
112                let id = val.id();
113                val.updates(UpdatesFlags::BEGIN_WITH_LAST, self.tx_up.clone());
114                self.subs.insert(id, SubEntry { count: 1, path, val });
115                e.insert(id);
116                id
117            }
118        }
119    }
120
121    fn unsubscribe(&mut self, id: SubId) -> Result<()> {
122        match self.subs.get_mut(&id) {
123            None => bail!("not subscribed"),
124            Some(se) => {
125                se.count -= 1;
126                if se.count == 0 {
127                    let path = se.path.clone();
128                    self.subs.remove(&id);
129                    self.subs_by_path.remove(&path);
130                }
131                Ok(())
132            }
133        }
134    }
135
136    fn write(&mut self, id: SubId, val: Value) -> Result<()> {
137        match self.subs.get(&id) {
138            None => bail!("not subscribed"),
139            Some(se) => {
140                se.val.write(val);
141                Ok(())
142            }
143        }
144    }
145
146    fn publish(&mut self, path: Path, val: Value) -> Result<PubId> {
147        match self.pubs_by_path.entry(path) {
148            Entry::Occupied(_) => bail!("already published"),
149            Entry::Vacant(e) => {
150                let path = e.key().clone();
151                let val = self.publisher.publish(path.clone(), val)?;
152                let id = val.id();
153                e.insert(id);
154                self.pubs.insert(id, PubEntry { val, path });
155                Ok(id)
156            }
157        }
158    }
159
160    fn unpublish(&mut self, id: PubId) -> Result<()> {
161        match self.pubs.remove(&id) {
162            None => bail!("not published"),
163            Some(pe) => {
164                self.pubs_by_path.remove(&pe.path);
165                Ok(())
166            }
167        }
168    }
169
170    fn update(
171        &mut self,
172        batch: &mut UpdateBatch,
173        mut updates: Pooled<Vec<protocol::BatchItem>>,
174    ) -> Result<()> {
175        for up in updates.drain(..) {
176            match self.pubs.get(&up.id) {
177                None => bail!("not published"),
178                Some(pe) => pe.val.update(batch, up.data),
179            }
180        }
181        Ok(())
182    }
183
184    fn call(
185        &mut self,
186        id: u64,
187        path: Path,
188        mut args: Pooled<Vec<(Pooled<String>, Value)>>,
189    ) -> Result<PendingCall> {
190        let proc = match self.rpcs.entry(path) {
191            Entry::Occupied(e) => e.into_mut(),
192            Entry::Vacant(e) => {
193                let proc = Proc::new(&self.subscriber, e.key().clone())?;
194                e.insert(proc)
195            }
196        }
197        .clone();
198        Ok(Box::pin(async move { (id, proc.call(args.drain(..)).await) }) as PendingCall)
199    }
200
201    async fn process_from_client(
202        &mut self,
203        tx: &mut SplitSink<WebSocket, Message>,
204        queued: &mut Vec<result::Result<Message, warp::Error>>,
205        calls_pending: &mut FuturesUnordered<PendingCall>,
206        timeout: Option<Duration>
207    ) -> Result<()> {
208        let mut batch = self.publisher.start_batch();
209        for r in queued.drain(..) {
210            let m = r?;
211            if m.is_ping() {
212                continue;
213            }
214            match m.to_str() {
215                Err(_) => err(tx, "expected text", timeout).await?,
216                Ok(txt) => match serde_json::from_str::<Request>(txt) {
217                    Err(e) => err(tx, format!("could not parse message {}", e), timeout).await?,
218                    Ok(req) => match req {
219                        Request::Subscribe { path } => {
220                            let id = self.subscribe(path);
221                            reply(tx, &Response::Subscribed { id }, timeout).await?
222                        }
223                        Request::Unsubscribe { id } => match self.unsubscribe(id) {
224                            Err(e) => err(tx, e.to_string(), timeout).await?,
225                            Ok(()) => reply(tx, &Response::Unsubscribed, timeout).await?,
226                        },
227                        Request::Write { id, val } => match self.write(id, val) {
228                            Err(e) => err(tx, e.to_string(), timeout).await?,
229                            Ok(()) => reply(tx, &Response::Wrote, timeout).await?,
230                        },
231                        Request::Publish { path, init } => match self.publish(path, init)
232                        {
233                            Err(e) => err(tx, e.to_string(), timeout).await?,
234                            Ok(id) => reply(tx, &Response::Published { id }, timeout).await?,
235                        },
236                        Request::Unpublish { id } => match self.unpublish(id) {
237                            Err(e) => err(tx, e.to_string(), timeout).await?,
238                            Ok(()) => reply(tx, &Response::Unpublished, timeout).await?,
239                        },
240                        Request::Update { updates } => {
241                            match self.update(&mut batch, updates) {
242                                Err(e) => err(tx, e.to_string(), timeout).await?,
243                                Ok(()) => reply(tx, &Response::Updated, timeout).await?,
244                            }
245                        }
246                        Request::Call { id, path, args } => {
247                            match self.call(id, path, args) {
248                                Ok(pending) => calls_pending.push(pending),
249                                Err(e) => {
250                                    let error = format!("rpc call failed {}", e);
251                                    reply(tx, &Response::CallFailed { id, error }, timeout).await?
252                                }
253                            }
254                        }
255                        Request::Unknown => err(tx, "unknown request", timeout).await?,
256                    },
257                },
258            }
259        }
260        Ok(batch.commit(timeout).await)
261    }
262}
263
264async fn handle_client(
265    publisher: Publisher,
266    subscriber: Subscriber,
267    ws: WebSocket,
268    timeout: Option<Duration>
269) -> Result<()> {
270    static UPDATES: Lazy<Pool<Vec<Update>>> = Lazy::new(|| Pool::new(50, 10000));
271    let (tx_up, mut rx_up) = mpsc::channel::<Pooled<Vec<(SubId, Event)>>>(3);
272    let mut ctx = ClientCtx::new(publisher, subscriber, tx_up);
273    let (mut tx_ws, rx_ws) = ws.split();
274    let mut queued: Vec<result::Result<Message, warp::Error>> = Vec::new();
275    let mut rx_ws = Batched::new(rx_ws.fuse(), 10_000);
276    let mut calls_pending: FuturesUnordered<PendingCall> = FuturesUnordered::new();
277    calls_pending.push(Box::pin(async { future::pending().await }) as PendingCall);
278    loop {
279        select_biased! {
280            (id, res) = calls_pending.select_next_some() => match res {
281                Ok(result) => {
282                    reply(&mut tx_ws, &Response::CallSuccess { id, result }, timeout).await?
283                }
284                Err(e) => {
285                    let error = format!("rpc call failed {}", e);
286                    reply(&mut tx_ws, &Response::CallFailed { id, error }, timeout).await?
287                }
288            },
289            r = rx_ws.select_next_some() => match r {
290                BatchItem::InBatch(r) => queued.push(r),
291                BatchItem::EndBatch => {
292                    ctx.process_from_client(
293                        &mut tx_ws,
294                        &mut queued,
295                        &mut calls_pending,
296                        timeout
297                    ).await?
298                }
299            },
300            mut batch = rx_up.select_next_some() => {
301                let mut updates = UPDATES.take();
302                for (id, event) in batch.drain(..) {
303                    updates.push(Update {id, event});
304                }
305                reply(&mut tx_ws, &Response::Update { updates }, timeout).await?
306            },
307        }
308    }
309}
310
311/// If you want to integrate the netidx api server into your own warp project
312/// this will return the filter path will be the http path where the websocket
313/// lives
314pub fn filter(
315    publisher: Publisher,
316    subscriber: Subscriber,
317    path: &'static str,
318    timeout: Option<Duration>,
319) -> BoxedFilter<(impl Reply,)> {
320    warp::path(path)
321        .and(warp::ws())
322        .map(move |ws: Ws| {
323            let (publisher, subscriber) = (publisher.clone(), subscriber.clone());
324            ws.on_upgrade(move |ws| {
325                let (publisher, subscriber) = (publisher.clone(), subscriber.clone());
326                async move {
327                    if let Err(e) = handle_client(publisher, subscriber, ws, timeout).await {
328                        warn!("client handler exited: {}", e)
329                    }
330                }
331            })
332        })
333        .boxed()
334}
335
336/// If you want to embed the websocket api in your own process, but you don't
337/// want to serve any other warp filters then you can just call this in a task.
338/// This will not return unless the server crashes, you should
339/// probably run it in a task.
340pub async fn run(
341    config: config::Config,
342    publisher: Publisher,
343    subscriber: Subscriber,
344    timeout: Option<Duration>
345) -> Result<()> {
346    let routes = filter(publisher, subscriber, "ws", timeout);
347    match (&config.cert, &config.key) {
348        (_, None) | (None, _) => {
349            warp::serve(routes).run(config.listen.parse::<SocketAddr>()?).await
350        }
351        (Some(cert), Some(key)) => {
352            warp::serve(routes)
353                .tls()
354                .cert_path(cert)
355                .key_path(key)
356                .run(config.listen.parse::<SocketAddr>()?)
357                .await
358        }
359    }
360    Ok(())
361}