epp_client/
connection.rs

1//! Manages registry connections and reading/writing to them
2
3use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::Duration;
7use std::{io, mem, str, u32};
8
9use async_trait::async_trait;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
11use tracing::{debug, info};
12
13use crate::error::Error;
14
15/// EPP Connection struct with some metadata for the connection
16pub(crate) struct EppConnection<C: Connector> {
17    pub registry: String,
18    connector: C,
19    stream: C::Connection,
20    pub greeting: String,
21    timeout: Duration,
22    // A request that is currently in flight
23    //
24    // Because the code here currently depends on only one request being in flight at a time,
25    // this needs to be finished (written, and response read) before we start another one.
26    current: Option<RequestState>,
27    // The next request to be sent
28    //
29    // If we get a request while another request is in flight (because its future was dropped),
30    // we will store it here until the current request is finished.
31    next: Option<RequestState>,
32}
33
34impl<C: Connector> EppConnection<C> {
35    pub(crate) async fn new(
36        connector: C,
37        registry: String,
38        timeout: Duration,
39    ) -> Result<Self, Error> {
40        let mut this = Self {
41            registry,
42            stream: connector.connect(timeout).await?,
43            connector,
44            greeting: String::new(),
45            timeout,
46            current: None,
47            next: None,
48        };
49
50        this.read_greeting().await?;
51        Ok(this)
52    }
53
54    async fn read_greeting(&mut self) -> Result<(), Error> {
55        assert!(self.current.is_none());
56        self.current = Some(RequestState::ReadLength {
57            read: 0,
58            buf: vec![0; 256],
59        });
60
61        self.greeting = RequestFuture { conn: self }.await?;
62        Ok(())
63    }
64
65    pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
66        debug!("{}: reconnecting", self.registry);
67        let _ = self.current.take();
68        let _ = self.next.take();
69        self.stream = self.connector.connect(self.timeout).await?;
70        self.read_greeting().await?;
71        Ok(())
72    }
73
74    /// Sends an EPP XML request to the registry and returns the response
75    pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
76        let new = RequestState::new(command)?;
77
78        // If we have a request currently in flight, finish that first
79        // If another request was queued up behind the one in flight, just replace it
80        match self.current.is_some() {
81            true => {
82                debug!(
83                    "{}: Queueing up request in order to finish in-flight request",
84                    self.registry
85                );
86                self.next = Some(new);
87            }
88            false => self.current = Some(new),
89        }
90
91        Ok(RequestFuture { conn: self })
92    }
93
94    /// Closes the socket and shuts down the connection
95    pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
96        info!("{}: Closing connection", self.registry);
97        timeout(self.timeout, self.stream.shutdown()).await?;
98        Ok(())
99    }
100
101    fn handle(
102        &mut self,
103        mut state: RequestState,
104        cx: &mut Context<'_>,
105    ) -> Result<Transition, Error> {
106        match &mut state {
107            RequestState::Writing { mut start, buf } => {
108                let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
109                    Poll::Ready(Ok(wrote)) => wrote,
110                    Poll::Ready(Err(err)) => return Err(err.into()),
111                    Poll::Pending => return Ok(Transition::Pending(state)),
112                };
113
114                if wrote == 0 {
115                    return Err(io::Error::new(
116                        io::ErrorKind::UnexpectedEof,
117                        format!("{}: Unexpected EOF while writing", self.registry),
118                    )
119                    .into());
120                }
121
122                start += wrote;
123                debug!(
124                    "{}: Wrote {} bytes, {} out of {} done",
125                    self.registry,
126                    wrote,
127                    start,
128                    buf.len()
129                );
130
131                // Transition to reading the response's frame header once
132                // we've written the entire request
133                if start < buf.len() {
134                    return Ok(Transition::Next(state));
135                }
136
137                Ok(Transition::Next(RequestState::ReadLength {
138                    read: 0,
139                    buf: vec![0; 256],
140                }))
141            }
142            RequestState::ReadLength { mut read, buf } => {
143                let mut read_buf = ReadBuf::new(&mut buf[read..]);
144                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
145                    Poll::Ready(Ok(())) => {}
146                    Poll::Ready(Err(err)) => return Err(err.into()),
147                    Poll::Pending => return Ok(Transition::Pending(state)),
148                };
149
150                let filled = read_buf.filled();
151                if filled.is_empty() {
152                    return Err(io::Error::new(
153                        io::ErrorKind::UnexpectedEof,
154                        format!("{}: Unexpected EOF while reading length", self.registry),
155                    )
156                    .into());
157                }
158
159                // We're looking for the frame header which tells us how long the response will be.
160                // The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
161                // have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.
162
163                read += filled.len();
164                if read < 4 {
165                    return Ok(Transition::Next(state));
166                }
167
168                let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
169                debug!("{}: Expected response length: {}", self.registry, expected);
170                buf.resize(expected, 0);
171                Ok(Transition::Next(RequestState::Reading {
172                    read,
173                    buf: mem::take(buf),
174                    expected,
175                }))
176            }
177            RequestState::Reading {
178                mut read,
179                buf,
180                expected,
181            } => {
182                let mut read_buf = ReadBuf::new(&mut buf[read..]);
183                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
184                    Poll::Ready(Ok(())) => {}
185                    Poll::Ready(Err(err)) => return Err(err.into()),
186                    Poll::Pending => return Ok(Transition::Pending(state)),
187                }
188
189                let filled = read_buf.filled();
190                if filled.is_empty() {
191                    return Err(io::Error::new(
192                        io::ErrorKind::UnexpectedEof,
193                        format!("{}: Unexpected EOF while reading", self.registry),
194                    )
195                    .into());
196                }
197
198                read += filled.len();
199                debug!(
200                    "{}: Read {} bytes, {} out of {} done",
201                    self.registry,
202                    filled.len(),
203                    read,
204                    expected
205                );
206
207                //
208
209                Ok(if read < *expected {
210                    // If we haven't received the entire response yet, stick to the `Reading` state.
211                    Transition::Next(state)
212                } else if let Some(next) = self.next.take() {
213                    // Otherwise, if we were just pushing through this request because it was already
214                    // in flight when we started a new one, ignore this response and move to the
215                    // next request (the one this `RequestFuture` is actually for).
216                    Transition::Next(next)
217                } else {
218                    // Otherwise, drain off the frame header and convert the rest to a `String`.
219                    buf.drain(..4);
220                    Transition::Done(String::from_utf8(mem::take(buf))?)
221                })
222            }
223        }
224    }
225}
226
227pub(crate) struct RequestFuture<'a, C: Connector> {
228    conn: &'a mut EppConnection<C>,
229}
230
231impl<'a, C: Connector> Future for RequestFuture<'a, C> {
232    type Output = Result<String, Error>;
233
234    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235        let this = self.get_mut();
236        loop {
237            let state = this.conn.current.take().unwrap();
238            match this.conn.handle(state, cx) {
239                Ok(Transition::Next(next)) => {
240                    this.conn.current = Some(next);
241                    continue;
242                }
243                Ok(Transition::Pending(state)) => {
244                    this.conn.current = Some(state);
245                    return Poll::Pending;
246                }
247                Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
248                Err(err) => {
249                    // Assume the error means the connection can no longer be used
250                    this.conn.next = None;
251                    return Poll::Ready(Err(err));
252                }
253            }
254        }
255    }
256}
257
258// Transitions between `RequestState`s
259enum Transition {
260    Pending(RequestState),
261    Next(RequestState),
262    Done(String),
263}
264
265#[derive(Debug)]
266enum RequestState {
267    // Writing the request command out to the peer
268    Writing {
269        // The amount of bytes we've already written
270        start: usize,
271        // The full XML request
272        buf: Vec<u8>,
273    },
274    // Reading the frame header (32-bit big-endian unsigned integer)
275    ReadLength {
276        // The amount of bytes we've already read
277        read: usize,
278        // The buffer we're using to read into
279        buf: Vec<u8>,
280    },
281    // Reading the entire frame
282    Reading {
283        // The amount of bytes we've already read
284        read: usize,
285        // The buffer we're using to read into
286        //
287        // This will still have the frame header in it, needs to be cut off before
288        // yielding the response to the caller.
289        buf: Vec<u8>,
290        // The expected length of the response according to the frame header
291        expected: usize,
292    },
293}
294
295impl RequestState {
296    fn new(command: &str) -> Result<Self, Error> {
297        let len = command.len();
298
299        let buf_size = len + 4;
300        let mut buf: Vec<u8> = vec![0u8; buf_size];
301
302        let len = len + 4;
303        let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
304
305        buf[..4].clone_from_slice(&len_u32);
306        buf[4..].clone_from_slice(command.as_bytes());
307        Ok(Self::Writing { start: 0, buf })
308    }
309}
310
311pub(crate) async fn timeout<T, E: Into<Error>>(
312    timeout: Duration,
313    fut: impl Future<Output = Result<T, E>>,
314) -> Result<T, Error> {
315    match tokio::time::timeout(timeout, fut).await {
316        Ok(Ok(t)) => Ok(t),
317        Ok(Err(e)) => Err(e.into()),
318        Err(_) => Err(Error::Timeout),
319    }
320}
321
322#[async_trait]
323pub trait Connector {
324    type Connection: AsyncRead + AsyncWrite + Unpin;
325
326    async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
327}