1use std::time::Duration;
2
3#[cfg(feature = "rustls")]
4use rustls_pki_types::{CertificateDer, PrivateKeyDer};
5use tracing::{debug, error};
6
7use crate::common::NoExtension;
8pub use crate::connection::Connector;
9use crate::connection::EppConnection;
10use crate::error::Error;
11use crate::hello::{Greeting, Hello};
12use crate::request::{Command, CommandWrapper, Extension, Transaction};
13use crate::response::{Response, ResponseStatus};
14use crate::xml;
15
16pub struct EppClient<C: Connector> {
68 connection: EppConnection<C>,
69}
70
71#[cfg(feature = "rustls")]
72impl EppClient<RustlsConnector> {
73 pub async fn connect(
83 registry: String,
84 server: (String, u16),
85 identity: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
86 timeout: Duration,
87 ) -> Result<Self, Error> {
88 let connector = RustlsConnector::new(server, identity).await?;
89 Self::new(connector, registry, timeout).await
90 }
91}
92
93impl<C: Connector> EppClient<C> {
94 pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
96 Ok(Self {
97 connection: EppConnection::new(connector, registry, timeout).await?,
98 })
99 }
100
101 pub async fn hello(&mut self) -> Result<Greeting, Error> {
103 let xml = xml::serialize(Hello)?;
104
105 debug!("{}: hello: {}", self.connection.registry, &xml);
106 let response = self.connection.transact(&xml)?.await?;
107 debug!("{}: greeting: {}", self.connection.registry, &response);
108
109 xml::deserialize::<Greeting>(&response)
110 }
111
112 pub async fn transact<'c, 'e, Cmd, Ext>(
113 &mut self,
114 data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
115 id: &str,
116 ) -> Result<Response<Cmd::Response, Ext::Response>, Error>
117 where
118 Cmd: Transaction<Ext> + Command + 'c,
119 Ext: Extension + 'e,
120 {
121 let data = data.into();
122 let document = CommandWrapper::new(data.command, data.extension, id);
123 let xml = xml::serialize(&document)?;
124
125 debug!("{}: request: {}", self.connection.registry, &xml);
126 let response = self.connection.transact(&xml)?.await?;
127 debug!("{}: response: {}", self.connection.registry, &response);
128
129 let rsp = match xml::deserialize::<Response<Cmd::Response, Ext::Response>>(&response) {
130 Ok(rsp) => rsp,
131 Err(e) => {
132 error!(%response, "failed to deserialize response for transaction: {e}");
133 return Err(e);
134 }
135 };
136
137 if rsp.result.code.is_success() {
138 return Ok(rsp);
139 }
140
141 let err = crate::error::Error::Command(Box::new(ResponseStatus {
142 result: rsp.result,
143 tr_ids: rsp.tr_ids,
144 }));
145
146 Err(err)
147 }
148
149 pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
152 self.connection.transact(xml)?.await
153 }
154
155 pub fn xml_greeting(&self) -> String {
157 String::from(&self.connection.greeting)
158 }
159
160 pub fn greeting(&self) -> Result<Greeting, Error> {
162 xml::deserialize::<Greeting>(&self.connection.greeting)
163 }
164
165 pub async fn reconnect(&mut self) -> Result<(), Error> {
166 self.connection.reconnect().await
167 }
168
169 pub async fn shutdown(mut self) -> Result<(), Error> {
170 self.connection.shutdown().await
171 }
172}
173
174#[derive(Debug)]
175pub struct RequestData<'c, 'e, C, E> {
176 pub(crate) command: &'c C,
177 pub(crate) extension: Option<&'e E>,
178}
179
180impl<'c, C: Command> From<&'c C> for RequestData<'c, 'static, C, NoExtension> {
181 fn from(command: &'c C) -> Self {
182 Self {
183 command,
184 extension: None,
185 }
186 }
187}
188
189impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, 'e, C, E> {
190 fn from((command, extension): (&'c C, &'e E)) -> Self {
191 Self {
192 command,
193 extension: Some(extension),
194 }
195 }
196}
197
198impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
200 fn clone(&self) -> Self {
201 *self
202 }
203}
204
205impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
207
208#[cfg(feature = "rustls")]
209pub use rustls_connector::RustlsConnector;
210
211#[cfg(feature = "rustls")]
212mod rustls_connector {
213 use std::io;
214 use std::sync::Arc;
215 use std::time::Duration;
216
217 use async_trait::async_trait;
218 use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
219 use tokio::net::lookup_host;
220 use tokio::net::TcpStream;
221 use tokio_rustls::client::TlsStream;
222 use tokio_rustls::rustls::{ClientConfig, RootCertStore};
223 use tokio_rustls::TlsConnector;
224 use tracing::info;
225
226 use crate::connection::{self, Connector};
227 use crate::error::Error;
228
229 pub struct RustlsConnector {
230 inner: TlsConnector,
231 domain: ServerName<'static>,
232 server: (String, u16),
233 }
234
235 impl RustlsConnector {
236 pub async fn new(
237 server: (String, u16),
238 identity: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
239 ) -> Result<Self, Error> {
240 let mut roots = RootCertStore::empty();
241 for cert in rustls_native_certs::load_native_certs()? {
242 roots.add(cert).map_err(|err| {
243 Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>
244 })?;
245 }
246
247 let builder = ClientConfig::builder().with_root_certificates(roots);
248
249 let config = match identity {
250 Some((certs, key)) => builder
251 .with_client_auth_cert(certs, key)
252 .map_err(|e| Error::Other(e.into()))?,
253 None => builder.with_no_client_auth(),
254 };
255
256 let domain = ServerName::try_from(server.0.as_str())
257 .map_err(|_| {
258 io::Error::new(
259 io::ErrorKind::InvalidInput,
260 format!("invalid domain: {}", server.0),
261 )
262 })?
263 .to_owned();
264
265 Ok(Self {
266 inner: TlsConnector::from(Arc::new(config)),
267 domain,
268 server,
269 })
270 }
271 }
272
273 #[async_trait]
274 impl Connector for RustlsConnector {
275 type Connection = TlsStream<TcpStream>;
276
277 async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
278 info!("Connecting to server: {}:{}", self.server.0, self.server.1);
279 let addr = match lookup_host(&self.server).await?.next() {
280 Some(addr) => addr,
281 None => {
282 return Err(Error::Io(io::Error::new(
283 io::ErrorKind::InvalidInput,
284 format!("Invalid host: {}", &self.server.0),
285 )))
286 }
287 };
288
289 let stream = TcpStream::connect(addr).await?;
290 let future = self.inner.connect(self.domain.clone(), stream);
291 connection::timeout(timeout, future).await
292 }
293 }
294}