epp_client/
client.rs

1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6#[cfg(feature = "tokio-rustls")]
7use tokio::net::lookup_host;
8use tokio::net::TcpStream;
9#[cfg(feature = "tokio-rustls")]
10use tokio_rustls::client::TlsStream;
11#[cfg(feature = "tokio-rustls")]
12use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
13#[cfg(feature = "tokio-rustls")]
14use tokio_rustls::TlsConnector;
15use tracing::{debug, error, info};
16
17use crate::common::{Certificate, NoExtension, PrivateKey};
18pub use crate::connection::Connector;
19use crate::connection::{self, EppConnection};
20use crate::error::Error;
21use crate::hello::{Greeting, GreetingDocument, HelloDocument};
22use crate::request::{Command, CommandDocument, Extension, Transaction};
23use crate::response::{Response, ResponseDocument, ResponseStatus};
24use crate::xml;
25
26/// An `EppClient` provides an interface to sending EPP requests to a registry
27///
28/// Once initialized, the EppClient instance can serialize EPP requests to XML and send them
29/// to the registry and deserialize the XML responses from the registry to local types.
30///
31/// # Examples
32///
33/// ```no_run
34/// # use std::collections::HashMap;
35/// # use std::net::ToSocketAddrs;
36/// # use std::time::Duration;
37/// #
38/// use epp_client::EppClient;
39/// use epp_client::domain::DomainCheck;
40/// use epp_client::common::NoExtension;
41///
42/// # #[tokio::main]
43/// # async fn main() {
44/// // Create an instance of EppClient
45/// let timeout = Duration::from_secs(5);
46/// let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
47///     Ok(client) => client,
48///     Err(e) => panic!("Failed to create EppClient: {}",  e)
49/// };
50///
51/// // Make a EPP Hello call to the registry
52/// let greeting = client.hello().await.unwrap();
53/// println!("{:?}", greeting);
54///
55/// // Execute an EPP Command against the registry with distinct request and response objects
56/// let domain_check = DomainCheck { domains: &["eppdev.com", "eppdev.net"] };
57/// let response = client.transact(&domain_check, "transaction-id").await.unwrap();
58/// response.res_data.unwrap().list
59///     .iter()
60///     .for_each(|chk| println!("Domain: {}, Available: {}", chk.id, chk.available));
61/// # }
62/// ```
63///
64/// The output would look like this:
65///
66/// ```text
67/// Domain: eppdev.com, Available: 1
68/// Domain: eppdev.net, Available: 1
69/// ```
70pub struct EppClient<C: Connector> {
71    connection: EppConnection<C>,
72}
73
74#[cfg(feature = "tokio-rustls")]
75impl EppClient<RustlsConnector> {
76    /// Connect to the specified `addr` and `hostname` over TLS
77    ///
78    /// The `registry` is used as a name in internal logging; `host` provides the host name
79    /// and port to connect to), `hostname` is sent as the TLS server name indication and
80    /// `identity` provides optional TLS client authentication (using) rustls as the TLS
81    /// implementation. The `timeout` limits the time spent on any underlying network operations.
82    ///
83    /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin`
84    /// implementation.
85    pub async fn connect(
86        registry: String,
87        server: (String, u16),
88        identity: Option<(Vec<Certificate>, PrivateKey)>,
89        timeout: Duration,
90    ) -> Result<Self, Error> {
91        let connector = RustlsConnector::new(server, identity).await?;
92        Self::new(connector, registry, timeout).await
93    }
94}
95
96impl<C: Connector> EppClient<C> {
97    /// Create an `EppClient` from an already established connection
98    pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
99        Ok(Self {
100            connection: EppConnection::new(connector, registry, timeout).await?,
101        })
102    }
103
104    /// Executes an EPP Hello call and returns the response as a `Greeting`
105    pub async fn hello(&mut self) -> Result<Greeting, Error> {
106        let xml = xml::serialize(&HelloDocument::default())?;
107
108        debug!("{}: hello: {}", self.connection.registry, &xml);
109        let response = self.connection.transact(&xml)?.await?;
110        debug!("{}: greeting: {}", self.connection.registry, &response);
111
112        Ok(xml::deserialize::<GreetingDocument>(&response)?.data)
113    }
114
115    pub async fn transact<'c, 'e, Cmd, Ext>(
116        &mut self,
117        data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
118        id: &str,
119    ) -> Result<Response<Cmd::Response, Ext::Response>, Error>
120    where
121        Cmd: Transaction<Ext> + Command + 'c,
122        Ext: Extension + 'e,
123    {
124        let data = data.into();
125        let document = CommandDocument::new(data.command, data.extension, id);
126        let xml = xml::serialize(&document)?;
127
128        debug!("{}: request: {}", self.connection.registry, &xml);
129        let response = self.connection.transact(&xml)?.await?;
130        debug!("{}: response: {}", self.connection.registry, &response);
131
132        let rsp =
133            match xml::deserialize::<ResponseDocument<Cmd::Response, Ext::Response>>(&response) {
134                Ok(rsp) => rsp,
135                Err(e) => {
136                    error!(%response, "failed to deserialize response for transaction: {e}");
137                    return Err(e);
138                }
139            };
140
141        if rsp.data.result.code.is_success() {
142            return Ok(rsp.data);
143        }
144
145        let err = crate::error::Error::Command(Box::new(ResponseStatus {
146            result: rsp.data.result,
147            tr_ids: rsp.data.tr_ids,
148        }));
149
150        error!(%response, "Failed to deserialize response for transaction: {}", err);
151        Err(err)
152    }
153
154    /// Accepts raw EPP XML and returns the raw EPP XML response to it.
155    /// Not recommended for direct use but sometimes can be useful for debugging
156    pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
157        self.connection.transact(xml)?.await
158    }
159
160    /// Returns the greeting received on establishment of the connection in raw xml form
161    pub fn xml_greeting(&self) -> String {
162        String::from(&self.connection.greeting)
163    }
164
165    /// Returns the greeting received on establishment of the connection as an `Greeting`
166    pub fn greeting(&self) -> Result<Greeting, Error> {
167        xml::deserialize::<GreetingDocument>(&self.connection.greeting).map(|obj| obj.data)
168    }
169
170    pub async fn reconnect(&mut self) -> Result<(), Error> {
171        self.connection.reconnect().await
172    }
173
174    pub async fn shutdown(mut self) -> Result<(), Error> {
175        self.connection.shutdown().await
176    }
177}
178
179#[derive(Debug)]
180pub struct RequestData<'c, 'e, C, E> {
181    pub(crate) command: &'c C,
182    pub(crate) extension: Option<&'e E>,
183}
184
185impl<'c, C: Command> From<&'c C> for RequestData<'c, 'static, C, NoExtension> {
186    fn from(command: &'c C) -> Self {
187        Self {
188            command,
189            extension: None,
190        }
191    }
192}
193
194impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, 'e, C, E> {
195    fn from((command, extension): (&'c C, &'e E)) -> Self {
196        Self {
197            command,
198            extension: Some(extension),
199        }
200    }
201}
202
203// Manual impl because this does not depend on whether `C` and `E` are `Clone`
204impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
205    fn clone(&self) -> Self {
206        Self {
207            command: self.command,
208            extension: self.extension,
209        }
210    }
211}
212
213// Manual impl because this does not depend on whether `C` and `E` are `Copy`
214impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
215
216#[cfg(feature = "tokio-rustls")]
217pub struct RustlsConnector {
218    inner: TlsConnector,
219    domain: ServerName,
220    server: (String, u16),
221}
222
223impl RustlsConnector {
224    pub async fn new(
225        server: (String, u16),
226        identity: Option<(Vec<Certificate>, PrivateKey)>,
227    ) -> Result<Self, Error> {
228        let mut roots = RootCertStore::empty();
229        roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
230            OwnedTrustAnchor::from_subject_spki_name_constraints(
231                ta.subject,
232                ta.spki,
233                ta.name_constraints,
234            )
235        }));
236
237        let builder = ClientConfig::builder()
238            .with_safe_defaults()
239            .with_root_certificates(roots);
240
241        let config = match identity {
242            Some((certs, key)) => {
243                let certs = certs
244                    .into_iter()
245                    .map(|cert| tokio_rustls::rustls::Certificate(cert.0))
246                    .collect();
247                builder
248                    .with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
249                    .map_err(|e| Error::Other(e.into()))?
250            }
251            None => builder.with_no_client_auth(),
252        };
253
254        let domain = server.0.as_str().try_into().map_err(|_| {
255            io::Error::new(
256                io::ErrorKind::InvalidInput,
257                format!("Invalid domain: {}", server.0),
258            )
259        })?;
260
261        Ok(Self {
262            inner: TlsConnector::from(Arc::new(config)),
263            domain,
264            server,
265        })
266    }
267}
268
269#[cfg(feature = "tokio-rustls")]
270#[async_trait]
271impl Connector for RustlsConnector {
272    type Connection = TlsStream<TcpStream>;
273
274    async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
275        info!("Connecting to server: {}:{}", self.server.0, self.server.1);
276        let addr = match lookup_host(&self.server).await?.next() {
277            Some(addr) => addr,
278            None => {
279                return Err(Error::Io(io::Error::new(
280                    io::ErrorKind::InvalidInput,
281                    format!("Invalid host: {}", &self.server.0),
282                )))
283            }
284        };
285
286        let stream = TcpStream::connect(addr).await?;
287        let future = self.inner.connect(self.domain.clone(), stream);
288        connection::timeout(timeout, future).await
289    }
290}