1use std::any::Any;
15#[cfg(feature = "keygen")]
16use std::net::TcpListener;
17
18#[cfg(feature = "keygen")]
19use addr::NodeAddr;
20use addr::NodeId;
21use amplify::Bipolar;
22#[cfg(feature = "keygen")]
23use inet2_addr::InetSocketAddr;
24#[cfg(feature = "zmq")]
25use inet2_addr::ServiceAddr;
26
27use super::{Decrypt, Encrypt, Transcode};
28use crate::session::noise::FramingProtocol;
29use crate::session::{noise, PlainTranscoder};
30use crate::transport::{
31 encrypted, unencrypted, DuplexConnection, Error, RecvFrame, RoutedFrame,
32 SendFrame,
33};
34#[cfg(feature = "zmq")]
35use crate::zeromq;
36use crate::{NoiseDecryptor, NoiseTranscoder};
37
38pub trait SendRecvMessage {
42 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error>;
43 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error>;
44 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error>;
45 fn send_routed_message(
46 &mut self,
47 source: &[u8],
48 route: &[u8],
49 dest: &[u8],
50 raw: &[u8],
51 ) -> Result<usize, Error>;
52 fn into_any(self: Box<Self>) -> Box<dyn Any>;
53}
54
55pub trait Split {
56 fn split(
57 self,
58 ) -> (Box<dyn RecvMessage + Send>, Box<dyn SendMessage + Send>);
59}
60
61pub trait RecvMessage {
62 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error>;
63 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
64 panic!("Multipeer sockets are not possible with the chosen transport")
68 }
69}
70
71pub trait SendMessage {
72 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error>;
73 fn send_routed_message(
74 &mut self,
75 _source: &[u8],
76 _route: &[u8],
77 _dest: &[u8],
78 _raw: &[u8],
79 ) -> Result<usize, Error> {
80 panic!("Multipeer sockets are not possible with the chosen transport")
84 }
85}
86
87pub struct Session<T, C>
88where
89 T: Transcode,
90 T::Left: Decrypt,
91 T::Right: Encrypt,
92 C: DuplexConnection + Bipolar,
93 C::Left: RecvFrame,
94 C::Right: SendFrame,
95{
96 pub(self) transcoder: T,
97 pub(self) connection: C,
98}
99
100pub struct Receiver<D, R>
101where
102 D: Decrypt,
103 R: RecvFrame,
104{
105 pub(self) decryptor: D,
106 pub(self) input: R,
107}
108
109pub struct Sender<E, S>
110where
111 E: Encrypt,
112 S: SendFrame,
113{
114 pub(self) encryptor: E,
115 pub(self) output: S,
116}
117
118trait InternalSession {
120 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error>;
121 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error>;
122 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error>;
123 fn send_routed_message(
124 &mut self,
125 source: &[u8],
126 route: &[u8],
127 dest: &[u8],
128 raw: &[u8],
129 ) -> Result<usize, Error>;
130}
131
132impl<T, C> InternalSession for Session<T, C>
133where
134 T: Transcode + 'static,
135 T::Left: Decrypt,
136 T::Right: Encrypt,
137 C: DuplexConnection + Bipolar + 'static,
138 C::Left: RecvFrame,
139 C::Right: SendFrame,
140 Error: From<T::Error> + From<<T::Left as Decrypt>::Error>,
141{
142 #[inline]
143 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
144 let reader = self.connection.as_receiver();
145 Ok(self.transcoder.decrypt(reader.recv_frame()?)?)
146 }
147
148 #[inline]
149 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error> {
150 let writer = self.connection.as_sender();
151 writer.send_frame(&self.transcoder.encrypt(raw))
152 }
153
154 #[inline]
155 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
156 let reader = self.connection.as_receiver();
157 let mut routed_frame = reader.recv_routed()?;
158 routed_frame.msg = self.transcoder.decrypt(routed_frame.msg)?;
159 Ok(routed_frame)
160 }
161
162 #[inline]
163 fn send_routed_message(
164 &mut self,
165 source: &[u8],
166 route: &[u8],
167 dest: &[u8],
168 raw: &[u8],
169 ) -> Result<usize, Error> {
170 let writer = self.connection.as_sender();
171 writer.send_routed(source, route, dest, &self.transcoder.encrypt(raw))
172 }
173}
174
175impl SendRecvMessage for Session<PlainTranscoder, unencrypted::Connection> {
176 #[inline]
177 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
178 InternalSession::recv_raw_message(self)
179 }
180 #[inline]
181 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error> {
182 InternalSession::send_raw_message(self, raw)
183 }
184 #[inline]
185 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
186 InternalSession::recv_routed_message(self)
187 }
188 #[inline]
189 fn send_routed_message(
190 &mut self,
191 source: &[u8],
192 route: &[u8],
193 dest: &[u8],
194 raw: &[u8],
195 ) -> Result<usize, Error> {
196 InternalSession::send_routed_message(self, source, route, dest, raw)
197 }
198 #[inline]
199 fn into_any(self: Box<Self>) -> Box<dyn Any> { self }
200}
201
202fn recv_noise_message<const LEN_SIZE: usize>(
203 reader: &mut dyn RecvFrame,
204 decrypt: &mut NoiseDecryptor<LEN_SIZE>,
205) -> Result<Vec<u8>, Error> {
206 let encrypted_len = reader.recv_frame()?;
208 decrypt.decrypt(encrypted_len)?;
209 let len = decrypt.pending_message_len().ok_or(Error::NoNoiseHeader)?;
210 let encrypted_payload = reader.recv_raw(len + noise::chacha::TAG_SIZE)?;
212 let payload = decrypt.decrypt(encrypted_payload)?;
213 Ok(payload)
214}
215
216impl<const LEN_SIZE: usize> SendRecvMessage
217 for Session<NoiseTranscoder<LEN_SIZE>, encrypted::Connection<LEN_SIZE>>
218{
219 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
220 let reader = self.connection.as_receiver();
221 recv_noise_message(reader, &mut self.transcoder.decryptor)
222 }
223
224 #[inline]
225 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error> {
226 InternalSession::send_raw_message(self, raw)
227 }
228 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
229 unimplemented!(
230 "to route brontide messages use presentation-level onion routing"
231 )
232 }
233 fn send_routed_message(
234 &mut self,
235 source: &[u8],
236 route: &[u8],
237 dest: &[u8],
238 raw: &[u8],
239 ) -> Result<usize, Error> {
240 unimplemented!(
241 "to route brontide messages use presentation-level onion routing"
242 )
243 }
244 #[inline]
245 fn into_any(self: Box<Self>) -> Box<dyn Any> { self }
246}
247
248#[cfg(feature = "zmq")]
249impl SendRecvMessage for Session<PlainTranscoder, zeromq::Connection> {
250 #[inline]
251 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
252 InternalSession::recv_raw_message(self)
253 }
254 #[inline]
255 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error> {
256 InternalSession::send_raw_message(self, raw)
257 }
258 #[inline]
259 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
260 InternalSession::recv_routed_message(self)
261 }
262 #[inline]
263 fn send_routed_message(
264 &mut self,
265 source: &[u8],
266 route: &[u8],
267 dest: &[u8],
268 raw: &[u8],
269 ) -> Result<usize, Error> {
270 InternalSession::send_routed_message(self, source, route, dest, raw)
271 }
272 #[inline]
273 fn into_any(self: Box<Self>) -> Box<dyn Any> { self }
274}
275
276impl<T, C> Split for Session<T, C>
277where
278 T: Transcode,
279 T::Left: Decrypt + Send + 'static,
280 T::Right: Encrypt + Send + 'static,
281 C: DuplexConnection + Bipolar,
282 C::Left: RecvFrame + Send + 'static,
283 C::Right: SendFrame + Send + 'static,
284 Receiver<T::Left, C::Left>: RecvMessage,
285 Error: From<T::Error> + From<<T::Left as Decrypt>::Error>,
286{
287 #[inline]
288 fn split(
289 self,
290 ) -> (Box<dyn RecvMessage + Send>, Box<dyn SendMessage + Send>) {
291 let (decryptor, encryptor) = self.transcoder.split();
292 let (input, output) = Bipolar::split(self.connection);
293 (
294 Box::new(Receiver { decryptor, input }),
295 Box::new(Sender { encryptor, output }),
296 )
297 }
298}
299
300pub type BrontideSession = Session<
301 NoiseTranscoder<{ FramingProtocol::Brontide.message_len_size() }>,
302 encrypted::Connection<2>,
303>;
304pub type BrontozaurSession = Session<
305 NoiseTranscoder<{ FramingProtocol::Brontozaur.message_len_size() }>,
306 encrypted::Connection<3>,
307>;
308#[cfg(feature = "zmq")]
309pub type LocalSession = Session<PlainTranscoder, zeromq::Connection>;
310#[cfg(feature = "zmq")]
311pub type RpcSession = Session<
312 NoiseTranscoder<{ FramingProtocol::Brontozaur.message_len_size() }>,
313 zeromq::Connection,
314>;
315
316impl<const LEN_SIZE: usize>
317 Session<NoiseTranscoder<LEN_SIZE>, encrypted::Connection<LEN_SIZE>>
318{
319 #[inline]
320 pub fn remote_id(&self) -> NodeId { self.transcoder.remote_pubkey().into() }
321}
322
323#[cfg(feature = "keygen")]
324impl BrontideSession {
325 pub fn with(
326 stream: std::net::TcpStream,
327 local_key: secp256k1::SecretKey,
328 remote_addr: InetSocketAddr,
329 ) -> Result<Self, Error> {
330 BrontideSession::with_tcp_encrypted(stream, local_key, remote_addr)
331 }
332
333 pub fn connect(
334 local_key: secp256k1::SecretKey,
335 remote_node: NodeAddr,
336 ) -> Result<Self, Error> {
337 BrontideSession::connect_tcp_encrypted(local_key, remote_node)
338 }
339
340 pub fn accept(
341 local_key: secp256k1::SecretKey,
342 listener: &TcpListener,
343 ) -> Result<Self, Error> {
344 BrontideSession::accept_tcp_encrypted(local_key, listener)
345 }
346}
347
348#[cfg(feature = "keygen")]
349impl BrontozaurSession {
350 pub fn with(
351 stream: std::net::TcpStream,
352 local_key: secp256k1::SecretKey,
353 remote_addr: InetSocketAddr,
354 ) -> Result<Self, Error> {
355 BrontozaurSession::with_tcp_encrypted(stream, local_key, remote_addr)
356 }
357
358 pub fn connect_with(
359 stream: std::net::TcpStream,
360 local_key: secp256k1::SecretKey,
361 remote_node: NodeAddr,
362 ) -> Result<Self, Error> {
363 BrontozaurSession::connect_with_tcp_encrypted(
364 stream,
365 local_key,
366 remote_node,
367 )
368 }
369
370 pub fn connect(
371 local_key: secp256k1::SecretKey,
372 remote_node: NodeAddr,
373 ) -> Result<Self, Error> {
374 BrontozaurSession::connect_tcp_encrypted(local_key, remote_node)
375 }
376
377 pub fn accept(
378 local_key: secp256k1::SecretKey,
379 listener: &TcpListener,
380 ) -> Result<Self, Error> {
381 BrontozaurSession::accept_tcp_encrypted(local_key, listener)
382 }
383}
384
385#[cfg(feature = "zmq")]
386impl LocalSession {
387 pub fn connect(
388 zmq_type: zeromq::ZmqSocketType,
389 remote: &ServiceAddr,
390 local: Option<&ServiceAddr>,
391 identity: Option<&[u8]>,
392 context: &zmq::Context,
393 ) -> Result<Self, Error> {
394 LocalSession::connect_zmq_unencrypted(
395 zmq_type, remote, local, identity, context,
396 )
397 }
398
399 pub fn with_zmq_socket(
400 zmq_type: zeromq::ZmqSocketType,
401 socket: zmq::Socket,
402 ) -> Self {
403 LocalSession::with_zmq_socket_unencrypted(zmq_type, socket)
404 }
405}
406
407#[cfg(feature = "keygen")]
432impl<const LEN_SIZE: usize>
433 Session<NoiseTranscoder<LEN_SIZE>, encrypted::Connection<LEN_SIZE>>
434{
435 fn with_tcp_encrypted(
436 stream: std::net::TcpStream,
437 local_key: secp256k1::SecretKey,
438 remote_addr: InetSocketAddr,
439 ) -> Result<Self, Error> {
440 Self::init_tcp_encrypted(
441 local_key,
442 encrypted::Connection::with(stream, remote_addr),
443 )
444 }
445
446 fn connect_with_tcp_encrypted(
447 stream: std::net::TcpStream,
448 local_key: secp256k1::SecretKey,
449 remote_node: NodeAddr,
450 ) -> Result<Self, Error> {
451 let mut connection =
452 encrypted::Connection::with(stream, remote_node.addr);
453 let transcoder = NoiseTranscoder::new_initiator(
454 local_key,
455 remote_node.public_key(),
456 &mut connection,
457 )?;
458 Ok(Self {
459 transcoder,
460 connection,
461 })
462 }
463
464 fn connect_tcp_encrypted(
465 local_key: secp256k1::SecretKey,
466 remote_node: NodeAddr,
467 ) -> Result<Self, Error> {
468 let mut connection = encrypted::Connection::connect(remote_node.addr)?;
469 let transcoder = NoiseTranscoder::new_initiator(
470 local_key,
471 remote_node.public_key(),
472 &mut connection,
473 )?;
474 Ok(Self {
475 transcoder,
476 connection,
477 })
478 }
479
480 fn accept_tcp_encrypted(
481 local_key: secp256k1::SecretKey,
482 listener: &TcpListener,
483 ) -> Result<Self, Error> {
484 Self::init_tcp_encrypted(
485 local_key,
486 encrypted::Connection::accept(listener)?,
487 )
488 }
489
490 fn init_tcp_encrypted(
491 local_key: secp256k1::SecretKey,
492 mut connection: encrypted::Connection<LEN_SIZE>,
493 ) -> Result<Self, Error> {
494 let transcoder =
495 NoiseTranscoder::new_responder(local_key, &mut connection)?;
496 Ok(Self {
497 transcoder,
498 connection,
499 })
500 }
501}
502
503#[cfg(feature = "zmq")]
551impl Session<PlainTranscoder, zeromq::Connection> {
552 fn connect_zmq_unencrypted(
553 zmq_type: zeromq::ZmqSocketType,
554 remote: &ServiceAddr,
555 local: Option<&ServiceAddr>,
556 identity: Option<&[u8]>,
557 context: &zmq::Context,
558 ) -> Result<Self, Error> {
559 Ok(Self {
560 transcoder: PlainTranscoder,
561 connection: zeromq::Connection::connect(
562 zmq_type, remote, local, identity, context,
563 )?,
564 })
565 }
566
567 fn with_zmq_socket_unencrypted(
568 zmq_type: zeromq::ZmqSocketType,
569 socket: zmq::Socket,
570 ) -> Self {
571 Self {
572 transcoder: PlainTranscoder,
573 connection: zeromq::Connection::with_socket(zmq_type, socket),
574 }
575 }
576}
577
578#[cfg(feature = "zmq")]
579impl<T> Session<T, zeromq::Connection>
580where
581 T: Transcode,
582 T::Left: Decrypt + Send + 'static,
583 T::Right: Encrypt + Send + 'static,
584{
585 pub fn as_socket(&self) -> &zmq::Socket { self.connection.as_socket() }
586
587 pub fn set_identity(
588 &mut self,
589 identity: &impl AsRef<[u8]>,
590 context: &zmq::Context,
591 ) -> Result<(), Error> {
592 self.connection
593 .set_identity(identity, context)
594 .map_err(Error::from)
595 }
596}
597
598trait InternalInput {
600 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error>;
601 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error>;
602}
603
604impl<T, C> InternalInput for Receiver<T, C>
605where
606 T: Decrypt,
607 C: RecvFrame,
608 Error: From<T::Error>,
610{
611 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
612 Ok(self.decryptor.decrypt(self.input.recv_frame()?)?)
613 }
614 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
615 let mut routed_frame = self.input.recv_routed()?;
616 routed_frame.msg = self.decryptor.decrypt(routed_frame.msg)?;
617 Ok(routed_frame)
618 }
619}
620
621impl RecvMessage for Receiver<PlainTranscoder, unencrypted::Stream> {
622 #[inline]
623 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
624 InternalInput::recv_raw_message(self)
625 }
626 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
627 InternalInput::recv_routed_message(self)
628 }
629}
630
631impl<const LEN_SIZE: usize> RecvMessage
632 for Receiver<NoiseDecryptor<LEN_SIZE>, encrypted::Stream<LEN_SIZE>>
633{
634 #[inline]
635 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
636 recv_noise_message(&mut self.input, &mut self.decryptor)
637 }
638 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
639 InternalInput::recv_routed_message(self)
640 }
641}
642
643#[cfg(feature = "zmq")]
644impl RecvMessage for Receiver<PlainTranscoder, zeromq::WrappedSocket> {
645 #[inline]
646 fn recv_raw_message(&mut self) -> Result<Vec<u8>, Error> {
647 InternalInput::recv_raw_message(self)
648 }
649 fn recv_routed_message(&mut self) -> Result<RoutedFrame, Error> {
650 InternalInput::recv_routed_message(self)
651 }
652}
653
654impl<T, C> SendMessage for Sender<T, C>
655where
656 T: Encrypt,
657 C: SendFrame,
658{
659 fn send_raw_message(&mut self, raw: &[u8]) -> Result<usize, Error> {
660 self.output.send_frame(&self.encryptor.encrypt(raw))
661 }
662 fn send_routed_message(
663 &mut self,
664 source: &[u8],
665 route: &[u8],
666 dest: &[u8],
667 raw: &[u8],
668 ) -> Result<usize, Error> {
669 let encrypted = self.encryptor.encrypt(raw);
670 self.output.send_routed(source, route, dest, &encrypted)
671 }
672}
673
674#[cfg(test)]
675mod test {
676 use super::*;
677
678 #[test]
679 #[cfg(feature = "zmq")]
680 fn test_zmq_no_encryption() {
681 let ctx = zmq::Context::new();
682 let locator = ServiceAddr::Inproc(s!("test"));
683 let mut rx = Session::connect_zmq_unencrypted(
684 zeromq::ZmqSocketType::Rep,
685 &locator,
686 None,
687 None,
688 &ctx,
689 )
690 .unwrap();
691 let mut tx = Session::connect_zmq_unencrypted(
692 zeromq::ZmqSocketType::Req,
693 &locator,
694 None,
695 None,
696 &ctx,
697 )
698 .unwrap();
699
700 let msg = b"Some message";
701 SendRecvMessage::send_raw_message(&mut tx, msg).unwrap();
702 assert_eq!(SendRecvMessage::recv_raw_message(&mut rx).unwrap(), msg);
703
704 let msg = b"";
705 SendRecvMessage::send_raw_message(&mut rx, msg).unwrap();
706 assert_eq!(SendRecvMessage::recv_raw_message(&mut tx).unwrap(), msg);
707 }
708}