1use std::io;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6#[cfg(feature = "tokio-rustls")]
7use tokio::net::lookup_host;
8use tokio::net::TcpStream;
9#[cfg(feature = "tokio-rustls")]
10use tokio_rustls::client::TlsStream;
11#[cfg(feature = "tokio-rustls")]
12use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName};
13#[cfg(feature = "tokio-rustls")]
14use tokio_rustls::TlsConnector;
15use tracing::{debug, error, info};
16
17use crate::common::{Certificate, NoExtension, PrivateKey};
18pub use crate::connection::Connector;
19use crate::connection::{self, EppConnection};
20use crate::error::Error;
21use crate::hello::{Greeting, GreetingDocument, HelloDocument};
22use crate::request::{Command, CommandDocument, Extension, Transaction};
23use crate::response::{Response, ResponseDocument, ResponseStatus};
24use crate::xml;
25
26pub struct EppClient<C: Connector> {
71 connection: EppConnection<C>,
72}
73
74#[cfg(feature = "tokio-rustls")]
75impl EppClient<RustlsConnector> {
76 pub async fn connect(
86 registry: String,
87 server: (String, u16),
88 identity: Option<(Vec<Certificate>, PrivateKey)>,
89 timeout: Duration,
90 ) -> Result<Self, Error> {
91 let connector = RustlsConnector::new(server, identity).await?;
92 Self::new(connector, registry, timeout).await
93 }
94}
95
96impl<C: Connector> EppClient<C> {
97 pub async fn new(connector: C, registry: String, timeout: Duration) -> Result<Self, Error> {
99 Ok(Self {
100 connection: EppConnection::new(connector, registry, timeout).await?,
101 })
102 }
103
104 pub async fn hello(&mut self) -> Result<Greeting, Error> {
106 let xml = xml::serialize(&HelloDocument::default())?;
107
108 debug!("{}: hello: {}", self.connection.registry, &xml);
109 let response = self.connection.transact(&xml)?.await?;
110 debug!("{}: greeting: {}", self.connection.registry, &response);
111
112 Ok(xml::deserialize::<GreetingDocument>(&response)?.data)
113 }
114
115 pub async fn transact<'c, 'e, Cmd, Ext>(
116 &mut self,
117 data: impl Into<RequestData<'c, 'e, Cmd, Ext>>,
118 id: &str,
119 ) -> Result<Response<Cmd::Response, Ext::Response>, Error>
120 where
121 Cmd: Transaction<Ext> + Command + 'c,
122 Ext: Extension + 'e,
123 {
124 let data = data.into();
125 let document = CommandDocument::new(data.command, data.extension, id);
126 let xml = xml::serialize(&document)?;
127
128 debug!("{}: request: {}", self.connection.registry, &xml);
129 let response = self.connection.transact(&xml)?.await?;
130 debug!("{}: response: {}", self.connection.registry, &response);
131
132 let rsp =
133 match xml::deserialize::<ResponseDocument<Cmd::Response, Ext::Response>>(&response) {
134 Ok(rsp) => rsp,
135 Err(e) => {
136 error!(%response, "failed to deserialize response for transaction: {e}");
137 return Err(e);
138 }
139 };
140
141 if rsp.data.result.code.is_success() {
142 return Ok(rsp.data);
143 }
144
145 let err = crate::error::Error::Command(Box::new(ResponseStatus {
146 result: rsp.data.result,
147 tr_ids: rsp.data.tr_ids,
148 }));
149
150 error!(%response, "Failed to deserialize response for transaction: {}", err);
151 Err(err)
152 }
153
154 pub async fn transact_xml(&mut self, xml: &str) -> Result<String, Error> {
157 self.connection.transact(xml)?.await
158 }
159
160 pub fn xml_greeting(&self) -> String {
162 String::from(&self.connection.greeting)
163 }
164
165 pub fn greeting(&self) -> Result<Greeting, Error> {
167 xml::deserialize::<GreetingDocument>(&self.connection.greeting).map(|obj| obj.data)
168 }
169
170 pub async fn reconnect(&mut self) -> Result<(), Error> {
171 self.connection.reconnect().await
172 }
173
174 pub async fn shutdown(mut self) -> Result<(), Error> {
175 self.connection.shutdown().await
176 }
177}
178
179#[derive(Debug)]
180pub struct RequestData<'c, 'e, C, E> {
181 pub(crate) command: &'c C,
182 pub(crate) extension: Option<&'e E>,
183}
184
185impl<'c, C: Command> From<&'c C> for RequestData<'c, 'static, C, NoExtension> {
186 fn from(command: &'c C) -> Self {
187 Self {
188 command,
189 extension: None,
190 }
191 }
192}
193
194impl<'c, 'e, C: Command, E: Extension> From<(&'c C, &'e E)> for RequestData<'c, 'e, C, E> {
195 fn from((command, extension): (&'c C, &'e E)) -> Self {
196 Self {
197 command,
198 extension: Some(extension),
199 }
200 }
201}
202
203impl<'c, 'e, C, E> Clone for RequestData<'c, 'e, C, E> {
205 fn clone(&self) -> Self {
206 Self {
207 command: self.command,
208 extension: self.extension,
209 }
210 }
211}
212
213impl<'c, 'e, C, E> Copy for RequestData<'c, 'e, C, E> {}
215
216#[cfg(feature = "tokio-rustls")]
217pub struct RustlsConnector {
218 inner: TlsConnector,
219 domain: ServerName,
220 server: (String, u16),
221}
222
223impl RustlsConnector {
224 pub async fn new(
225 server: (String, u16),
226 identity: Option<(Vec<Certificate>, PrivateKey)>,
227 ) -> Result<Self, Error> {
228 let mut roots = RootCertStore::empty();
229 roots.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
230 OwnedTrustAnchor::from_subject_spki_name_constraints(
231 ta.subject,
232 ta.spki,
233 ta.name_constraints,
234 )
235 }));
236
237 let builder = ClientConfig::builder()
238 .with_safe_defaults()
239 .with_root_certificates(roots);
240
241 let config = match identity {
242 Some((certs, key)) => {
243 let certs = certs
244 .into_iter()
245 .map(|cert| tokio_rustls::rustls::Certificate(cert.0))
246 .collect();
247 builder
248 .with_single_cert(certs, tokio_rustls::rustls::PrivateKey(key.0))
249 .map_err(|e| Error::Other(e.into()))?
250 }
251 None => builder.with_no_client_auth(),
252 };
253
254 let domain = server.0.as_str().try_into().map_err(|_| {
255 io::Error::new(
256 io::ErrorKind::InvalidInput,
257 format!("Invalid domain: {}", server.0),
258 )
259 })?;
260
261 Ok(Self {
262 inner: TlsConnector::from(Arc::new(config)),
263 domain,
264 server,
265 })
266 }
267}
268
269#[cfg(feature = "tokio-rustls")]
270#[async_trait]
271impl Connector for RustlsConnector {
272 type Connection = TlsStream<TcpStream>;
273
274 async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error> {
275 info!("Connecting to server: {}:{}", self.server.0, self.server.1);
276 let addr = match lookup_host(&self.server).await?.next() {
277 Some(addr) => addr,
278 None => {
279 return Err(Error::Io(io::Error::new(
280 io::ErrorKind::InvalidInput,
281 format!("Invalid host: {}", &self.server.0),
282 )))
283 }
284 };
285
286 let stream = TcpStream::connect(addr).await?;
287 let future = self.inner.connect(self.domain.clone(), stream);
288 connection::timeout(timeout, future).await
289 }
290}