instant_epp/
client.rs

1use std::time::Duration;
2
3#[cfg(feature = "rustls")]
4use rustls_pki_types::{CertificateDer, PrivateKeyDer};
5use tracing::{debug, error};
6
7use crate::common::NoExtension;
8pub use crate::connection::Connector;
9use crate::connection::EppConnection;
10use crate::error::Error;
11use crate::hello::{Greeting, Hello};
12use crate::request::{Command, CommandWrapper, Extension, Transaction};
13use crate::response::{Response, ResponseStatus};
14use crate::xml;
15
16/// An `EppClient` provides an interface to sending EPP requests to a registry
17///
18/// Once initialized, the EppClient instance can serialize EPP requests to XML and send them
19/// to the registry and deserialize the XML responses from the registry to local types.
20///
21/// # Examples
22///
23/// ```no_run
24/// # use std::collections::HashMap;
25/// # use std::net::ToSocketAddrs;
26/// # use std::time::Duration;
27/// #
28/// use instant_epp::EppClient;
29/// use instant_epp::domain::DomainCheck;
30/// use instant_epp::common::NoExtension;
31///
32/// # #[cfg(feature = "rustls")]
33/// # #[tokio::main]
34/// # async fn main() {
35/// // Create an instance of EppClient
36/// let timeout = Duration::from_secs(5);
37/// let mut client = match EppClient::connect("registry_name".to_string(), ("example.com".to_owned(), 7000), None, timeout).await {
38///     Ok(client) => client,
39///     Err(e) => panic!("Failed to create EppClient: {}",  e)
40/// };
41///
42/// // Make a EPP Hello call to the registry
43/// let greeting = client.hello().await.unwrap();
44/// println!("{:?}", greeting);
45///
46/// // Execute an EPP Command against the registry with distinct request and response objects
47/// let domain_check = DomainCheck { domains: &["eppdev.com", "eppdev.net"] };
48/// let response = client.transact(&domain_check, "transaction-id").await.unwrap();
49/// response
50///     .res_data()
51///     .unwrap()
52///     .list
53///     .iter()
54///     .for_each(|chk| println!("Domain: {}, Available: {}", chk.inner.id, chk.inner.available));
55/// # }
56/// #
57/// # #[cfg(not(feature = "rustls"))]
58/// # fn main() {}
59/// ```
60///
61/// The output would look like this:
62///
63/// ```text
64/// Domain: eppdev.com, Available: 1
65/// Domain: eppdev.net, Available: 1
66/// ```
67pub struct EppClient<C: Connector> {
68    connection: EppConnection<C>,
69}
70
71#[cfg(feature = "rustls")]
72impl EppClient<RustlsConnector> {
73    /// Connect to the specified `addr` and `hostname` over TLS
74    ///
75    /// The `registry` is used as a name in internal logging; `host` provides the host name
76    /// and port to connect to), `hostname` is sent as the TLS server name indication and
77    /// `identity` provides optional TLS client authentication (using) rustls as the TLS
78    /// implementation. The `timeout` limits the time spent on any underlying network operations.
79    ///
80    /// Alternatively, use `EppClient::new()` with any established `AsyncRead + AsyncWrite + Unpin`
81    /// implementation.
82    pub async fn connect(
83        registry: String,
84        server: (String, u16),
85        identity: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
86        timeout: Duration,
87    ) -> Result<Self, Error> {
88        let connector = RustlsConnector::new(server, identity).await?;
89        Self::new(connector, registry, timeout).await
90    }
91}
92
93impl<C: Connector> EppClient<C> {
94    /// Create an `EppClient` from an already established connection
95    pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
96        Ok(Self {
97            connection: EppConnection::new(connector, registry, timeout).await?,
98        })
99    }
100
101    /// Executes an EPP Hello call and returns the response as a `Greeting`
102    pub async fn hello(&mut self) -> Result<Greeting, Error> {
103        let xml = xml::serialize(Hello)?;
104
105        debug!("{}: hello: {}", self.connection.registry, &xml);
106        let response = self.connection.transact(&xml)?.await?;
107        debug!("{}: greeting: {}", self.connection.registry, &response);
108
109        xml::deserialize::<Greeting>(&response)
110    }
111
112    pub async fn transact<'c, 'e, Cmd, Ext>(
113        &mut self,
114        data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
115        id: &str,
116    ) -> Result<Response<Cmd::Response, Ext::Response>, Error>
117    where
118        Cmd: Transaction<Ext> + Command + 'c,
119        Ext: Extension + 'e,
120    {
121        let data = data.into();
122        let document = CommandWrapper::new(data.command, data.extension, id);
123        let xml = xml::serialize(&document)?;
124
125        debug!("{}: request: {}", self.connection.registry, &xml);
126        let response = self.connection.transact(&xml)?.await?;
127        debug!("{}: response: {}", self.connection.registry, &response);
128
129        let rsp = match xml::deserialize::<Response<Cmd::Response, Ext::Response>>(&response) {
130            Ok(rsp) => rsp,
131            Err(e) => {
132                error!(%response, "failed to deserialize response for transaction: {e}");
133                return Err(e);
134            }
135        };
136
137        if rsp.result.code.is_success() {
138            return Ok(rsp);
139        }
140
141        let err = crate::error::Error::Command(Box::new(ResponseStatus {
142            result: rsp.result,
143            tr_ids: rsp.tr_ids,
144        }));
145
146        Err(err)
147    }
148
149    /// Accepts raw EPP XML and returns the raw EPP XML response to it.
150    /// Not recommended for direct use but sometimes can be useful for debugging
151    pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
152        self.connection.transact(xml)?.await
153    }
154
155    /// Returns the greeting received on establishment of the connection in raw xml form
156    pub fn xml_greeting(&self) -> String {
157        String::from(&self.connection.greeting)
158    }
159
160    /// Returns the greeting received on establishment of the connection as an `Greeting`
161    pub fn greeting(&self) -> Result<Greeting, Error> {
162        xml::deserialize::<Greeting>(&self.connection.greeting)
163    }
164
165    pub async fn reconnect(&mut self) -> Result<(), Error> {
166        self.connection.reconnect().await
167    }
168
169    pub async fn shutdown(mut self) -> Result<(), Error> {
170        self.connection.shutdown().await
171    }
172}
173
174#[derive(Debug)]
175pub struct RequestData<'c, 'e, C, E> {
176    pub(crate) command: &'c C,
177    pub(crate) extension: Option<&'e E>,
178}
179
180impl<'c, C: Command> From<&'c C> for RequestData<'c, 'static, C, NoExtension> {
181    fn from(command: &'c C) -> Self {
182        Self {
183            command,
184            extension: None,
185        }
186    }
187}
188
189impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, 'e, C, E> {
190    fn from((command, extension): (&'c C, &'e E)) -> Self {
191        Self {
192            command,
193            extension: Some(extension),
194        }
195    }
196}
197
198// Manual impl because this does not depend on whether `C` and `E` are `Clone`
199impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
200    fn clone(&self) -> Self {
201        *self
202    }
203}
204
205// Manual impl because this does not depend on whether `C` and `E` are `Copy`
206impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
207
208#[cfg(feature = "rustls")]
209pub use rustls_connector::RustlsConnector;
210
211#[cfg(feature = "rustls")]
212mod rustls_connector {
213    use std::io;
214    use std::sync::Arc;
215    use std::time::Duration;
216
217    use async_trait::async_trait;
218    use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
219    use tokio::net::lookup_host;
220    use tokio::net::TcpStream;
221    use tokio_rustls::client::TlsStream;
222    use tokio_rustls::rustls::{ClientConfig, RootCertStore};
223    use tokio_rustls::TlsConnector;
224    use tracing::info;
225
226    use crate::connection::{self, Connector};
227    use crate::error::Error;
228
229    pub struct RustlsConnector {
230        inner: TlsConnector,
231        domain: ServerName<'static>,
232        server: (String, u16),
233    }
234
235    impl RustlsConnector {
236        pub async fn new(
237            server: (String, u16),
238            identity: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
239        ) -> Result<Self, Error> {
240            let mut roots = RootCertStore::empty();
241            for cert in rustls_native_certs::load_native_certs()? {
242                roots.add(cert).map_err(|err| {
243                    Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>
244                })?;
245            }
246
247            let builder = ClientConfig::builder().with_root_certificates(roots);
248
249            let config = match identity {
250                Some((certs, key)) => builder
251                    .with_client_auth_cert(certs, key)
252                    .map_err(|e| Error::Other(e.into()))?,
253                None => builder.with_no_client_auth(),
254            };
255
256            let domain = ServerName::try_from(server.0.as_str())
257                .map_err(|_| {
258                    io::Error::new(
259                        io::ErrorKind::InvalidInput,
260                        format!("invalid domain: {}", server.0),
261                    )
262                })?
263                .to_owned();
264
265            Ok(Self {
266                inner: TlsConnector::from(Arc::new(config)),
267                domain,
268                server,
269            })
270        }
271    }
272
273    #[async_trait]
274    impl Connector for RustlsConnector {
275        type Connection = TlsStream<TcpStream>;
276
277        async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
278            info!("Connecting to server: {}:{}", self.server.0, self.server.1);
279            let addr = match lookup_host(&self.server).await?.next() {
280                Some(addr) => addr,
281                None => {
282                    return Err(Error::Io(io::Error::new(
283                        io::ErrorKind::InvalidInput,
284                        format!("Invalid host: {}", &self.server.0),
285                    )))
286                }
287            };
288
289            let stream = TcpStream::connect(addr).await?;
290            let future = self.inner.connect(self.domain.clone(), stream);
291            connection::timeout(timeout, future).await
292        }
293    }
294}