1use std::collections::HashMap;
8
9use tokio::task::JoinHandle;
10
11use crate::transport::{TransportAddr, TransportError};
12
13use super::addr::BleAddr;
14
15pub struct BleConnection<S> {
17 pub stream: S,
19 pub recv_task: Option<JoinHandle<()>>,
21 pub send_mtu: u16,
23 pub recv_mtu: u16,
25 pub established_at: tokio::time::Instant,
27 pub is_static: bool,
29 pub addr: BleAddr,
31}
32
33impl<S> BleConnection<S> {
34 pub fn effective_mtu(&self) -> u16 {
36 self.send_mtu.min(self.recv_mtu)
37 }
38}
39
40impl<S> Drop for BleConnection<S> {
41 fn drop(&mut self) {
42 if let Some(task) = self.recv_task.take() {
43 task.abort();
44 }
45 }
46}
47
48pub struct ConnectionPool<S> {
50 connections: HashMap<TransportAddr, BleConnection<S>>,
51 max_connections: usize,
52}
53
54impl<S> ConnectionPool<S> {
55 pub fn new(max_connections: usize) -> Self {
57 Self {
58 connections: HashMap::new(),
59 max_connections,
60 }
61 }
62
63 pub fn len(&self) -> usize {
65 self.connections.len()
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.connections.is_empty()
71 }
72
73 pub fn is_full(&self) -> bool {
75 self.connections.len() >= self.max_connections
76 }
77
78 pub fn max_connections(&self) -> usize {
80 self.max_connections
81 }
82
83 pub fn get(&self, addr: &TransportAddr) -> Option<&BleConnection<S>> {
85 self.connections.get(addr)
86 }
87
88 pub fn get_mut(&mut self, addr: &TransportAddr) -> Option<&mut BleConnection<S>> {
90 self.connections.get_mut(addr)
91 }
92
93 pub fn contains(&self, addr: &TransportAddr) -> bool {
95 self.connections.contains_key(addr)
96 }
97
98 pub fn insert(
103 &mut self,
104 addr: TransportAddr,
105 conn: BleConnection<S>,
106 ) -> Result<Option<TransportAddr>, TransportError> {
107 use std::collections::hash_map::Entry;
108
109 if let Entry::Occupied(mut e) = self.connections.entry(addr.clone()) {
111 e.insert(conn);
112 return Ok(None);
113 }
114
115 if !self.is_full() {
117 self.connections.insert(addr, conn);
118 return Ok(None);
119 }
120
121 let evicted = self.find_eviction_candidate(conn.is_static)?;
123 self.connections.remove(&evicted);
124 self.connections.insert(addr, conn);
125 Ok(Some(evicted))
126 }
127
128 pub fn remove(&mut self, addr: &TransportAddr) -> Option<BleConnection<S>> {
130 self.connections.remove(addr)
131 }
132
133 pub fn addrs(&self) -> Vec<TransportAddr> {
135 self.connections.keys().cloned().collect()
136 }
137
138 fn find_eviction_candidate(
143 &self,
144 new_is_static: bool,
145 ) -> Result<TransportAddr, TransportError> {
146 if new_is_static {
147 self.connections
149 .iter()
150 .filter(|(_, c)| !c.is_static)
151 .min_by_key(|(_, c)| c.established_at)
152 .map(|(addr, _)| addr.clone())
153 .ok_or_else(|| {
154 TransportError::NotSupported("BLE pool full: all connections are static".into())
155 })
156 } else {
157 self.connections
159 .iter()
160 .filter(|(_, c)| !c.is_static)
161 .min_by_key(|(_, c)| c.established_at)
162 .map(|(addr, _)| addr.clone())
163 .ok_or_else(|| {
164 TransportError::NotSupported("BLE pool full: all connections are static".into())
165 })
166 }
167 }
168}
169
170#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn test_addr(n: u8) -> TransportAddr {
179 TransportAddr::from_string(&format!("hci0/AA:BB:CC:DD:EE:{n:02X}"))
180 }
181
182 fn test_ble_addr(n: u8) -> BleAddr {
183 BleAddr {
184 adapter: "hci0".to_string(),
185 device: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, n],
186 }
187 }
188
189 fn test_conn(n: u8, is_static: bool) -> BleConnection<()> {
190 BleConnection {
191 stream: (),
192 recv_task: None,
193 send_mtu: 2048,
194 recv_mtu: 2048,
195 established_at: tokio::time::Instant::now(),
196 is_static,
197 addr: test_ble_addr(n),
198 }
199 }
200
201 #[test]
202 fn test_pool_basic_insert() {
203 let mut pool: ConnectionPool<()> = ConnectionPool::new(7);
204 assert!(pool.is_empty());
205
206 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
207 assert_eq!(pool.len(), 1);
208 assert!(!pool.is_empty());
209 assert!(pool.contains(&test_addr(1)));
210 }
211
212 #[test]
213 fn test_pool_remove() {
214 let mut pool: ConnectionPool<()> = ConnectionPool::new(7);
215 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
216 assert!(pool.remove(&test_addr(1)).is_some());
217 assert!(pool.is_empty());
218 }
219
220 #[test]
221 fn test_pool_full_eviction() {
222 let mut pool: ConnectionPool<()> = ConnectionPool::new(3);
223 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
224 pool.insert(test_addr(2), test_conn(2, false)).unwrap();
225 pool.insert(test_addr(3), test_conn(3, false)).unwrap();
226 assert!(pool.is_full());
227
228 let result = pool.insert(test_addr(4), test_conn(4, false));
230 assert!(result.is_ok());
231 assert!(result.unwrap().is_some()); assert_eq!(pool.len(), 3);
233 assert!(pool.contains(&test_addr(4)));
234 }
235
236 #[test]
237 fn test_pool_static_evicts_nonstatic() {
238 let mut pool: ConnectionPool<()> = ConnectionPool::new(2);
239 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
240 pool.insert(test_addr(2), test_conn(2, false)).unwrap();
241
242 let result = pool.insert(test_addr(3), test_conn(3, true));
244 assert!(result.is_ok());
245 assert_eq!(pool.len(), 2);
246 assert!(pool.contains(&test_addr(3)));
247 }
248
249 #[test]
250 fn test_pool_all_static_rejects() {
251 let mut pool: ConnectionPool<()> = ConnectionPool::new(2);
252 pool.insert(test_addr(1), test_conn(1, true)).unwrap();
253 pool.insert(test_addr(2), test_conn(2, true)).unwrap();
254
255 let result = pool.insert(test_addr(3), test_conn(3, false));
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn test_pool_replace_existing() {
262 let mut pool: ConnectionPool<()> = ConnectionPool::new(2);
263 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
264
265 let result = pool.insert(test_addr(1), test_conn(1, true));
267 assert!(result.is_ok());
268 assert_eq!(pool.len(), 1);
269 assert!(pool.get(&test_addr(1)).unwrap().is_static);
270 }
271
272 #[test]
273 fn test_pool_effective_mtu() {
274 let mut conn = test_conn(1, false);
275 conn.send_mtu = 1024;
276 conn.recv_mtu = 2048;
277 assert_eq!(conn.effective_mtu(), 1024);
278 }
279
280 #[test]
281 fn test_pool_addrs() {
282 let mut pool: ConnectionPool<()> = ConnectionPool::new(7);
283 pool.insert(test_addr(1), test_conn(1, false)).unwrap();
284 pool.insert(test_addr(2), test_conn(2, false)).unwrap();
285
286 let mut addrs = pool.addrs();
287 addrs.sort_by(|a, b| a.as_str().cmp(&b.as_str()));
288 assert_eq!(addrs.len(), 2);
289 }
290}