1use crate::pb;
6use anyhow::{anyhow, Error};
7use log::{debug, trace, warn};
8use std::collections;
9use std::sync::{
10 atomic::{AtomicUsize, Ordering},
11 Arc,
12};
13use tokio::sync::{broadcast, mpsc, Mutex};
14
15#[derive(Debug)]
16pub struct Stage {
17 requests: Mutex<collections::HashMap<u32, Request>>,
18 notify: broadcast::Sender<Request>,
19 hsm_connections: Arc<AtomicUsize>,
20}
21
22#[derive(Clone, Debug)]
23pub struct Request {
24 pub request: pb::HsmRequest,
25 pub response: mpsc::Sender<pb::HsmResponse>,
26 pub start_time: tokio::time::Instant,
27}
28
29impl Stage {
30 pub fn new() -> Self {
31 let (notify, _) = broadcast::channel(1000);
32 Stage {
33 requests: Mutex::new(collections::HashMap::new()),
34 notify: notify,
35 hsm_connections: Arc::new(AtomicUsize::new(0)),
36 }
37 }
38
39 pub async fn send(
40 &self,
41 request: pb::HsmRequest,
42 ) -> Result<mpsc::Receiver<pb::HsmResponse>, Error> {
43 let mut requests = self.requests.lock().await;
44 let (response, receiver): (
45 mpsc::Sender<pb::HsmResponse>,
46 mpsc::Receiver<pb::HsmResponse>,
47 ) = mpsc::channel(1);
48
49 let r = Request {
50 request,
51 response,
52 start_time: tokio::time::Instant::now(),
53 };
54
55 requests.insert(r.request.request_id, r.clone());
56
57 if let Err(_) = self.notify.send(r) {
58 warn!("Error notifying hsmd request stream, likely lost connection.");
59 }
60
61 Ok(receiver)
62 }
63
64 pub async fn mystream(&self) -> StageStream {
65 let requests = self.requests.lock().await;
66 self.hsm_connections.fetch_add(1, Ordering::Relaxed);
67 StageStream {
68 backlog: requests.values().map(|e| e.clone()).collect(),
69 bcast: self.notify.subscribe(),
70 hsm_connections: self.hsm_connections.clone(),
71 }
72 }
73
74 pub async fn respond(&self, response: pb::HsmResponse) -> Result<(), Error> {
75 let mut requests = self.requests.lock().await;
76 match requests.remove(&response.request_id) {
77 Some(req) => {
78 debug!(
79 "Response for request_id={}, signer_rtt={}s, outstanding requests count={}",
80 response.request_id,
81 req.start_time.elapsed().as_secs_f64(),
82 requests.len()
83 );
84 if let Err(e) = req.response.send(response).await {
85 Err(anyhow!("Error sending request to requester: {:?}", e))
86 } else {
87 Ok(())
88 }
89 }
90 None => {
91 trace!(
92 "Request {} not found, is this a duplicate result?",
93 response.request_id
94 );
95 Ok(())
96 }
97 }
98 }
99
100 pub async fn is_stuck(&self) -> bool {
101 let sticky = self
102 .requests
103 .lock()
104 .await
105 .values()
106 .filter(|r| r.request.raw[0..2] == [0u8, 5])
107 .count();
108
109 trace!("Found {sticky} sticky requests.");
110 sticky != 0
111 }
112}
113
114pub struct StageStream {
115 backlog: Vec<Request>,
116 bcast: broadcast::Receiver<Request>,
117 hsm_connections: Arc<AtomicUsize>,
118}
119
120impl StageStream {
121 pub async fn next(&mut self) -> Result<Request, Error> {
122 if self.backlog.len() > 0 {
123 let req = self.backlog[0].clone();
124 self.backlog.remove(0);
125 Ok(req)
126 } else {
127 match self.bcast.recv().await {
128 Ok(r) => Ok(r),
129 Err(e) => Err(anyhow!("error waiting for more requests: {:?}", e)),
130 }
131 }
132 }
133}
134
135impl Drop for StageStream {
136 fn drop(&mut self) {
137 self.hsm_connections.fetch_sub(1, Ordering::Relaxed);
138 }
139}
140
141#[cfg(test)]
142mod test {
143 use super::*;
144 use std::time::Duration;
145 use tokio::time::sleep as delay_for;
146
147 #[tokio::test]
148 async fn test_live_stream() {
149 let stage = Stage::new();
150
151 let mut responses = vec![];
152
153 for i in 0..10 {
154 responses.push(
155 stage
156 .send(pb::HsmRequest {
157 request_id: i,
158 context: None,
159 raw: vec![],
160 signer_state: vec![],
161 requests: vec![],
162 })
163 .await
164 .unwrap(),
165 );
166 }
167
168 let mut s1 = stage.mystream().await;
169 let mut s2 = stage.mystream().await;
170 let f1 = tokio::spawn(async move {
171 while let Ok(r) = s1.next().await {
172 eprintln!("hsmd {} is handling request {}", 1, r.request.request_id);
173 match r
174 .response
175 .send(pb::HsmResponse {
176 request_id: r.request.request_id,
177 raw: vec![],
178 signer_state: vec![],
179 })
180 .await
181 {
182 Ok(_) => {}
183 Err(e) => eprintln!("errror {:?}", e),
184 }
185 delay_for(Duration::from_millis(10)).await;
186 }
187 });
188 let f2 = tokio::spawn(async move {
189 while let Ok(r) = s2.next().await {
190 eprintln!("hsmd {} is handling request {}", 2, r.request.request_id);
191 match r
192 .response
193 .send(pb::HsmResponse {
194 request_id: r.request.request_id,
195 raw: vec![],
196 signer_state: vec![],
197 })
198 .await
199 {
200 Ok(_) => {}
201 Err(e) => eprintln!("errror {:?}", e),
202 }
203 delay_for(Duration::from_millis(10)).await;
204 }
205 });
206
207 for i in 10..100 {
208 responses.push(
209 stage
210 .send(pb::HsmRequest {
211 request_id: i,
212 context: None,
213 raw: vec![],
214 signer_state: vec![],
215 requests: vec![],
216 })
217 .await
218 .unwrap(),
219 );
220 }
221
222 for mut r in responses {
223 let resp = r.recv().await.unwrap();
224 eprintln!("Got response {:?}", resp);
225 }
226
227 drop(stage);
228 f1.await.unwrap();
229 f2.await.unwrap();
230 }
231}