1
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::hash::Hash;
5use std::sync::{Arc, Mutex};
6use std::clone::Clone;
7use std::fmt::Debug;
8use std::pin::Pin;
9
10use futures::prelude::*;
11use futures::channel::{mpsc, oneshot};
12use futures::task::{Context, Poll};
13use async_trait::async_trait;
14
15use crate::connector::Connector;
16
17pub struct Wire <ReqId, Target, Req, Resp, E, Ctx> {
19 connectors: Arc<Mutex<HashMap<Target, WireMux<ReqId, Target, Req, Resp, E, Ctx>>>>,
20
21 requests: Arc<Mutex<HashMap<(Target, Target, ReqId), oneshot::Sender<Resp>>>>,
22
23 _e: PhantomData<E>,
24 _ctx: PhantomData<Ctx>,
25}
26
27impl <ReqId, Target, Req, Resp, E, Ctx> Clone for Wire<ReqId, Target, Req, Resp, E, Ctx>
28where
29 ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
30 Target: Clone + Hash + PartialEq + Eq + Send + 'static,
31 Req: PartialEq + Debug + Send + 'static,
32 Resp: PartialEq + Debug + Send + 'static,
33 E: PartialEq + Debug + Send + 'static,
34 Ctx: Clone + PartialEq + Debug + Send + 'static,
35{
36 fn clone(&self) -> Self {
37 Wire {
38 connectors: self.connectors.clone(),
39 requests: self.requests.clone(),
40
41 _e: PhantomData,
42 _ctx: PhantomData,
43 }
44 }
45}
46
47impl <ReqId, Target, Req, Resp, E, Ctx> Wire<ReqId, Target, Req, Resp, E, Ctx>
48where
49 ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
50 Target: Clone + Hash + PartialEq + Eq + Send + 'static,
51 Req: PartialEq + Debug + Send + 'static,
52 Resp: PartialEq + Debug + Send + 'static,
53 E: PartialEq + Debug + Send + 'static,
54 Ctx: Clone + PartialEq + Debug + Send + 'static,
55{
56 pub fn new() -> Wire<ReqId, Target, Req, Resp, E, Ctx> {
58 Wire{
59 connectors: Arc::new(Mutex::new(HashMap::new())),
60 requests: Arc::new(Mutex::new(HashMap::new())),
61
62 _e: PhantomData,
63 _ctx: PhantomData,
64 }
65 }
66
67 pub fn connector(&mut self, target: Target) -> WireMux<ReqId, Target, Req, Resp, E, Ctx> {
69 let w = WireMux::new(self.clone(), target.clone());
70
71 self.connectors.lock().unwrap().insert(target, w.clone());
72
73 w
74 }
75
76 async fn request(&mut self, _ctx: Ctx, to: Target, from: Target, id: ReqId, req: Req) -> Result<Resp, ()> {
77 let mut conn = {
79 let c = self.connectors.lock().unwrap();
80 c.get(&to.clone()).unwrap().clone()
81 };
82
83 let (tx, rx) = oneshot::channel();
85 self.requests.lock().unwrap().insert((to, from.clone(), id.clone()), tx);
86
87 conn.send(from, id, req).await.unwrap();
89
90 let res = rx.await.unwrap();
92
93 Ok(res)
94 }
95
96 async fn respond(&mut self, _ctx: Ctx, to: Target, from: Target, id: ReqId, resp: Resp) -> Result<(), E> {
97 let pending = self.requests.lock().unwrap().remove(&(from, to, id)).unwrap();
98
99 pending.send(resp).unwrap();
100
101 Ok(())
102 }
103}
104
105pub struct WireMux<ReqId, Target, Req, Resp, E, Ctx> {
106 addr: Target,
107
108 connector: Wire<ReqId, Target, Req, Resp, E, Ctx>,
109
110 receiver_tx: Arc<Mutex<mpsc::Sender<(Target, ReqId, Req)>>>,
111 receiver_rx: Arc<Mutex<mpsc::Receiver<(Target, ReqId, Req)>>>,
112
113 _e: PhantomData<E>,
114 _ctx: PhantomData<Ctx>,
115}
116
117impl <ReqId, Target, Req, Resp, E, Ctx> WireMux<ReqId, Target, Req, Resp, E, Ctx>
118where
119 ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
120 Target: Clone + Hash + PartialEq + Eq + Send + 'static,
121 Req: PartialEq + Debug + Send + 'static,
122 Resp: PartialEq + Debug + Send + 'static,
123 E: PartialEq + Debug + Send + 'static,
124 Ctx: Clone + PartialEq + Debug + Send + 'static,
125{
126 fn new(connector: Wire<ReqId, Target, Req, Resp, E, Ctx>, addr: Target) -> WireMux<ReqId, Target, Req, Resp, E, Ctx> {
127 let (tx, rx) = mpsc::channel(0);
128
129 WireMux{
130 addr,
131 connector,
132
133 receiver_rx: Arc::new(Mutex::new(rx)),
134 receiver_tx: Arc::new(Mutex::new(tx)),
135
136 _e: PhantomData,
137 _ctx: PhantomData,
138 }
139 }
140
141 async fn send(&mut self, from: Target, id: ReqId, req: Req) -> Result<(), E> {
142 let mut tx = self.receiver_tx.lock().unwrap().clone();
143
144 match tx.send((from, id, req)).await {
145 Ok(_) => (),
146 Err(e) => panic!(e),
147 };
148
149 Ok(())
150 }
151}
152
153
154impl <ReqId, Target, Req, Resp, E, Ctx> Clone for WireMux<ReqId, Target, Req, Resp, E, Ctx>
155where
156 ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
157 Target: Clone + Hash + PartialEq + Eq + Send + 'static,
158 Req: PartialEq + Debug + Send + 'static,
159 Resp: PartialEq + Debug + Send + 'static,
160 E: PartialEq + Debug + Send + 'static,
161 Ctx: Clone + PartialEq + Debug + Send + 'static,
162{
163 fn clone(&self) -> Self {
164 WireMux{
165 addr: self.addr.clone(),
166 connector: self.connector.clone(),
167
168 receiver_rx: self.receiver_rx.clone(),
169 receiver_tx: self.receiver_tx.clone(),
170
171 _e: PhantomData,
172 _ctx: PhantomData,
173 }
174 }
175}
176
177#[async_trait]
178impl <ReqId, Target, Req, Resp, E, Ctx> Connector<ReqId, Target, Req, Resp, E, Ctx> for WireMux <ReqId, Target, Req, Resp, E, Ctx>
179where
180 ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
181 Target: Clone + Hash + PartialEq + Eq + Send + 'static,
182 Req: PartialEq + Debug + Send + 'static,
183 Resp: PartialEq + Debug + Send + 'static,
184 E: PartialEq + Debug + Send + 'static,
185 Ctx: Clone + PartialEq + Debug + Send + 'static,
186{
187 async fn request(
189 &mut self, ctx: Ctx, req_id: ReqId, target: Target, req: Req,
190 ) -> Result<Resp, E> {
191 let addr = self.addr.clone();
192
193 let res = match self.connector.request(ctx, target, addr, req_id, req).await {
195 Ok(r) => r,
196 Err(e) => panic!(e),
197 };
198
199 Ok(res)
200 }
201
202 async fn respond(
204 &mut self, ctx: Ctx, req_id: ReqId, target: Target, resp: Resp,
205 ) -> Result<(), E> {
206 let mut conn = self.connector.clone();
207 let addr = self.addr.clone();
208
209 match conn.respond(ctx, target, addr, req_id, resp).await {
210 Ok(_) => (),
211 Err(e) => panic!(e),
212 };
213
214 Ok(())
215 }
216}
217
218impl <ReqId, Target, Req, Resp, E, Ctx> Stream for WireMux <ReqId, Target, Req, Resp, E, Ctx>
219where
220 ReqId: Hash + Eq + PartialEq + Debug + Send + 'static,
221 Target: Hash + PartialEq + Eq + Send + 'static,
222 Req: PartialEq + Debug + Send + 'static,
223 Resp: PartialEq + Debug + Send + 'static,
224 E: PartialEq + Debug + Send + 'static,
225 Ctx: Clone + PartialEq + Debug + Send + 'static,
226{
227 type Item = (Target, ReqId, Req);
228
229 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
231 let rx = self.receiver_rx.clone();
232 let mut rx = rx.lock().unwrap();
233 rx.poll_next_unpin(cx)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239
240 use futures::prelude::*;
241 use futures::executor::block_on;
242
243 use super::*;
244
245 #[test]
246 fn test_wiring() {
247 let mut i: Wire<u16, u64, u32, u32, (), ()> = Wire::new();
248
249 let mut c1 = i.connector(0x11);
250 let mut c2 = i.connector(0x22);
251
252 let a = async move {
254 let resp = c1.request((), 1, 0x22, 40).await.unwrap();
255 assert_eq!(resp, 50);
256 }.boxed();
257
258 let b = async move {
259 while let Some((from, id, val)) = c2.next().await {
260 c2.respond((), id, from, val + 10).await.unwrap();
261 }
262 }.boxed();
263
264 let _ = block_on(future::select(a, b));
267
268 }
269
270}