bs_gl_plugin/
stager.rs

1/// A simple staging mechanism for incoming requests so we can invert from
2/// pull to push. Used by `hsmproxy` to stage requests that can then
3/// asynchronously be retrieved and processed by one or more client
4/// devices.
5use crate::pb;
6use anyhow::{anyhow, Error};
7use log::{debug, trace, warn};
8use std::collections;
9use std::sync::{
10    atomic::{AtomicUsize, Ordering},
11    Arc,
12};
13use tokio::sync::{broadcast, mpsc, Mutex};
14
15#[derive(Debug)]
16pub struct Stage {
17    requests: Mutex<collections::HashMap<u32, Request>>,
18    notify: broadcast::Sender<Request>,
19    hsm_connections: Arc<AtomicUsize>,
20}
21
22#[derive(Clone, Debug)]
23pub struct Request {
24    pub request: pb::HsmRequest,
25    pub response: mpsc::Sender<pb::HsmResponse>,
26    pub start_time: tokio::time::Instant,
27}
28
29impl Stage {
30    pub fn new() -> Self {
31        let (notify, _) = broadcast::channel(1000);
32        Stage {
33            requests: Mutex::new(collections::HashMap::new()),
34            notify: notify,
35            hsm_connections: Arc::new(AtomicUsize::new(0)),
36        }
37    }
38
39    pub async fn send(
40        &self,
41        request: pb::HsmRequest,
42    ) -> Result<mpsc::Receiver<pb::HsmResponse>, Error> {
43        let mut requests = self.requests.lock().await;
44        let (response, receiver): (
45            mpsc::Sender<pb::HsmResponse>,
46            mpsc::Receiver<pb::HsmResponse>,
47        ) = mpsc::channel(1);
48
49        let r = Request {
50            request,
51            response,
52            start_time: tokio::time::Instant::now(),
53        };
54
55        requests.insert(r.request.request_id, r.clone());
56
57        if let Err(_) = self.notify.send(r) {
58            warn!("Error notifying hsmd request stream, likely lost connection.");
59        }
60
61        Ok(receiver)
62    }
63
64    pub async fn mystream(&self) -> StageStream {
65        let requests = self.requests.lock().await;
66        self.hsm_connections.fetch_add(1, Ordering::Relaxed);
67        StageStream {
68            backlog: requests.values().map(|e| e.clone()).collect(),
69            bcast: self.notify.subscribe(),
70            hsm_connections: self.hsm_connections.clone(),
71        }
72    }
73
74    pub async fn respond(&self, response: pb::HsmResponse) -> Result<(), Error> {
75        let mut requests = self.requests.lock().await;
76        match requests.remove(&response.request_id) {
77            Some(req) => {
78                debug!(
79                    "Response for request_id={}, signer_rtt={}s, outstanding requests count={}",
80                    response.request_id,
81                    req.start_time.elapsed().as_secs_f64(),
82                    requests.len()
83                );
84                if let Err(e) = req.response.send(response).await {
85                    Err(anyhow!("Error sending request to requester: {:?}", e))
86                } else {
87                    Ok(())
88                }
89            }
90            None => {
91                trace!(
92                    "Request {} not found, is this a duplicate result?",
93                    response.request_id
94                );
95                Ok(())
96            }
97        }
98    }
99
100    pub async fn is_stuck(&self) -> bool {
101        let sticky = self
102            .requests
103            .lock()
104            .await
105            .values()
106            .filter(|r| r.request.raw[0..2] == [0u8, 5])
107            .count();
108
109        trace!("Found {sticky} sticky requests.");
110        sticky != 0
111    }
112}
113
114pub struct StageStream {
115    backlog: Vec<Request>,
116    bcast: broadcast::Receiver<Request>,
117    hsm_connections: Arc<AtomicUsize>,
118}
119
120impl StageStream {
121    pub async fn next(&mut self) -> Result<Request, Error> {
122        if self.backlog.len() > 0 {
123            let req = self.backlog[0].clone();
124            self.backlog.remove(0);
125            Ok(req)
126        } else {
127            match self.bcast.recv().await {
128                Ok(r) => Ok(r),
129                Err(e) => Err(anyhow!("error waiting for more requests: {:?}", e)),
130            }
131        }
132    }
133}
134
135impl Drop for StageStream {
136    fn drop(&mut self) {
137        self.hsm_connections.fetch_sub(1, Ordering::Relaxed);
138    }
139}
140
141#[cfg(test)]
142mod test {
143    use super::*;
144    use std::time::Duration;
145    use tokio::time::sleep as delay_for;
146
147    #[tokio::test]
148    async fn test_live_stream() {
149        let stage = Stage::new();
150
151        let mut responses = vec![];
152
153        for i in 0..10 {
154            responses.push(
155                stage
156                    .send(pb::HsmRequest {
157                        request_id: i,
158                        context: None,
159                        raw: vec![],
160                        signer_state: vec![],
161                        requests: vec![],
162                    })
163                    .await
164                    .unwrap(),
165            );
166        }
167
168        let mut s1 = stage.mystream().await;
169        let mut s2 = stage.mystream().await;
170        let f1 = tokio::spawn(async move {
171            while let Ok(r) = s1.next().await {
172                eprintln!("hsmd {} is handling request {}", 1, r.request.request_id);
173                match r
174                    .response
175                    .send(pb::HsmResponse {
176                        request_id: r.request.request_id,
177                        raw: vec![],
178                        signer_state: vec![],
179                    })
180                    .await
181                {
182                    Ok(_) => {}
183                    Err(e) => eprintln!("errror {:?}", e),
184                }
185                delay_for(Duration::from_millis(10)).await;
186            }
187        });
188        let f2 = tokio::spawn(async move {
189            while let Ok(r) = s2.next().await {
190                eprintln!("hsmd {} is handling request {}", 2, r.request.request_id);
191                match r
192                    .response
193                    .send(pb::HsmResponse {
194                        request_id: r.request.request_id,
195                        raw: vec![],
196                        signer_state: vec![],
197                    })
198                    .await
199                {
200                    Ok(_) => {}
201                    Err(e) => eprintln!("errror {:?}", e),
202                }
203                delay_for(Duration::from_millis(10)).await;
204            }
205        });
206
207        for i in 10..100 {
208            responses.push(
209                stage
210                    .send(pb::HsmRequest {
211                        request_id: i,
212                        context: None,
213                        raw: vec![],
214                        signer_state: vec![],
215                        requests: vec![],
216                    })
217                    .await
218                    .unwrap(),
219            );
220        }
221
222        for mut r in responses {
223            let resp = r.recv().await.unwrap();
224            eprintln!("Got response {:?}", resp);
225        }
226
227        drop(stage);
228        f1.await.unwrap();
229        f2.await.unwrap();
230    }
231}