rr_mux/
wire.rs

1
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::hash::Hash;
5use std::sync::{Arc, Mutex};
6use std::clone::Clone;
7use std::fmt::Debug;
8use std::pin::Pin;
9
10use futures::prelude::*;
11use futures::channel::{mpsc, oneshot};
12use futures::task::{Context, Poll};
13use async_trait::async_trait;
14
15use crate::connector::Connector;
16
17/// Wire provides an interconnect to support integration testing of Mux based implementations
18pub struct Wire <ReqId, Target, Req, Resp, E, Ctx> {
19    connectors: Arc<Mutex<HashMap<Target, WireMux<ReqId, Target, Req, Resp, E, Ctx>>>>,
20
21    requests: Arc<Mutex<HashMap<(Target, Target, ReqId), oneshot::Sender<Resp>>>>,
22
23    _e: PhantomData<E>, 
24    _ctx: PhantomData<Ctx>,
25}
26
27impl <ReqId, Target, Req, Resp, E, Ctx> Clone for Wire<ReqId, Target, Req, Resp, E, Ctx>
28where
29    ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
30    Target: Clone + Hash + PartialEq + Eq + Send + 'static,
31    Req: PartialEq + Debug + Send + 'static,
32    Resp: PartialEq + Debug + Send + 'static,
33    E: PartialEq + Debug + Send + 'static,
34    Ctx: Clone + PartialEq + Debug + Send + 'static,
35{
36    fn clone(&self) -> Self {
37        Wire {
38            connectors: self.connectors.clone(),
39            requests: self.requests.clone(),
40
41            _e: PhantomData,
42            _ctx: PhantomData,
43        }
44    }
45}
46
47impl <ReqId, Target, Req, Resp, E, Ctx> Wire<ReqId, Target, Req, Resp, E, Ctx> 
48where
49    ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
50    Target: Clone + Hash + PartialEq + Eq + Send + 'static,
51    Req: PartialEq + Debug + Send + 'static,
52    Resp: PartialEq + Debug + Send + 'static,
53    E: PartialEq + Debug + Send + 'static,
54    Ctx: Clone + PartialEq + Debug + Send + 'static,
55{
56    /// Create a new Wire interconnect
57    pub fn new() -> Wire<ReqId, Target, Req, Resp, E, Ctx>  {
58        Wire{
59            connectors: Arc::new(Mutex::new(HashMap::new())),
60            requests: Arc::new(Mutex::new(HashMap::new())),
61
62            _e: PhantomData,
63            _ctx: PhantomData,
64        }
65    }
66
67    /// Create a new connector for the provided target address
68    pub fn connector(&mut self, target: Target) -> WireMux<ReqId, Target, Req, Resp, E, Ctx> {
69        let w = WireMux::new(self.clone(), target.clone());
70
71        self.connectors.lock().unwrap().insert(target, w.clone());
72
73        w
74    }
75
76    async fn request(&mut self, _ctx: Ctx, to: Target, from: Target, id: ReqId, req: Req) -> Result<Resp, ()> {
77        // Fetch matching connector
78        let mut conn = {
79            let c = self.connectors.lock().unwrap();
80            c.get(&to.clone()).unwrap().clone()
81        };
82
83        // Bind response channel
84        let (tx, rx) = oneshot::channel();
85        self.requests.lock().unwrap().insert((to, from.clone(), id.clone()), tx);
86
87        // Forward request
88        conn.send(from, id, req).await.unwrap();
89
90        // Await response
91        let res = rx.await.unwrap();
92
93        Ok(res)
94    }
95
96    async fn respond(&mut self, _ctx: Ctx, to: Target, from: Target, id: ReqId, resp: Resp) -> Result<(), E> {
97        let pending = self.requests.lock().unwrap().remove(&(from, to, id)).unwrap();
98        
99        pending.send(resp).unwrap();
100        
101        Ok(())
102    }
103}
104
105pub struct WireMux<ReqId, Target, Req, Resp, E, Ctx> {
106    addr: Target,
107
108    connector: Wire<ReqId, Target, Req, Resp, E, Ctx>,
109
110    receiver_tx: Arc<Mutex<mpsc::Sender<(Target, ReqId, Req)>>>,
111    receiver_rx: Arc<Mutex<mpsc::Receiver<(Target, ReqId, Req)>>>,
112
113    _e: PhantomData<E>, 
114    _ctx: PhantomData<Ctx>,
115}
116
117impl <ReqId, Target, Req, Resp, E, Ctx> WireMux<ReqId, Target, Req, Resp, E, Ctx>
118where
119    ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
120    Target: Clone + Hash + PartialEq + Eq + Send + 'static,
121    Req: PartialEq + Debug + Send + 'static,
122    Resp: PartialEq + Debug + Send + 'static,
123    E: PartialEq + Debug + Send + 'static,
124    Ctx: Clone + PartialEq + Debug + Send + 'static,
125{
126    fn new(connector: Wire<ReqId, Target, Req, Resp, E, Ctx>, addr: Target) -> WireMux<ReqId, Target, Req, Resp, E, Ctx> {
127        let (tx, rx) = mpsc::channel(0);
128
129        WireMux{
130            addr,
131            connector,
132
133            receiver_rx: Arc::new(Mutex::new(rx)),
134            receiver_tx: Arc::new(Mutex::new(tx)),
135
136            _e: PhantomData,
137            _ctx: PhantomData,
138        }
139    }
140
141    async fn send(&mut self, from: Target, id: ReqId, req: Req) -> Result<(), E> {
142        let mut tx = self.receiver_tx.lock().unwrap().clone();
143        
144        match tx.send((from, id, req)).await {
145            Ok(_) => (),
146            Err(e) => panic!(e),
147        };
148
149        Ok(())
150    }
151}
152
153
154impl <ReqId, Target, Req, Resp, E, Ctx> Clone for WireMux<ReqId, Target, Req, Resp, E, Ctx> 
155where
156    ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
157    Target: Clone + Hash + PartialEq + Eq + Send + 'static,
158    Req: PartialEq + Debug + Send + 'static,
159    Resp: PartialEq + Debug + Send + 'static,
160    E: PartialEq + Debug + Send + 'static,
161    Ctx: Clone + PartialEq + Debug + Send + 'static,
162{
163    fn clone(&self) -> Self {
164        WireMux{
165            addr: self.addr.clone(),
166            connector: self.connector.clone(),
167
168            receiver_rx: self.receiver_rx.clone(),
169            receiver_tx: self.receiver_tx.clone(),
170
171            _e: PhantomData,
172            _ctx: PhantomData,
173        }
174    }
175}
176
177#[async_trait]
178impl <ReqId, Target, Req, Resp, E, Ctx> Connector<ReqId, Target, Req, Resp, E, Ctx> for WireMux <ReqId, Target, Req, Resp, E, Ctx> 
179where
180    ReqId: Clone + Hash + Eq + PartialEq + Debug + Send + 'static,
181    Target: Clone + Hash + PartialEq + Eq + Send + 'static,
182    Req: PartialEq + Debug + Send + 'static,
183    Resp: PartialEq + Debug + Send + 'static,
184    E: PartialEq + Debug + Send + 'static,
185    Ctx: Clone + PartialEq + Debug + Send + 'static,
186{
187    // Send a request and receive a response or error at some time in the future
188    async fn request(
189        &mut self, ctx: Ctx, req_id: ReqId, target: Target, req: Req,
190    ) -> Result<Resp, E> {
191        let addr = self.addr.clone();
192
193        // Send to connector and await response
194        let res = match self.connector.request(ctx, target, addr, req_id, req).await {
195            Ok(r) => r,
196            Err(e) => panic!(e),
197        };
198
199        Ok(res)
200    }
201
202    // Respond to a received request
203    async fn respond(
204        &mut self, ctx: Ctx, req_id: ReqId, target: Target, resp: Resp,
205    ) -> Result<(), E> {
206        let mut conn = self.connector.clone();
207        let addr = self.addr.clone();
208
209        match conn.respond(ctx, target, addr, req_id, resp).await {
210            Ok(_) => (),
211            Err(e) => panic!(e),
212        };
213
214        Ok(())
215    }
216}
217
218impl <ReqId, Target, Req, Resp, E, Ctx> Stream for WireMux <ReqId, Target, Req, Resp, E, Ctx> 
219where
220    ReqId: Hash + Eq + PartialEq + Debug + Send + 'static,
221    Target: Hash + PartialEq + Eq + Send + 'static,
222    Req: PartialEq + Debug + Send + 'static,
223    Resp: PartialEq + Debug + Send + 'static,
224    E: PartialEq + Debug + Send + 'static,
225    Ctx: Clone + PartialEq + Debug + Send + 'static,
226{
227    type Item = (Target, ReqId, Req);
228
229    // Poll to receive pending requests
230    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
231        let rx = self.receiver_rx.clone();
232        let mut rx = rx.lock().unwrap();
233        rx.poll_next_unpin(cx)
234    }
235}
236
237#[cfg(test)]
238mod tests {
239
240    use futures::prelude::*;
241    use futures::executor::block_on;
242
243    use super::*;
244
245    #[test]
246    fn test_wiring() {
247        let mut i: Wire<u16, u64, u32, u32, (), ()> = Wire::new();
248        
249        let mut c1 = i.connector(0x11);
250        let mut c2 = i.connector(0x22);
251        
252        // c1 makes a request (and checks the response)
253        let a = async move {
254            let resp = c1.request((), 1, 0x22, 40).await.unwrap();
255            assert_eq!(resp, 50);
256        }.boxed();
257
258        let b = async move {
259            while let Some((from, id, val)) = c2.next().await {
260                c2.respond((), id, from, val + 10).await.unwrap();
261            }
262        }.boxed();
263        
264        // Run using select
265        // a will finish, b will poll forever
266        let _ = block_on(future::select(a, b));
267
268    }
269
270}