use std::time::Duration;
use tracing::{debug, error};
use crate::common::NoExtension;
#[cfg(feature = "rustls")]
use crate::common::{Certificate, PrivateKey};
pub use crate::connection::Connector;
use crate::connection::EppConnection;
use crate::error::Error;
use crate::hello::{Greeting, Hello};
use crate::request::{Command, CommandWrapper, Extension, Transaction};
use crate::response::{Response, ResponseStatus};
use crate::xml;
pub struct EppClient<C: Connector> {
connection: EppConnection<C>,
}
#[cfg(feature = "rustls")]
impl EppClient<RustlsConnector> {
pub async fn connect(
registry: String,
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
timeout: Duration,
) -> Result<Self, Error> {
let connector = RustlsConnector::new(server, identity).await?;
Self::new(connector, registry, timeout).await
}
}
impl<C: Connector> EppClient<C> {
pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
Ok(Self {
connection: EppConnection::new(connector, registry, timeout).await?,
})
}
pub async fn hello(&mut self) -> Result<Greeting, Error> {
let xml = xml::serialize(Hello)?;
debug!("{}: hello: {}", self.connection.registry, &xml);
let response = self.connection.transact(&xml)?.await?;
debug!("{}: greeting: {}", self.connection.registry, &response);
xml::deserialize::<Greeting>(&response)
}
pub async fn transact<'c, 'e, Cmd, Ext>(
&mut self,
data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
id: &str,
) -> Result<Response<Cmd::Response, Ext::Response>, Error>
where
Cmd: Transaction<Ext> + Command + 'c,
Ext: Extension + 'e,
{
let data = data.into();
let document = CommandWrapper::new(data.command, data.extension, id);
let xml = xml::serialize(&document)?;
debug!("{}: request: {}", self.connection.registry, &xml);
let response = self.connection.transact(&xml)?.await?;
debug!("{}: response: {}", self.connection.registry, &response);
let rsp = match xml::deserialize::<Response<Cmd::Response, Ext::Response>>(&response) {
Ok(rsp) => rsp,
Err(e) => {
error!(%response, "failed to deserialize response for transaction: {e}");
return Err(e);
}
};
if rsp.result.code.is_success() {
return Ok(rsp);
}
let err = crate::error::Error::Command(Box::new(ResponseStatus {
result: rsp.result,
tr_ids: rsp.tr_ids,
}));
Err(err)
}
pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
self.connection.transact(xml)?.await
}
pub fn xml_greeting(&self) -> String {
String::from(&self.connection.greeting)
}
pub fn greeting(&self) -> Result<Greeting, Error> {
xml::deserialize::<Greeting>(&self.connection.greeting)
}
pub async fn reconnect(&mut self) -> Result<(), Error> {
self.connection.reconnect().await
}
pub async fn shutdown(mut self) -> Result<(), Error> {
self.connection.shutdown().await
}
}
#[derive(Debug)]
pub struct RequestData<'c, 'e, C, E> {
pub(crate) command: &'c C,
pub(crate) extension: Option<&'e E>,
}
impl<'c, C: Command> From<&'c C> for RequestData<'c, 'static, C, NoExtension> {
fn from(command: &'c C) -> Self {
Self {
command,
extension: None,
}
}
}
impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, 'e, C, E> {
fn from((command, extension): (&'c C, &'e E)) -> Self {
Self {
command,
extension: Some(extension),
}
}
}
impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
fn clone(&self) -> Self {
*self
}
}
impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
#[cfg(feature = "rustls")]
pub use rustls_connector::RustlsConnector;
#[cfg(feature = "rustls")]
mod rustls_connector {
use std::io;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::net::lookup_host;
use tokio::net::TcpStream;
use tokio_rustls::client::TlsStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerName};
use tokio_rustls::TlsConnector;
use tracing::info;
use crate::common::{Certificate, PrivateKey};
use crate::connection::{self, Connector};
use crate::error::Error;
pub struct RustlsConnector {
inner: TlsConnector,
domain: ServerName,
server: (String, u16),
}
impl RustlsConnector {
pub async fn new(
server: (String, u16),
identity: Option<(Vec<Certificate>, PrivateKey)>,
) -> Result<Self, Error> {
let mut roots = RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs()? {
roots
.add(&tokio_rustls::rustls::Certificate(cert.0))
.map_err(|err| {
Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>
})?;
}
let builder = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots);
let config = match identity {
Some((certs, key)) => {
let certs = certs
.into_iter()
.map(|cert| tokio_rustls::rustls::Certificate(cert.0))
.collect();
builder
.with_client_auth_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
.map_err(|e| Error::Other(e.into()))?
}
None => builder.with_no_client_auth(),
};
let domain = server.0.as_str().try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid domain: {}", server.0),
)
})?;
Ok(Self {
inner: TlsConnector::from(Arc::new(config)),
domain,
server,
})
}
}
#[async_trait]
impl Connector for RustlsConnector {
type Connection = TlsStream<TcpStream>;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
info!("Connecting to server: {}:{}", self.server.0, self.server.1);
let addr = match lookup_host(&self.server).await?.next() {
Some(addr) => addr,
None => {
return Err(Error::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid host: {}", &self.server.0),
)))
}
};
let stream = TcpStream::connect(addr).await?;
let future = self.inner.connect(self.domain.clone(), stream);
connection::timeout(timeout, future).await
}
}
}