cosmic_hyperlane_tcp/
lib.rs

1#![allow(warnings)]
2
3#[macro_use]
4extern crate async_trait;
5
6use std::io::{Empty, Read};
7use std::iter;
8use std::net::{SocketAddr, ToSocketAddrs};
9use std::pin::Pin;
10use std::str::FromStr;
11use std::string::FromUtf8Error;
12use std::sync::Arc;
13use std::time::Duration;
14
15use rcgen::{Certificate, generate_simple_self_signed, RcgenError};
16use rustls::internal::msgs::codec::Codec;
17use rustls::{ClientConfig, RootCertStore, server, ServerConfig, ServerName};
18use tls_api_rustls::TlsConnectorBuilder;
19use tokio::fs::File;
20use tokio::io;
21use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
22use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf, ReadHalf, WriteHalf};
23use tokio::net::{TcpListener, TcpStream};
24use tokio::sync::{broadcast, mpsc, oneshot};
25use tokio::time::error::Elapsed;
26use tokio_rustls::{TlsAcceptor, TlsConnector, TlsStream};
27
28use cosmic_hyperlane::{
29    HyperConnectionDetails, HyperConnectionStatus, HyperGate, HyperGateSelector, HyperwayEndpoint,
30    HyperwayEndpointFactory, VersionGate,
31};
32use cosmic_space::err::SpaceErr;
33use cosmic_space::hyper::Knock;
34use cosmic_space::log::PointLogger;
35use cosmic_space::substance::Substance;
36use cosmic_space::wave::{Ping, UltraWave, Wave};
37use cosmic_space::VERSION;
38
39pub struct HyperlaneTcpClient {
40    host: String,
41    cert_dir: String,
42    knock: Knock,
43    logger: PointLogger,
44    verify: bool,
45}
46
47impl HyperlaneTcpClient {
48    pub fn new<H, S>(host: H, cert_dir: S, knock: Knock, verify: bool, logger: PointLogger) -> Self
49    where
50        S: ToString,
51        H: ToString,
52    {
53        Self {
54            host: host.to_string(),
55            cert_dir: cert_dir.to_string(),
56            knock,
57            verify,
58            logger,
59        }
60    }
61}
62
63#[async_trait]
64impl HyperwayEndpointFactory for HyperlaneTcpClient {
65    async fn create(
66        &self,
67        status_tx: mpsc::Sender<HyperConnectionDetails>,
68    ) -> Result<HyperwayEndpoint, SpaceErr> {
69        let mut root_certs = RootCertStore::empty();
70
71        let ca_file = format!("{}/cert.der", self.cert_dir);
72
73        let mut ca_file = File::open(ca_file).await?;
74        let mut ca_buffer = Vec::new();
75        ca_file.read_to_end(&mut ca_buffer).await?;
76
77        root_certs.add_parsable_certificates(&mut [ca_buffer]);
78
79        let client_config = Arc::new(
80            ClientConfig::builder()
81                .with_safe_default_cipher_suites()
82                .with_safe_default_kx_groups()
83                .with_safe_default_protocol_versions()
84                .unwrap()
85                .with_root_certificates(root_certs)
86                .with_no_client_auth(),
87        );
88
89        let mut connector: TlsConnector = TlsConnector::from(client_config);
90        let stream = tokio::net::TcpStream::connect(self.host.as_str()).await?;
91
92        let host = self.host.split(":").next().unwrap().to_string();
93        let server_name = rustls::ServerName::try_from(host.as_str()).unwrap();
94        let tokio_tls_connector = connector.connect(server_name, stream).await?;
95
96        let mut stream = FrameStream::new(tokio_tls_connector.into());
97
98        let endpoint =
99            FrameMuxer::handshake(stream, status_tx.clone(), self.logger.clone()).await?;
100
101        let wave: Wave<Ping> = self.knock.clone().into();
102        let wave = wave.to_ultra();
103        endpoint.tx.send(wave).await?;
104
105        Ok(endpoint)
106    }
107}
108
109pub struct CertGenerator {
110    certs: Vec<u8>,
111    key: Vec<u8>,
112}
113
114impl CertGenerator {
115    pub fn gen(subject_alt_names: Vec<String>) -> Result<Self, RcgenError> {
116        let cert = generate_simple_self_signed(subject_alt_names)?;
117        let certs = cert.serialize_der()?;
118        let key = cert.serialize_private_key_der();
119        Ok(Self { certs, key })
120    }
121
122    pub async fn read_from_dir(dir: String) -> Result<Self, Error> {
123        let mut certs_data = vec![];
124        let mut certs = File::open(format!("{}/cert.der", dir)).await?;
125        certs.read_to_end(&mut certs_data).await?;
126
127        let mut key_data = vec![];
128        let mut key = File::open(format!("{}/key.der", dir)).await?;
129        key.read_to_end(&mut key_data).await?;
130
131        Ok(Self {
132            certs: certs_data,
133            key: key_data,
134        })
135    }
136
137    pub fn certs(&self) -> Vec<u8> {
138        self.certs.clone()
139    }
140
141    pub fn private_key(&self) -> Vec<u8> {
142        self.key.clone()
143    }
144
145    pub async fn write_to_dir(&self, dir: String) -> io::Result<()> {
146        let mut certs = File::create(format!("{}/cert.der", dir)).await?;
147        certs.write_all(&self.certs()).await?;
148        let mut key = File::create(format!("{}/key.der", dir)).await?;
149        key.write_all(&self.private_key()).await?;
150        Ok(())
151    }
152}
153
154#[derive(Clone)]
155pub struct Frame {
156    pub data: Vec<u8>,
157}
158
159impl Frame {
160    pub fn from_string(string: String) -> Frame {
161        Frame {
162            data: string.as_bytes().to_vec(),
163        }
164    }
165
166    pub fn to_string(self) -> Result<String, SpaceErr> {
167        Ok(String::from_utf8(self.data)?)
168    }
169
170    pub fn from_version(version: &semver::Version) -> Frame {
171        Frame {
172            data: version.to_string().as_bytes().to_vec(),
173        }
174    }
175
176    pub fn to_version(self) -> Result<semver::Version, SpaceErr> {
177        Ok(semver::Version::from_str(
178            String::from_utf8(self.data)?.as_str(),
179        )?)
180    }
181
182    pub async fn from_stream<'a>(read: &'a mut TlsStream<TcpStream>) -> Result<Frame, SpaceErr> {
183        let size = read.read_u32().await? as usize;
184        let mut data = Vec::with_capacity(size as usize);
185
186        while data.len() < size {
187            read.read_buf(&mut data).await?;
188        }
189
190        Ok(Self { data })
191    }
192
193    pub async fn to_stream<'a>(&self, write: &'a mut TlsStream<TcpStream>) -> Result<(), SpaceErr> {
194        write.write_u32(self.data.len() as u32).await?;
195        write.write_all(self.data.as_slice()).await?;
196        write.flush().await?;
197        Ok(())
198    }
199
200    pub fn to_wave(self) -> Result<UltraWave, SpaceErr> {
201        Ok(bincode::deserialize(self.data.as_slice())?)
202    }
203
204    pub fn from_wave(wave: UltraWave) -> Result<Self, SpaceErr> {
205        Ok(Self {
206            data: bincode::serialize(&wave)?,
207        })
208    }
209}
210
211pub struct FrameMuxer {
212    stream: FrameStream,
213    tx: mpsc::Sender<UltraWave>,
214    rx: mpsc::Receiver<UltraWave>,
215    terminate_rx: mpsc::Receiver<()>,
216    logger: PointLogger,
217}
218impl FrameMuxer {
219    pub async fn handshake(
220        mut stream: FrameStream,
221        status_tx: mpsc::Sender<HyperConnectionDetails>,
222        logger: PointLogger,
223    ) -> Result<HyperwayEndpoint, SpaceErr> {
224        stream.write_version(&VERSION.clone()).await?;
225        let in_version =
226            tokio::time::timeout(Duration::from_secs(30), stream.read_version()).await??;
227
228        if in_version == *VERSION {
229            //            logger.info("version match");
230
231            stream.write_string("Ok".to_string()).await?;
232        } else {
233            logger.warn("version mismatch");
234            status_tx
235                .send(HyperConnectionDetails::new(
236                    HyperConnectionStatus::Handshake,
237                    "version mismatch",
238                ))
239                .await?;
240            let msg = format!(
241                "Err(\"expected version {}. encountered version {}\")",
242                VERSION.to_string(),
243                in_version.to_string()
244            );
245            stream.write_string(msg.clone()).await?;
246            return Err(msg.into());
247        }
248
249        let result = tokio::time::timeout(Duration::from_secs(30), stream.read_string()).await??;
250        if "Ok".to_string() != result {
251            return logger.result(Err(format!(
252                "remote did not indicate Ok. expected: 'Ok' encountered '{}'",
253                result
254            )
255            .into()));
256        }
257
258        Ok(Self::new(stream, logger))
259    }
260
261    pub fn new(stream: FrameStream, logger: PointLogger) -> HyperwayEndpoint {
262        let (in_tx, in_rx) = mpsc::channel(1024);
263        let (out_tx, out_rx) = mpsc::channel(1024);
264        let (terminate_tx, mut terminate_rx) = mpsc::channel(1);
265        let mut muxer = Self {
266            stream,
267            tx: in_tx,
268            rx: out_rx,
269            terminate_rx,
270            logger: logger.clone(),
271        };
272        {
273            let logger = logger.clone();
274            tokio::spawn(async move {
275                logger.result(muxer.mux().await).unwrap();
276            });
277        }
278
279        let (oneshot_terminate_tx, mut oneshot_terminate_rx) = oneshot::channel();
280        tokio::spawn(async move {
281            oneshot_terminate_rx.await.unwrap_or_default();
282            terminate_tx.send(()).await.unwrap_or_default();
283        });
284        HyperwayEndpoint::new_with_drop(out_tx, in_rx, oneshot_terminate_tx, logger)
285    }
286
287    pub async fn mux(mut self) -> Result<(), SpaceErr> {
288        loop {
289            tokio::select! {
290                wave = self.rx.recv() => {
291                    match wave {
292                        None => {
293                            self.logger.warn("rx discon");
294                            break
295                        },
296                        Some(wave) => {
297                           self.stream.write_wave(wave.clone()).await?;
298                        }
299                    }
300                }
301                wave = self.stream.read_wave() => {
302                    match wave {
303                       Ok(wave) => {
304                            self.tx.send(wave).await?;
305                       },
306                       Err(err) => {
307                            self.logger.error(format!("read stream err: {}",err.to_string()));
308                            break;
309                       }
310                    }
311                }
312                _ = self.terminate_rx.recv() => {
313                     self.logger.warn(format!("terminated"));
314                     return Ok(())
315                    }
316            }
317        }
318        Ok(())
319    }
320}
321
322pub struct FrameStream {
323    stream: TlsStream<TcpStream>,
324}
325
326impl FrameStream {
327    pub fn new(stream: TlsStream<TcpStream>) -> Self {
328        Self { stream }
329    }
330
331    pub async fn frame(&mut self) -> Result<Frame, SpaceErr> {
332        Frame::from_stream(&mut self.stream).await
333    }
334
335    pub async fn read_version(&mut self) -> Result<semver::Version, SpaceErr> {
336        self.frame().await?.to_version()
337    }
338
339    pub async fn read_string(&mut self) -> Result<String, SpaceErr> {
340        self.frame().await?.to_string()
341    }
342
343    pub async fn read_wave(&mut self) -> Result<UltraWave, SpaceErr> {
344        self.frame().await?.to_wave()
345    }
346
347    pub async fn write_frame(&mut self, frame: Frame) -> Result<(), SpaceErr> {
348        frame.to_stream(&mut self.stream).await
349    }
350
351    pub async fn write_string(&mut self, string: String) -> Result<(), SpaceErr> {
352        self.write_frame(Frame::from_string(string)).await
353    }
354
355    pub async fn write_version(&mut self, version: &semver::Version) -> Result<(), SpaceErr> {
356        self.write_frame(Frame::from_version(version)).await
357    }
358
359    pub async fn write_wave(&mut self, wave: UltraWave) -> Result<(), SpaceErr> {
360        self.write_frame(Frame::from_wave(wave)?).await
361    }
362}
363
364pub struct HyperlaneTcpServerApi {}
365
366impl HyperlaneTcpServerApi {
367    pub fn new() -> Self {
368        Self {}
369    }
370}
371
372pub struct HyperlaneTcpServer {
373    gate: Arc<HyperGateSelector>,
374    listener: TcpListener,
375    logger: PointLogger,
376    acceptor: TlsAcceptor,
377    server_kill_tx: broadcast::Sender<()>,
378    server_kill_rx: broadcast::Receiver<()>,
379}
380
381impl HyperlaneTcpServer {
382    pub async fn new(
383        port: u16,
384        cert_dir: String,
385        gate: Arc<HyperGateSelector>,
386        logger: PointLogger,
387    ) -> Result<Self, Error> {
388        let (server_kill_tx, server_kill_rx) = broadcast::channel(1);
389
390        // load certificate
391        let cert_path = format!("{}/cert.der", cert_dir);
392        let key_path = format!("{}/key.der", cert_dir);
393
394        let mut cert_data = vec![];
395        let mut key_data = vec![];
396
397        let mut file = std::fs::File::open(cert_path)?;
398        file.read_to_end(&mut cert_data)?;
399
400        let mut file = std::fs::File::open(key_path)?;
401        file.read_to_end(&mut key_data)?;
402
403        // I highly doubt this works
404        let mut ca_certs = Vec::<rustls::Certificate>::new();
405        ca_certs.push(rustls::Certificate(cert_data));
406
407        let private_key = rustls::PrivateKey(key_data);
408
409        let server_config = Arc::new(
410            ServerConfig::builder()
411                .with_safe_default_cipher_suites()
412                .with_safe_default_kx_groups()
413                .with_safe_default_protocol_versions()
414                .unwrap()
415                .with_no_client_auth()
416                .with_single_cert(ca_certs, private_key)
417                .expect("bad certificate/key"),
418        );
419
420        let mut acceptor = TlsAcceptor::from(server_config);
421        let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
422            .await
423            .unwrap();
424
425        Ok(Self {
426            acceptor,
427            gate,
428            listener,
429            logger,
430            server_kill_tx,
431            server_kill_rx,
432        })
433    }
434
435    pub fn start(mut self) -> Result<HyperlaneTcpServerApi, Error> {
436        tokio::spawn(async move {
437            self.run().await;
438        });
439        Ok(HyperlaneTcpServerApi::new())
440    }
441
442    async fn run(mut self) {
443        loop {
444            let stream = self.listener.accept().await.unwrap().0;
445            let acceptor = self.acceptor.clone();
446            let gate = self.gate.clone();
447            let logger = self.logger.clone();
448            let mut server_kill_rx = self.server_kill_tx.subscribe();
449
450            tokio::spawn(async move {
451                async fn serve(
452                    stream: TcpStream,
453                    acceptor: TlsAcceptor,
454                    gate: Arc<HyperGateSelector>,
455                    server_kill_rx: broadcast::Receiver<()>,
456                    logger: PointLogger,
457                ) -> Result<(), Error> {
458                    let mut stream = acceptor.accept(stream).await.unwrap();
459
460                    let mut stream = FrameStream::new(stream.into());
461
462                    let (status_tx, mut status_rx): (
463                        mpsc::Sender<HyperConnectionDetails>,
464                        mpsc::Receiver<HyperConnectionDetails>,
465                    ) = mpsc::channel(1024);
466                    {
467                        let logger = logger.clone();
468                        tokio::spawn(async move {
469                            while let Some(details) = status_rx.recv().await {
470                                /*                                logger.info(format!(
471                                    "{} | {}",
472                                    details.status.to_string(),
473                                    details.info
474                                ))*/
475                            }
476                        });
477                    }
478                    let mut mux = FrameMuxer::handshake(stream, status_tx, logger.clone()).await?;
479
480                    let knock = tokio::time::timeout(Duration::from_secs(30), mux.rx.recv())
481                        .await?
482                        .ok_or("expected wave")?;
483                    let knock = knock.to_directed()?;
484                    if let Substance::Knock(knock) = knock.body() {
485                        let mut endpoint = gate.knock(knock.clone()).await?;
486                        mux.connect(endpoint);
487                    } else {
488                        let msg = format!(
489                            "expected client Substance::Knock(Knock) encountered '{}'",
490                            knock.body().kind().to_string()
491                        );
492                        return logger.result(Err(SpaceErr::str(msg).into()));
493                    }
494
495                    Ok(())
496                }
497                serve(stream, acceptor, gate, server_kill_rx, logger).await;
498            });
499        }
500    }
501}
502
503pub fn add(left: usize, right: usize) -> usize {
504    left + right
505}
506
507#[derive(Debug, Clone)]
508pub struct Error {
509    pub message: String,
510}
511
512impl ToString for Error {
513    fn to_string(&self) -> String {
514        self.message.clone()
515    }
516}
517
518impl Error {
519    pub fn new<S: ToString>(m: S) -> Self {
520        Self {
521            message: m.to_string(),
522        }
523    }
524}
525impl From<Elapsed> for Error {
526    fn from(e: Elapsed) -> Self {
527        Self::new(e)
528    }
529}
530
531impl From<FromUtf8Error> for Error {
532    fn from(e: FromUtf8Error) -> Self {
533        Self::new(e)
534    }
535}
536
537impl From<std::io::Error> for Error {
538    fn from(e: std::io::Error) -> Self {
539        Error::new(e)
540    }
541}
542
543impl From<SpaceErr> for Error {
544    fn from(e: SpaceErr) -> Self {
545        Error::new(e)
546    }
547}
548
549impl From<RcgenError> for Error {
550    fn from(e: RcgenError) -> Self {
551        Error::new(e)
552    }
553}
554
555impl From<String> for Error {
556    fn from(e: String) -> Self {
557        Error::new(e)
558    }
559}
560
561impl From<&str> for Error {
562    fn from(e: &str) -> Self {
563        Error::new(e)
564    }
565}
566
567#[cfg(test)]
568mod tests {
569    use std::time::Duration;
570
571    use cosmic_hyperlane::test_util::{
572        FAE, LargeFrameTest, LESS, SingleInterchangePlatform, WaveTest,
573    };
574    use cosmic_space::loc::ToSurface;
575    use cosmic_space::log::RootLogger;
576
577    use chrono::DateTime;
578    use chrono::Utc;
579    use cosmic_hyperlane::HyperClient;
580    use cosmic_space::point::Point;
581    use cosmic_space::settings::Timeouts;
582    use cosmic_space::wave::exchange::asynch::Exchanger;
583    use cosmic_space::wave::DirectedProto;
584
585    use super::*;
586
587    #[no_mangle]
588    pub extern "C" fn cosmic_uuid() -> String {
589        uuid::Uuid::new_v4().to_string()
590    }
591
592    #[no_mangle]
593    pub extern "C" fn cosmic_timestamp() -> DateTime<Utc> {
594        Utc::now()
595    }
596
597    //#[tokio::test]
598    async fn test_tcp() -> Result<(), Error> {
599        let platform = SingleInterchangePlatform::new().await;
600
601        CertGenerator::gen(vec!["localhost".to_string()])?
602            .write_to_dir(".".to_string())
603            .await?;
604        let logger = RootLogger::default();
605        let logger = logger.point(Point::from_str("tcp-server")?);
606        let port = 4344u16;
607        let server =
608            HyperlaneTcpServer::new(port, ".".to_string(), platform.gate.clone(), logger.clone())
609                .await?;
610        let api = server.start()?;
611
612        let less_logger = logger.point(LESS.clone());
613        let less_client = Box::new(HyperlaneTcpClient::new(
614            format!("localhost:{}", port),
615            ".",
616            platform.knock(LESS.to_surface()),
617            false,
618            less_logger,
619        ));
620
621        let fae_logger = logger.point(FAE.clone());
622        let fae_client = Box::new(HyperlaneTcpClient::new(
623            format!("localhost:{}", port),
624            ".",
625            platform.knock(FAE.to_surface()),
626            false,
627            fae_logger,
628        ));
629
630        let test = WaveTest::new(fae_client, less_client);
631
632        test.go().await.unwrap();
633
634        Ok(())
635    }
636
637    //    #[tokio::test]
638    async fn test_large_frame() -> Result<(), Error> {
639        let platform = SingleInterchangePlatform::new().await;
640
641        CertGenerator::gen(vec!["localhost".to_string()])?
642            .write_to_dir(".".to_string())
643            .await?;
644        let logger = RootLogger::default();
645        let logger = logger.point(Point::from_str("tcp-server")?);
646        let port = 4345u16;
647        let server =
648            HyperlaneTcpServer::new(port, ".".to_string(), platform.gate.clone(), logger.clone())
649                .await?;
650        let api = server.start()?;
651
652        let less_logger = logger.point(LESS.clone());
653        let less_client = Box::new(HyperlaneTcpClient::new(
654            format!("localhost:{}", port),
655            ".",
656            platform.knock(LESS.to_surface()),
657            false,
658            less_logger,
659        ));
660
661        let fae_logger = logger.point(FAE.clone());
662        let fae_client = Box::new(HyperlaneTcpClient::new(
663            format!("localhost:{}", port),
664            ".",
665            platform.knock(FAE.to_surface()),
666            false,
667            fae_logger,
668        ));
669
670        let test = LargeFrameTest::new(fae_client, less_client);
671
672        test.go().await.unwrap();
673
674        Ok(())
675    }
676}