1use crate::message::ChannelMessage;
2use crate::schema::*;
3use crate::util::{map_channel_err, pretty_hash};
4use crate::Message;
5use crate::{discovery_key, DiscoveryKey, Key};
6use async_channel::{Receiver, Sender, TrySendError};
7use futures_lite::ready;
8use futures_lite::stream::Stream;
9use std::collections::HashMap;
10use std::fmt;
11use std::io::{Error, ErrorKind, Result};
12use std::pin::Pin;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15use std::task::Poll;
16use tracing::debug;
17
18#[derive(Clone)]
22pub struct Channel {
23 inbound_rx: Option<Receiver<Message>>,
24 direct_inbound_tx: Sender<Message>,
25 outbound_tx: Sender<Vec<ChannelMessage>>,
26 key: Key,
27 discovery_key: DiscoveryKey,
28 local_id: usize,
29 closed: Arc<AtomicBool>,
30}
31
32impl PartialEq for Channel {
33 fn eq(&self, other: &Self) -> bool {
34 self.key == other.key
35 && self.discovery_key == other.discovery_key
36 && self.local_id == other.local_id
37 }
38}
39
40impl fmt::Debug for Channel {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("Channel")
43 .field("discovery_key", &pretty_hash(&self.discovery_key))
44 .finish()
45 }
46}
47
48impl Channel {
49 fn new(
50 inbound_rx: Option<Receiver<Message>>,
51 direct_inbound_tx: Sender<Message>,
52 outbound_tx: Sender<Vec<ChannelMessage>>,
53 discovery_key: DiscoveryKey,
54 key: Key,
55 local_id: usize,
56 closed: Arc<AtomicBool>,
57 ) -> Self {
58 Self {
59 inbound_rx,
60 direct_inbound_tx,
61 outbound_tx,
62 key,
63 discovery_key,
64 local_id,
65 closed,
66 }
67 }
68 pub fn discovery_key(&self) -> &[u8; 32] {
70 &self.discovery_key
71 }
72
73 pub fn key(&self) -> &[u8; 32] {
75 &self.key
76 }
77
78 pub fn id(&self) -> usize {
80 self.local_id
81 }
82
83 pub fn closed(&self) -> bool {
85 self.closed.load(Ordering::SeqCst)
86 }
87
88 pub async fn send(&mut self, message: Message) -> Result<()> {
90 if self.closed() {
91 return Err(Error::new(
92 ErrorKind::ConnectionAborted,
93 "Channel is closed",
94 ));
95 }
96 debug!("TX:\n{message:?}\n");
97 let message = ChannelMessage::new(self.local_id as u64, message);
98 self.outbound_tx
99 .send(vec![message])
100 .await
101 .map_err(map_channel_err)
102 }
103
104 pub async fn send_batch(&mut self, messages: &[Message]) -> Result<()> {
106 if self.closed() {
117 return Err(Error::new(
118 ErrorKind::ConnectionAborted,
119 "Channel is closed",
120 ));
121 }
122
123 let messages = messages
124 .iter()
125 .map(|message| {
126 debug!("TX:\n{message:?}\n");
127 ChannelMessage::new(self.local_id as u64, message.clone())
128 })
129 .collect();
130 self.outbound_tx
131 .send(messages)
132 .await
133 .map_err(map_channel_err)
134 }
135
136 pub fn take_receiver(&mut self) -> Option<Receiver<Message>> {
141 self.inbound_rx.take()
142 }
143
144 pub fn local_sender(&self) -> Sender<Message> {
150 self.direct_inbound_tx.clone()
151 }
152
153 pub async fn close(&mut self) -> Result<()> {
155 if self.closed() {
156 return Ok(());
157 }
158 let close = Close {
159 channel: self.local_id as u64,
160 };
161 self.send(Message::Close(close)).await?;
162 self.closed.store(true, Ordering::SeqCst);
163 Ok(())
164 }
165
166 pub async fn signal_local_protocol(&mut self, name: &str, data: Vec<u8>) -> Result<()> {
169 self.send(Message::LocalSignal((name.to_string(), data)))
170 .await?;
171 Ok(())
172 }
173}
174
175impl Stream for Channel {
176 type Item = Message;
177 fn poll_next(
178 self: Pin<&mut Self>,
179 cx: &mut std::task::Context<'_>,
180 ) -> std::task::Poll<Option<Self::Item>> {
181 let this = self.get_mut();
182 match this.inbound_rx.as_mut() {
183 None => Poll::Ready(None),
184 Some(ref mut inbound_rx) => {
185 let message = ready!(Pin::new(inbound_rx).poll_next(cx));
186 Poll::Ready(message)
187 }
188 }
189 }
190}
191
192#[derive(Clone, Debug)]
194pub(crate) struct ChannelHandle {
195 discovery_key: DiscoveryKey,
196 local_state: Option<LocalState>,
197 remote_state: Option<RemoteState>,
198 inbound_tx: Option<Sender<Message>>,
199 closed: Arc<AtomicBool>,
200}
201
202#[derive(Clone, Debug)]
203struct LocalState {
204 key: Key,
205 local_id: usize,
206}
207
208#[derive(Clone, Debug)]
209struct RemoteState {
210 remote_id: usize,
211 remote_capability: Option<Vec<u8>>,
212}
213
214impl ChannelHandle {
215 fn new(discovery_key: DiscoveryKey) -> Self {
216 Self {
217 discovery_key,
218 local_state: None,
219 remote_state: None,
220 inbound_tx: None,
221 closed: Arc::new(AtomicBool::new(false)),
222 }
223 }
224 fn new_local(local_id: usize, discovery_key: DiscoveryKey, key: Key) -> Self {
225 let mut this = Self::new(discovery_key);
226 this.attach_local(local_id, key);
227 this
228 }
229
230 fn new_remote(
231 remote_id: usize,
232 discovery_key: DiscoveryKey,
233 remote_capability: Option<Vec<u8>>,
234 ) -> Self {
235 let mut this = Self::new(discovery_key);
236 this.attach_remote(remote_id, remote_capability);
237 this
238 }
239
240 pub(crate) fn discovery_key(&self) -> &[u8; 32] {
241 &self.discovery_key
242 }
243
244 pub(crate) fn local_id(&self) -> Option<usize> {
245 self.local_state.as_ref().map(|s| s.local_id)
246 }
247
248 pub(crate) fn remote_id(&self) -> Option<usize> {
249 self.remote_state.as_ref().map(|s| s.remote_id)
250 }
251
252 pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) {
253 let local_state = LocalState { local_id, key };
254 self.local_state = Some(local_state);
255 }
256
257 pub(crate) fn attach_remote(&mut self, remote_id: usize, remote_capability: Option<Vec<u8>>) {
258 let remote_state = RemoteState {
259 remote_id,
260 remote_capability,
261 };
262 self.remote_state = Some(remote_state);
263 }
264
265 pub(crate) fn is_connected(&self) -> bool {
266 self.local_state.is_some() && self.remote_state.is_some()
267 }
268
269 pub(crate) fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec<u8>>)> {
270 if !self.is_connected() {
271 return Err(error("Channel is not opened from both local and remote"));
272 }
273 let local_state = self.local_state.as_ref().unwrap();
275 let remote_state = self.remote_state.as_ref().unwrap();
276 Ok((&local_state.key, remote_state.remote_capability.as_ref()))
277 }
278
279 pub(crate) fn open(&mut self, outbound_tx: Sender<Vec<ChannelMessage>>) -> Channel {
280 let local_state = self
281 .local_state
282 .as_ref()
283 .expect("May not open channel that is not locally attached");
284
285 let (inbound_tx, inbound_rx) = async_channel::unbounded();
286 let channel = Channel::new(
287 Some(inbound_rx),
288 inbound_tx.clone(),
289 outbound_tx,
290 self.discovery_key,
291 local_state.key,
292 local_state.local_id,
293 self.closed.clone(),
294 );
295
296 self.inbound_tx = Some(inbound_tx);
297 channel
298 }
299
300 pub(crate) fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> {
301 if let Some(inbound_tx) = self.inbound_tx.as_mut() {
302 inbound_tx
303 .try_send(message)
304 .map_err(|e| error(format!("Sending to channel failed: {e}").as_str()))
305 } else {
306 Err(error("Channel is not open"))
307 }
308 }
309
310 pub(crate) fn try_send_inbound_tolerate_closed(
311 &mut self,
312 message: Message,
313 ) -> std::io::Result<()> {
314 if let Some(inbound_tx) = self.inbound_tx.as_mut() {
315 if let Err(err) = inbound_tx.try_send(message) {
316 match err {
317 TrySendError::Full(e) => {
318 return Err(error(format!("Sending to channel failed: {e}").as_str()))
319 }
320 TrySendError::Closed(_) => {}
321 }
322 }
323 }
324 Ok(())
325 }
326}
327
328impl Drop for ChannelHandle {
329 fn drop(&mut self) {
330 self.closed.store(true, Ordering::SeqCst);
331 }
332}
333
334#[derive(Debug)]
336pub(crate) struct ChannelMap {
337 channels: HashMap<String, ChannelHandle>,
338 local_id: Vec<Option<String>>,
339 remote_id: Vec<Option<String>>,
340}
341
342impl ChannelMap {
343 pub(crate) fn new() -> Self {
344 Self {
345 channels: HashMap::new(),
346 local_id: vec![None],
349 remote_id: vec![],
350 }
351 }
352
353 pub(crate) fn attach_local(&mut self, key: Key) -> &ChannelHandle {
354 let discovery_key = discovery_key(&key);
355 let hdkey = hex::encode(discovery_key);
356 let local_id = self.alloc_local();
357
358 self.channels
359 .entry(hdkey.clone())
360 .and_modify(|channel| channel.attach_local(local_id, key))
361 .or_insert_with(|| ChannelHandle::new_local(local_id, discovery_key, key));
362
363 self.local_id[local_id] = Some(hdkey.clone());
364 self.channels.get(&hdkey).unwrap()
365 }
366
367 pub(crate) fn attach_remote(
368 &mut self,
369 discovery_key: DiscoveryKey,
370 remote_id: usize,
371 remote_capability: Option<Vec<u8>>,
372 ) -> &ChannelHandle {
373 let hdkey = hex::encode(discovery_key);
374 self.alloc_remote(remote_id);
375 self.channels
376 .entry(hdkey.clone())
377 .and_modify(|channel| channel.attach_remote(remote_id, remote_capability.clone()))
378 .or_insert_with(|| {
379 ChannelHandle::new_remote(remote_id, discovery_key, remote_capability)
380 });
381 self.remote_id[remote_id] = Some(hdkey.clone());
382 self.channels.get(&hdkey).unwrap()
383 }
384
385 pub(crate) fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> {
386 if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
387 self.channels.get_mut(hdkey)
388 } else {
389 None
390 }
391 }
392
393 pub(crate) fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> {
394 if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() {
395 self.channels.get(hdkey)
396 } else {
397 None
398 }
399 }
400
401 pub(crate) fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> {
402 if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
403 self.channels.get_mut(hdkey)
404 } else {
405 None
406 }
407 }
408
409 pub(crate) fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> {
410 if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() {
411 self.channels.get(hdkey)
412 } else {
413 None
414 }
415 }
416
417 pub(crate) fn has_channel(&self, discovery_key: &[u8]) -> bool {
418 let hdkey = hex::encode(discovery_key);
419 self.channels.contains_key(&hdkey)
420 }
421
422 pub(crate) fn remove(&mut self, discovery_key: &[u8]) {
423 let hdkey = hex::encode(discovery_key);
424 let channel = self.channels.get(&hdkey);
425 if let Some(channel) = channel {
426 if let Some(local_id) = channel.local_id() {
427 self.local_id[local_id] = None;
428 }
429 if let Some(remote_id) = channel.remote_id() {
430 self.remote_id[remote_id] = None;
431 }
432 }
433 self.channels.remove(&hdkey);
434 }
435
436 pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec<u8>>)> {
437 let channel_handle = self
438 .get_local(local_id)
439 .ok_or_else(|| error("Channel not found"))?;
440 channel_handle.prepare_to_verify()
441 }
442
443 pub(crate) fn accept(
444 &mut self,
445 local_id: usize,
446 outbound_tx: Sender<Vec<ChannelMessage>>,
447 ) -> Result<Channel> {
448 let channel_handle = self
449 .get_local_mut(local_id)
450 .ok_or_else(|| error("Channel not found"))?;
451 if !channel_handle.is_connected() {
452 return Err(error("Channel is not opened from remote"));
453 }
454 let channel = channel_handle.open(outbound_tx);
455 Ok(channel)
456 }
457
458 pub(crate) fn forward_inbound_message(
459 &mut self,
460 remote_id: usize,
461 message: Message,
462 ) -> Result<()> {
463 if let Some(channel_handle) = self.get_remote_mut(remote_id) {
464 channel_handle.try_send_inbound(message)?;
465 }
466 Ok(())
467 }
468
469 pub(crate) fn forward_inbound_message_tolerate_closed(
470 &mut self,
471 remote_id: usize,
472 message: Message,
473 ) -> Result<()> {
474 if let Some(channel_handle) = self.get_remote_mut(remote_id) {
475 channel_handle.try_send_inbound_tolerate_closed(message)?;
476 }
477 Ok(())
478 }
479
480 fn alloc_local(&mut self) -> usize {
481 let empty_id = self
482 .local_id
483 .iter()
484 .skip(1)
485 .position(|x| x.is_none())
486 .map(|position| position + 1);
487 match empty_id {
488 Some(empty_id) => empty_id,
489 None => {
490 self.local_id.push(None);
491 self.local_id.len() - 1
492 }
493 }
494 }
495
496 fn alloc_remote(&mut self, id: usize) {
497 if self.remote_id.len() > id {
498 self.remote_id[id] = None;
499 } else {
500 self.remote_id.resize(id + 1, None)
501 }
502 }
503
504 pub(crate) fn iter(&self) -> impl Iterator<Item = &ChannelHandle> {
505 self.channels.values()
506 }
507}
508
509fn error(message: &str) -> Error {
510 Error::new(ErrorKind::Other, message)
511}