use std::net::SocketAddr;
use openssl::ssl::SslConnector;
use tokio::net::TcpStream;
use tonic::transport::{Endpoint, Uri};
use tower::Service;
pub fn new_endpoint() -> Endpoint {
tonic::transport::Endpoint::from_static("http://[::]:50051")
}
pub fn connector(
uri: Uri,
ssl_conn: SslConnector,
domain: Option<String>,
) -> impl Service<
Uri,
Response = impl hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
Future = impl Send + 'static,
Error = crate::Error,
> {
tower::service_fn(move |_: Uri| {
let domain = domain.clone();
let uri = uri.clone();
let ssl_conn = ssl_conn.clone();
async move {
let addrs = dns_resolve(&uri).await?;
let ssl_config = ssl_conn.configure()?;
let ssl = match domain {
Some(d) => {
ssl_config.into_ssl(&d)?
}
None => ssl_config.into_ssl(uri.host().unwrap())?,
};
let io = connect_tcp(addrs).await?;
let mut stream = tokio_openssl::SslStream::new(ssl, io)?;
std::pin::Pin::new(&mut stream).connect().await?;
Ok::<_, crate::Error>(hyper_util::rt::TokioIo::new(stream))
}
})
}
async fn dns_resolve(uri: &Uri) -> std::io::Result<Vec<SocketAddr>> {
let host_port = uri
.authority()
.ok_or(std::io::Error::from(std::io::ErrorKind::InvalidInput))?
.as_str();
match host_port.parse::<SocketAddr>() {
Ok(addr) => Ok(vec![addr]),
Err(_) => {
tokio::net::lookup_host(host_port)
.await
.map(|a| a.collect::<Vec<_>>())
}
}
}
async fn connect_tcp(addrs: Vec<SocketAddr>) -> std::io::Result<TcpStream> {
let mut conn_err = std::io::Error::from(std::io::ErrorKind::AddrNotAvailable);
for addr in addrs {
match TcpStream::connect(addr).await {
Ok(s) => return Ok(s),
Err(e) => conn_err = e,
}
}
Err(conn_err)
}