1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
use std::io; use std::net::SocketAddr; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; use std::time::Duration; use crate::{other, Stream}; use bytes::Bytes; use h2::client::{self, SendRequest}; use http::Request; use tokio::net::TcpStream; use tokio::time::sleep; use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef, TlsConnector}; pub struct Connection { tls_config: Arc<ClientConfig>, addr: SocketAddr, domain_name: String, send_request: Option<SendRequest<Bytes>>, available: Arc<AtomicBool>, sleeps: usize, } impl Connection { pub async fn new( tls_config: ClientConfig, addr: SocketAddr, domain_name: String, ) -> io::Result<Connection> { let mut conn = Connection { tls_config: Arc::new(tls_config), addr, domain_name, send_request: None, available: Arc::new(AtomicBool::new(false)), sleeps: 0, }; conn.connect().await?; Ok(conn) } pub async fn new_stream(&mut self) -> io::Result<Stream> { if !self.available.load(Ordering::Relaxed) { match self.reconnect().await { Ok(()) => self.sleeps = 0, Err(e) => { self.sleeps += 1; return Err(e); } } } if let Some(send_request) = self.send_request.as_mut() { let (response, send_stream) = send_request .send_request(Request::new(()), false) .map_err(|e| { self.available.store(false, Ordering::Relaxed); log::error!("send stream error {:?}", e); other(&e.to_string()) })?; let recv_stream = response .await .map_err(|e| { self.available.store(false, Ordering::Relaxed); log::error!("response err {}", e); other(&e.to_string()) })? .into_body(); return Ok(Stream::new(send_stream, recv_stream)); } panic!("this should not happend"); } async fn connect(&mut self) -> io::Result<()> { self.reconnect().await } async fn reconnect(&mut self) -> io::Result<()> { if self.sleeps > 0 { let delay_ms = [50, 75, 100, 250, 500, 750, 1000] .get(self.sleeps as usize) .unwrap_or(&1000); sleep(Duration::from_millis(*delay_ms)).await; } self.available.store(false, Ordering::Relaxed); let tls_connector = TlsConnector::from(self.tls_config.clone()); let domain = DNSNameRef::try_from_ascii_str(&self.domain_name).map_err(|e| { log::error!("domain err {:?}", e); io::Error::new(io::ErrorKind::InvalidInput, "invalid domain name") })?; let stream = TcpStream::connect(self.addr).await?; stream.set_nodelay(true)?; let tls_stream = tls_connector.connect(domain, stream).await?; let (h2, connection) = client::handshake(tls_stream).await.map_err(|e| { log::error!("handshake err {:?}", e); other(&e.to_string()) })?; let available = self.available.clone(); tokio::spawn(async move { if let Err(e) = connection.await { log::error!("h2 underlay connection err {:?}", e); available.store(false, Ordering::Relaxed); } }); let h2 = h2.ready().await.map_err(|e| { log::error!("h2 ready err {:?}", e); other(&e.to_string()) })?; self.send_request = Some(h2); self.available.store(true, Ordering::Relaxed); Ok(()) } }