netidx_wsproxy/
lib.rs

1use crate::protocol::{Request, Response, Update};
2use anyhow::{bail, Result};
3use futures::{
4    channel::mpsc,
5    prelude::*,
6    select_biased,
7    stream::{FuturesUnordered, SplitSink},
8    StreamExt,
9};
10use fxhash::FxHashMap;
11use log::warn;
12use netidx::{
13    path::Path,
14    protocol::value::Value,
15    publisher::{Id as PubId, Publisher, UpdateBatch, Val as Pub},
16    subscriber::{Dval as Sub, Event, SubId, Subscriber, UpdatesFlags},
17    utils::{BatchItem, Batched},
18};
19use netidx_protocols::rpc::client::Proc;
20use poolshark::{Pool, Pooled};
21use std::time::Duration;
22use std::{
23    collections::{hash_map::Entry, HashMap},
24    net::SocketAddr,
25    pin::Pin,
26    result,
27    sync::LazyLock,
28};
29use warp::{
30    filters::BoxedFilter,
31    ws::{Message, WebSocket, Ws},
32    Filter, Reply,
33};
34
35pub mod config;
36mod protocol;
37
38struct SubEntry {
39    count: usize,
40    path: Path,
41    val: Sub,
42}
43
44struct PubEntry {
45    path: Path,
46    val: Pub,
47}
48
49type PendingCall =
50    Pin<Box<dyn Future<Output = (u64, Result<Value>)> + Send + Sync + 'static>>;
51
52async fn reply<'a>(
53    tx: &mut SplitSink<WebSocket, Message>,
54    m: &Response,
55    timeout: Option<Duration>,
56) -> Result<()> {
57    let s = serde_json::to_string(m)?;
58    // CR base1172 for estokes: Here we're only enforcing that the SplitSink write completes within
59    // [timeout], with no guarantee on how long it takes to actually flush the message to the client.
60    // In a perfect world we'd probably want a proper flush timeout (similar to what [WriteChannel] does).
61    // For now, just requiring that [tx.send(..)] completes within [timeout] is probably good enough.
62    // DUR
63    let fut = tx.send(Message::text(s));
64    match timeout {
65        None => Ok(fut.await?),
66        Some(timeout) => Ok(tokio::time::timeout(timeout, fut).await??),
67    }
68}
69async fn err(
70    tx: &mut SplitSink<WebSocket, Message>,
71    message: impl Into<String>,
72    timeout: Option<Duration>,
73) -> Result<()> {
74    reply(tx, &Response::Error { error: message.into() }, timeout).await
75}
76
77struct ClientCtx {
78    publisher: Publisher,
79    subscriber: Subscriber,
80    subs: FxHashMap<SubId, SubEntry>,
81    pubs: FxHashMap<PubId, PubEntry>,
82    subs_by_path: HashMap<Path, SubId>,
83    pubs_by_path: HashMap<Path, PubId>,
84    rpcs: HashMap<Path, Proc>,
85    tx_up: mpsc::Sender<Pooled<Vec<(SubId, Event)>>>,
86}
87
88impl ClientCtx {
89    fn new(
90        publisher: Publisher,
91        subscriber: Subscriber,
92        tx_up: mpsc::Sender<Pooled<Vec<(SubId, Event)>>>,
93    ) -> Self {
94        Self {
95            publisher,
96            subscriber,
97            tx_up,
98            subs: HashMap::default(),
99            pubs: HashMap::default(),
100            subs_by_path: HashMap::default(),
101            pubs_by_path: HashMap::default(),
102            rpcs: HashMap::default(),
103        }
104    }
105
106    fn subscribe(&mut self, path: Path) -> SubId {
107        match self.subs_by_path.entry(path) {
108            Entry::Occupied(e) => {
109                let se = self.subs.get_mut(e.get()).unwrap();
110                se.count += 1;
111                se.val.id()
112            }
113            Entry::Vacant(e) => {
114                let path = e.key().clone();
115                let val = self.subscriber.subscribe(path.clone());
116                let id = val.id();
117                val.updates(UpdatesFlags::BEGIN_WITH_LAST, self.tx_up.clone());
118                self.subs.insert(id, SubEntry { count: 1, path, val });
119                e.insert(id);
120                id
121            }
122        }
123    }
124
125    fn unsubscribe(&mut self, id: SubId) -> Result<()> {
126        match self.subs.get_mut(&id) {
127            None => bail!("not subscribed"),
128            Some(se) => {
129                se.count -= 1;
130                if se.count == 0 {
131                    let path = se.path.clone();
132                    self.subs.remove(&id);
133                    self.subs_by_path.remove(&path);
134                }
135                Ok(())
136            }
137        }
138    }
139
140    fn write(&mut self, id: SubId, val: Value) -> Result<()> {
141        match self.subs.get(&id) {
142            None => bail!("not subscribed"),
143            Some(se) => {
144                se.val.write(val);
145                Ok(())
146            }
147        }
148    }
149
150    fn publish(&mut self, path: Path, val: Value) -> Result<PubId> {
151        match self.pubs_by_path.entry(path) {
152            Entry::Occupied(_) => bail!("already published"),
153            Entry::Vacant(e) => {
154                let path = e.key().clone();
155                let val = self.publisher.publish(path.clone(), val)?;
156                let id = val.id();
157                e.insert(id);
158                self.pubs.insert(id, PubEntry { val, path });
159                Ok(id)
160            }
161        }
162    }
163
164    fn unpublish(&mut self, id: PubId) -> Result<()> {
165        match self.pubs.remove(&id) {
166            None => bail!("not published"),
167            Some(pe) => {
168                self.pubs_by_path.remove(&pe.path);
169                Ok(())
170            }
171        }
172    }
173
174    fn update(
175        &mut self,
176        batch: &mut UpdateBatch,
177        mut updates: Pooled<Vec<protocol::BatchItem>>,
178    ) -> Result<()> {
179        for up in updates.drain(..) {
180            match self.pubs.get(&up.id) {
181                None => bail!("not published"),
182                Some(pe) => pe.val.update(batch, up.data),
183            }
184        }
185        Ok(())
186    }
187
188    fn call(
189        &mut self,
190        id: u64,
191        path: Path,
192        mut args: Pooled<Vec<(Pooled<String>, Value)>>,
193    ) -> Result<PendingCall> {
194        let proc = match self.rpcs.entry(path) {
195            Entry::Occupied(e) => e.into_mut(),
196            Entry::Vacant(e) => {
197                let proc = Proc::new(&self.subscriber, e.key().clone())?;
198                e.insert(proc)
199            }
200        }
201        .clone();
202        Ok(Box::pin(async move { (id, proc.call(args.drain(..)).await) }) as PendingCall)
203    }
204
205    async fn process_from_client(
206        &mut self,
207        tx: &mut SplitSink<WebSocket, Message>,
208        queued: &mut Vec<result::Result<Message, warp::Error>>,
209        calls_pending: &mut FuturesUnordered<PendingCall>,
210        timeout: Option<Duration>,
211    ) -> Result<()> {
212        let mut batch = self.publisher.start_batch();
213        for r in queued.drain(..) {
214            let m = r?;
215            if m.is_ping() {
216                continue;
217            }
218            match m.to_str() {
219                Err(_) => err(tx, "expected text", timeout).await?,
220                Ok(txt) => match serde_json::from_str::<Request>(txt) {
221                    Err(e) => {
222                        err(tx, format!("could not parse message {}", e), timeout).await?
223                    }
224                    Ok(req) => match req {
225                        Request::Subscribe { path } => {
226                            let id = self.subscribe(path);
227                            reply(tx, &Response::Subscribed { id }, timeout).await?
228                        }
229                        Request::Unsubscribe { id } => match self.unsubscribe(id) {
230                            Err(e) => err(tx, e.to_string(), timeout).await?,
231                            Ok(()) => reply(tx, &Response::Unsubscribed, timeout).await?,
232                        },
233                        Request::Write { id, val } => match self.write(id, val) {
234                            Err(e) => err(tx, e.to_string(), timeout).await?,
235                            Ok(()) => reply(tx, &Response::Wrote, timeout).await?,
236                        },
237                        Request::Publish { path, init } => match self.publish(path, init)
238                        {
239                            Err(e) => err(tx, e.to_string(), timeout).await?,
240                            Ok(id) => {
241                                reply(tx, &Response::Published { id }, timeout).await?
242                            }
243                        },
244                        Request::Unpublish { id } => match self.unpublish(id) {
245                            Err(e) => err(tx, e.to_string(), timeout).await?,
246                            Ok(()) => reply(tx, &Response::Unpublished, timeout).await?,
247                        },
248                        Request::Update { updates } => {
249                            match self.update(&mut batch, updates) {
250                                Err(e) => err(tx, e.to_string(), timeout).await?,
251                                Ok(()) => reply(tx, &Response::Updated, timeout).await?,
252                            }
253                        }
254                        Request::Call { id, path, args } => {
255                            match self.call(id, path, args) {
256                                Ok(pending) => calls_pending.push(pending),
257                                Err(e) => {
258                                    let error = format!("rpc call failed {}", e);
259                                    reply(
260                                        tx,
261                                        &Response::CallFailed { id, error },
262                                        timeout,
263                                    )
264                                    .await?
265                                }
266                            }
267                        }
268                        Request::Unknown => err(tx, "unknown request", timeout).await?,
269                    },
270                },
271            }
272        }
273        Ok(batch.commit(timeout).await)
274    }
275}
276
277async fn handle_client(
278    publisher: Publisher,
279    subscriber: Subscriber,
280    ws: WebSocket,
281    timeout: Option<Duration>,
282) -> Result<()> {
283    static UPDATES: LazyLock<Pool<Vec<Update>>> = LazyLock::new(|| Pool::new(50, 10000));
284    let (tx_up, mut rx_up) = mpsc::channel::<Pooled<Vec<(SubId, Event)>>>(3);
285    let mut ctx = ClientCtx::new(publisher, subscriber, tx_up);
286    let (mut tx_ws, rx_ws) = ws.split();
287    let mut queued: Vec<result::Result<Message, warp::Error>> = Vec::new();
288    let mut rx_ws = Batched::new(rx_ws.fuse(), 10_000);
289    let mut calls_pending: FuturesUnordered<PendingCall> = FuturesUnordered::new();
290    calls_pending.push(Box::pin(async { future::pending().await }) as PendingCall);
291    loop {
292        select_biased! {
293            (id, res) = calls_pending.select_next_some() => match res {
294                Ok(result) => {
295                    reply(&mut tx_ws, &Response::CallSuccess { id, result }, timeout).await?
296                }
297                Err(e) => {
298                    let error = format!("rpc call failed {}", e);
299                    reply(&mut tx_ws, &Response::CallFailed { id, error }, timeout).await?
300                }
301            },
302            r = rx_ws.select_next_some() => match r {
303                BatchItem::InBatch(r) => queued.push(r),
304                BatchItem::EndBatch => {
305                    ctx.process_from_client(
306                        &mut tx_ws,
307                        &mut queued,
308                        &mut calls_pending,
309                        timeout
310                    ).await?
311                }
312            },
313            mut batch = rx_up.select_next_some() => {
314                let mut updates = UPDATES.take();
315                for (id, event) in batch.drain(..) {
316                    updates.push(Update {id, event});
317                }
318                reply(&mut tx_ws, &Response::Update { updates }, timeout).await?
319            },
320        }
321    }
322}
323
324/// If you want to integrate the netidx api server into your own warp project
325/// this will return the filter path will be the http path where the websocket
326/// lives
327pub fn filter(
328    publisher: Publisher,
329    subscriber: Subscriber,
330    path: &'static str,
331    timeout: Option<Duration>,
332) -> BoxedFilter<(impl Reply,)> {
333    warp::path(path)
334        .and(warp::ws())
335        .map(move |ws: Ws| {
336            let (publisher, subscriber) = (publisher.clone(), subscriber.clone());
337            ws.on_upgrade(move |ws| {
338                let (publisher, subscriber) = (publisher.clone(), subscriber.clone());
339                async move {
340                    if let Err(e) =
341                        handle_client(publisher, subscriber, ws, timeout).await
342                    {
343                        warn!("client handler exited: {}", e)
344                    }
345                }
346            })
347        })
348        .boxed()
349}
350
351/// If you want to embed the websocket api in your own process, but you don't
352/// want to serve any other warp filters then you can just call this in a task.
353/// This will not return unless the server crashes, you should
354/// probably run it in a task.
355pub async fn run(
356    config: config::Config,
357    publisher: Publisher,
358    subscriber: Subscriber,
359    timeout: Option<Duration>,
360) -> Result<()> {
361    let routes = filter(publisher, subscriber, "ws", timeout);
362    match (&config.cert, &config.key) {
363        (_, None) | (None, _) => {
364            warp::serve(routes).run(config.listen.parse::<SocketAddr>()?).await
365        }
366        (Some(cert), Some(key)) => {
367            warp::serve(routes)
368                .tls()
369                .cert_path(cert)
370                .key_path(key)
371                .run(config.listen.parse::<SocketAddr>()?)
372                .await
373        }
374    }
375    Ok(())
376}