gl_plugin/
stager.rs

1/// A simple staging mechanism for incoming requests so we can invert from
2/// pull to push. Used by `hsmproxy` to stage requests that can then
3/// asynchronously be retrieved and processed by one or more client
4/// devices.
5use crate::pb;
6use anyhow::{anyhow, Error};
7use log::{debug, trace, warn};
8use std::collections;
9use std::sync::{
10    atomic::{AtomicUsize, Ordering},
11    Arc,
12};
13use tokio::sync::{broadcast, mpsc, Mutex};
14
15#[derive(Debug)]
16pub struct Stage {
17    requests: Mutex<collections::HashMap<u32, Request>>,
18    notify: broadcast::Sender<Request>,
19    hsm_connections: Arc<AtomicUsize>,
20}
21
22#[derive(Clone, Debug)]
23pub struct Request {
24    pub request: pb::HsmRequest,
25    pub response: mpsc::Sender<pb::HsmResponse>,
26    pub start_time: tokio::time::Instant,
27}
28
29impl Stage {
30    pub fn new() -> Self {
31        let (notify, _) = broadcast::channel(1000);
32        Stage {
33            requests: Mutex::new(collections::HashMap::new()),
34            notify: notify,
35            hsm_connections: Arc::new(AtomicUsize::new(0)),
36        }
37    }
38
39    pub async fn send(
40        &self,
41        request: pb::HsmRequest,
42    ) -> Result<mpsc::Receiver<pb::HsmResponse>, Error> {
43        let mut requests = self.requests.lock().await;
44        let (response, receiver): (
45            mpsc::Sender<pb::HsmResponse>,
46            mpsc::Receiver<pb::HsmResponse>,
47        ) = mpsc::channel(1);
48
49        let r = Request {
50            request,
51            response,
52            start_time: tokio::time::Instant::now(),
53        };
54
55        requests.insert(r.request.request_id, r.clone());
56
57        if let Err(_) = self.notify.send(r) {
58            warn!("Error notifying hsmd request stream, likely lost connection.");
59        }
60
61        Ok(receiver)
62    }
63
64    pub async fn mystream(&self) -> StageStream {
65        let requests = self.requests.lock().await;
66        self.hsm_connections.fetch_add(1, Ordering::Relaxed);
67        StageStream {
68            backlog: requests.values().map(|e| e.clone()).collect(),
69            bcast: self.notify.subscribe(),
70            hsm_connections: self.hsm_connections.clone(),
71        }
72    }
73
74    pub async fn respond(&self, response: pb::HsmResponse) -> Result<(), Error> {
75        let mut requests = self.requests.lock().await;
76        match requests.remove(&response.request_id) {
77            Some(req) => {
78                debug!(
79                    "Response for request_id={}, signer_rtt={}s, outstanding requests count={}",
80                    response.request_id,
81                    req.start_time.elapsed().as_secs_f64(),
82                    requests.len()
83                );
84                if let Err(e) = req.response.send(response).await {
85                    Err(anyhow!("Error sending request to requester: {:?}", e))
86                } else {
87                    Ok(())
88                }
89            }
90            None => {
91                trace!(
92                    "Request {} not found, is this a duplicate result?",
93                    response.request_id
94                );
95                Ok(())
96            }
97        }
98    }
99
100    pub async fn is_stuck(&self) -> bool {
101        let sticky = self
102            .requests
103            .lock()
104            .await
105            .values()
106            .filter(|r| r.request.raw[0..2] == [0u8, 5])
107            .count();
108
109        trace!("Found {sticky} sticky requests.");
110        sticky != 0
111    }
112}
113
114pub struct StageStream {
115    backlog: Vec<Request>,
116    bcast: broadcast::Receiver<Request>,
117    hsm_connections: Arc<AtomicUsize>,
118}
119
120impl StageStream {
121    pub async fn next(&mut self) -> Result<Request, Error> {
122        if self.backlog.len() > 0 {
123            let req = self.backlog[0].clone();
124            self.backlog.remove(0);
125            Ok(req)
126        } else {
127            match self.bcast.recv().await {
128                Ok(r) => Ok(r),
129                Err(e) => Err(anyhow!("error waiting for more requests: {:?}", e)),
130            }
131        }
132    }
133}
134
135impl Drop for StageStream {
136    fn drop(&mut self) {
137        self.hsm_connections.fetch_sub(1, Ordering::Relaxed);
138    }
139}
140
141#[cfg(test)]
142mod test {
143    use super::*;
144    use std::time::Duration;
145    use tokio::time::sleep as delay_for;
146
147    #[tokio::test]
148    async fn test_live_stream() {
149        let stage = Stage::new();
150
151        let mut responses = vec![];
152
153        for i in 0..10 {
154            responses.push(
155                stage
156                    .send(pb::HsmRequest {
157                        request_id: i,
158                        context: None,
159                        raw: vec![],
160                        signer_state: vec![],
161                        requests: vec![],
162                    })
163                    .await
164                    .unwrap(),
165            );
166        }
167
168        let mut s1 = stage.mystream().await;
169        let mut s2 = stage.mystream().await;
170        let f1 = tokio::spawn(async move {
171            while let Ok(r) = s1.next().await {
172                eprintln!("hsmd {} is handling request {}", 1, r.request.request_id);
173                match r
174                    .response
175                    .send(pb::HsmResponse {
176                        request_id: r.request.request_id,
177                        raw: vec![],
178                        signer_state: vec![],
179                        error: "".into(),
180                    })
181                    .await
182                {
183                    Ok(_) => {}
184                    Err(e) => eprintln!("errror {:?}", e),
185                }
186                delay_for(Duration::from_millis(10)).await;
187            }
188        });
189        let f2 = tokio::spawn(async move {
190            while let Ok(r) = s2.next().await {
191                eprintln!("hsmd {} is handling request {}", 2, r.request.request_id);
192                match r
193                    .response
194                    .send(pb::HsmResponse {
195                        request_id: r.request.request_id,
196                        raw: vec![],
197                        signer_state: vec![],
198                        error: "".into(),
199                    })
200                    .await
201                {
202                    Ok(_) => {}
203                    Err(e) => eprintln!("errror {:?}", e),
204                }
205                delay_for(Duration::from_millis(10)).await;
206            }
207        });
208
209        for i in 10..100 {
210            responses.push(
211                stage
212                    .send(pb::HsmRequest {
213                        request_id: i,
214                        context: None,
215                        raw: vec![],
216                        signer_state: vec![],
217                        requests: vec![],
218                    })
219                    .await
220                    .unwrap(),
221            );
222        }
223
224        for mut r in responses {
225            let resp = r.recv().await.unwrap();
226            eprintln!("Got response {:?}", resp);
227        }
228
229        drop(stage);
230        f1.await.unwrap();
231        f2.await.unwrap();
232    }
233}