1use crate::{
2 DiscoveryKey, Key, Message, discovery_key,
3 message::ChannelMessage,
4 schema::*,
5 util::{map_channel_err, pretty_hash},
6};
7use async_channel::{Receiver, Sender, TrySendError};
8use futures_lite::{ready, stream::Stream};
9use std::{
10 collections::HashMap,
11 fmt,
12 io::{Error, ErrorKind, Result},
13 pin::Pin,
14 sync::{
15 Arc,
16 atomic::{AtomicBool, Ordering},
17 },
18 task::Poll,
19};
20use tracing::instrument;
21
22#[derive(Clone)]
26pub struct Channel {
27 inbound_rx: Option<Receiver<Message>>,
28 direct_inbound_tx: Sender<Message>,
29 outbound_tx: Sender<Vec<ChannelMessage>>,
30 key: Key,
31 discovery_key: DiscoveryKey,
32 local_id: usize,
33 closed: Arc<AtomicBool>,
34}
35
36impl PartialEq for Channel {
37 fn eq(&self, other: &Self) -> bool {
38 self.key == other.key
39 && self.discovery_key == other.discovery_key
40 && self.local_id == other.local_id
41 }
42}
43
44impl fmt::Debug for Channel {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.debug_struct("Channel")
47 .field("discovery_key", &pretty_hash(&self.discovery_key))
48 .finish()
49 }
50}
51
52impl Channel {
53 fn new(
54 inbound_rx: Option<Receiver<Message>>,
55 direct_inbound_tx: Sender<Message>,
56 outbound_tx: Sender<Vec<ChannelMessage>>,
57 discovery_key: DiscoveryKey,
58 key: Key,
59 local_id: usize,
60 closed: Arc<AtomicBool>,
61 ) -> Self {
62 Self {
63 inbound_rx,
64 direct_inbound_tx,
65 outbound_tx,
66 key,
67 discovery_key,
68 local_id,
69 closed,
70 }
71 }
72 pub fn discovery_key(&self) -> &[u8; 32] {
74 &self.discovery_key
75 }
76
77 pub fn key(&self) -> &[u8; 32] {
79 &self.key
80 }
81
82 pub fn id(&self) -> usize {
84 self.local_id
85 }
86
87 pub fn closed(&self) -> bool {
89 self.closed.load(Ordering::SeqCst)
90 }
91
92 pub async fn send(&self, message: Message) -> Result<()> {
94 if self.closed() {
95 return Err(Error::new(
96 ErrorKind::ConnectionAborted,
97 "Channel is closed",
98 ));
99 }
100 let message = ChannelMessage::new(self.local_id as u64, message);
101 self.outbound_tx
102 .send(vec![message])
103 .await
104 .map_err(map_channel_err)
105 }
106
107 pub fn send_batch(&self, messages: &[Message]) -> impl Future<Output = Result<()>> + use<> {
109 let closed = self.closed();
121
122 let messages = if !closed {
124 messages
125 .iter()
126 .map(|message| ChannelMessage::new(self.local_id as u64, message.clone()))
127 .collect()
128 } else {
129 vec![]
130 };
131
132 let outbound_tx = self.outbound_tx.clone();
133 async move {
134 if closed {
135 return Err(Error::new(
136 ErrorKind::ConnectionAborted,
137 "Channel is closed",
138 ));
139 }
140
141 outbound_tx.send(messages).await.map_err(map_channel_err)
142 }
143 }
144
145 pub fn take_receiver(&mut self) -> Option<Receiver<Message>> {
150 self.inbound_rx.take()
151 }
152
153 pub fn local_sender(&self) -> Sender<Message> {
159 self.direct_inbound_tx.clone()
160 }
161
162 pub async fn close(&self) -> Result<()> {
164 if self.closed() {
165 return Ok(());
166 }
167 let close = Close {
168 channel: self.local_id as u64,
169 };
170 self.send(Message::Close(close)).await?;
171 self.closed.store(true, Ordering::SeqCst);
172 Ok(())
173 }
174
175 pub async fn signal_local_protocol(&self, name: &str, data: Vec<u8>) -> Result<()> {
178 self.send(Message::LocalSignal((name.to_string(), data)))
179 .await?;
180 Ok(())
181 }
182}
183
184impl Stream for Channel {
185 type Item = Message;
186 fn poll_next(
187 self: Pin<&mut Self>,
188 cx: &mut std::task::Context<'_>,
189 ) -> std::task::Poll<Option<Self::Item>> {
190 let this = self.get_mut();
191 match this.inbound_rx.as_mut() {
192 None => Poll::Ready(None),
193 Some(ref mut inbound_rx) => {
194 let message = ready!(Pin::new(inbound_rx).poll_next(cx));
195 Poll::Ready(message)
196 }
197 }
198 }
199}
200
201#[derive(Clone, Debug)]
203pub(crate) struct ChannelHandle {
204 discovery_key: DiscoveryKey,
205 local_state: Option<LocalState>,
206 remote_state: Option<RemoteState>,
207 inbound_tx: Option<Sender<Message>>,
208 closed: Arc<AtomicBool>,
209}
210
211#[derive(Clone, Debug)]
212struct LocalState {
213 key: Key,
214 local_id: usize,
215}
216
217#[derive(Clone, Debug)]
218struct RemoteState {
219 remote_id: usize,
220 remote_capability: Option<Vec<u8>>,
221}
222
223impl ChannelHandle {
224 fn new(discovery_key: DiscoveryKey) -> Self {
225 Self {
226 discovery_key,
227 local_state: None,
228 remote_state: None,
229 inbound_tx: None,
230 closed: Arc::new(AtomicBool::new(false)),
231 }
232 }
233 fn new_local(local_id: usize, discovery_key: DiscoveryKey, key: Key) -> Self {
234 let mut this = Self::new(discovery_key);
235 this.attach_local(local_id, key);
236 this
237 }
238
239 fn new_remote(
240 remote_id: usize,
241 discovery_key: DiscoveryKey,
242 remote_capability: Option<Vec<u8>>,
243 ) -> Self {
244 let mut this = Self::new(discovery_key);
245 this.attach_remote(remote_id, remote_capability);
246 this
247 }
248
249 pub(crate) fn discovery_key(&self) -> &[u8; 32] {
250 &self.discovery_key
251 }
252
253 pub(crate) fn local_id(&self) -> Option<usize> {
254 self.local_state.as_ref().map(|s| s.local_id)
255 }
256
257 pub(crate) fn remote_id(&self) -> Option<usize> {
258 self.remote_state.as_ref().map(|s| s.remote_id)
259 }
260
261 #[instrument(skip_all, fields(local_id = local_id))]
262 pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) {
263 let local_state = LocalState { local_id, key };
264 self.local_state = Some(local_state);
265 }
266
267 pub(crate) fn attach_remote(&mut self, remote_id: usize, remote_capability: Option<Vec<u8>>) {
268 let remote_state = RemoteState {
269 remote_id,
270 remote_capability,
271 };
272 self.remote_state = Some(remote_state);
273 }
274
275 pub(crate) fn is_connected(&self) -> bool {
276 self.local_state.is_some() && self.remote_state.is_some()
277 }
278
279 pub(crate) fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec<u8>>)> {
280 if !self.is_connected() {
281 return Err(error("Channel is not opened from both local and remote"));
282 }
283 let local_state = self.local_state.as_ref().unwrap();
285 let remote_state = self.remote_state.as_ref().unwrap();
286 Ok((&local_state.key, remote_state.remote_capability.as_ref()))
287 }
288
289 #[instrument(skip_all)]
290 pub(crate) fn open(&mut self, outbound_tx: Sender<Vec<ChannelMessage>>) -> Channel {
291 let local_state = self
292 .local_state
293 .as_ref()
294 .expect("May not open channel that is not locally attached");
295
296 let (inbound_tx, inbound_rx) = async_channel::unbounded();
297 let channel = Channel::new(
298 Some(inbound_rx),
299 inbound_tx.clone(),
300 outbound_tx,
301 self.discovery_key,
302 local_state.key,
303 local_state.local_id,
304 self.closed.clone(),
305 );
306
307 self.inbound_tx = Some(inbound_tx);
308 channel
309 }
310
311 pub(crate) fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> {
312 if let Some(inbound_tx) = self.inbound_tx.as_mut() {
313 inbound_tx
314 .try_send(message)
315 .map_err(|e| error(format!("Sending to channel failed: {e}").as_str()))
316 } else {
317 Err(error("Channel is not open"))
318 }
319 }
320
321 pub(crate) fn try_send_inbound_tolerate_closed(
322 &mut self,
323 message: Message,
324 ) -> std::io::Result<()> {
325 if let Some(inbound_tx) = self.inbound_tx.as_mut()
326 && let Err(err) = inbound_tx.try_send(message)
327 {
328 match err {
329 TrySendError::Full(e) => {
330 return Err(error(format!("Sending to channel failed: {e}").as_str()));
331 }
332 TrySendError::Closed(_) => {}
333 }
334 }
335 Ok(())
336 }
337}
338
339impl Drop for ChannelHandle {
340 fn drop(&mut self) {
341 self.closed.store(true, Ordering::SeqCst);
342 }
343}
344
345#[derive(Debug)]
347pub(crate) struct ChannelMap {
348 channels: HashMap<String, ChannelHandle>,
349 local_id: Vec<Option<String>>,
350 remote_id: Vec<Option<String>>,
351}
352
353impl ChannelMap {
354 pub(crate) fn new() -> Self {
355 Self {
356 channels: HashMap::new(),
357 local_id: vec![None],
360 remote_id: vec![],
361 }
362 }
363
364 pub(crate) fn attach_local(&mut self, key: Key) -> &ChannelHandle {
365 let discovery_key = discovery_key(&key);
366 let hdkey = hex::encode(discovery_key);
367 let local_id = self.alloc_local();
368
369 self.channels
370 .entry(hdkey.clone())
371 .and_modify(|channel| channel.attach_local(local_id, key))
372 .or_insert_with(|| ChannelHandle::new_local(local_id, discovery_key, key));
373
374 self.local_id[local_id] = Some(hdkey.clone());
375 self.channels.get(&hdkey).unwrap()
376 }
377
378 pub(crate) fn attach_remote(
379 &mut self,
380 discovery_key: DiscoveryKey,
381 remote_id: usize,
382 remote_capability: Option<Vec<u8>>,
383 ) -> &ChannelHandle {
384 let hdkey = hex::encode(discovery_key);
385 self.alloc_remote(remote_id);
386 self.channels
387 .entry(hdkey.clone())
388 .and_modify(|channel| channel.attach_remote(remote_id, remote_capability.clone()))
389 .or_insert_with(|| {
390 ChannelHandle::new_remote(remote_id, discovery_key, remote_capability)
391 });
392 self.remote_id[remote_id] = Some(hdkey.clone());
393 self.channels.get(&hdkey).unwrap()
394 }
395
396 pub(crate) fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> {
397 if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
398 self.channels.get_mut(hdkey)
399 } else {
400 None
401 }
402 }
403
404 pub(crate) fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> {
405 if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
406 self.channels.get(hdkey)
407 } else {
408 None
409 }
410 }
411
412 pub(crate) fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> {
413 if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
414 self.channels.get_mut(hdkey)
415 } else {
416 None
417 }
418 }
419
420 pub(crate) fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> {
421 if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
422 self.channels.get(hdkey)
423 } else {
424 None
425 }
426 }
427
428 pub(crate) fn has_channel(&self, discovery_key: &[u8]) -> bool {
429 let hdkey = hex::encode(discovery_key);
430 self.channels.contains_key(&hdkey)
431 }
432
433 pub(crate) fn remove(&mut self, discovery_key: &[u8]) {
434 let hdkey = hex::encode(discovery_key);
435 let channel = self.channels.get(&hdkey);
436 if let Some(channel) = channel {
437 if let Some(local_id) = channel.local_id() {
438 self.local_id[local_id] = None;
439 }
440 if let Some(remote_id) = channel.remote_id() {
441 self.remote_id[remote_id] = None;
442 }
443 }
444 self.channels.remove(&hdkey);
445 }
446
447 #[instrument(skip(self))]
448 pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec<u8>>)> {
449 let channel_handle = self
450 .get_local(local_id)
451 .ok_or_else(|| error("Channel not found"))?;
452 channel_handle.prepare_to_verify()
453 }
454
455 pub(crate) fn accept(
456 &mut self,
457 local_id: usize,
458 outbound_tx: Sender<Vec<ChannelMessage>>,
459 ) -> Result<Channel> {
460 let channel_handle = self
461 .get_local_mut(local_id)
462 .ok_or_else(|| error("Channel not found"))?;
463 if !channel_handle.is_connected() {
464 return Err(error("Channel is not opened from remote"));
465 }
466 let channel = channel_handle.open(outbound_tx);
467 Ok(channel)
468 }
469
470 pub(crate) fn forward_inbound_message(
471 &mut self,
472 remote_id: usize,
473 message: Message,
474 ) -> Result<()> {
475 if let Some(channel_handle) = self.get_remote_mut(remote_id) {
476 channel_handle.try_send_inbound(message)?;
477 }
478 Ok(())
479 }
480
481 pub(crate) fn forward_inbound_message_tolerate_closed(
482 &mut self,
483 remote_id: usize,
484 message: Message,
485 ) -> Result<()> {
486 if let Some(channel_handle) = self.get_remote_mut(remote_id) {
487 channel_handle.try_send_inbound_tolerate_closed(message)?;
488 }
489 Ok(())
490 }
491
492 fn alloc_local(&mut self) -> usize {
493 let empty_id = self
494 .local_id
495 .iter()
496 .skip(1)
497 .position(|x| x.is_none())
498 .map(|position| position + 1);
499 match empty_id {
500 Some(empty_id) => empty_id,
501 None => {
502 self.local_id.push(None);
503 self.local_id.len() - 1
504 }
505 }
506 }
507
508 fn alloc_remote(&mut self, id: usize) {
509 if self.remote_id.len() > id {
510 self.remote_id[id] = None;
511 } else {
512 self.remote_id.resize(id + 1, None)
513 }
514 }
515
516 pub(crate) fn iter(&self) -> impl Iterator<Item = &ChannelHandle> {
517 self.channels.values()
518 }
519}
520
521fn error(message: &str) -> Error {
522 Error::other(message)
523}