instant_epp/
connection.rs

1//! Manages registry connections and reading/writing to them
2//!
3//! See also [RFC 5734](https://tools.ietf.org/html/rfc5734).
4
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9use std::{io, mem, str, u32};
10
11use async_trait::async_trait;
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
13use tracing::{debug, info};
14
15use crate::error::Error;
16
17/// EPP Connection struct with some metadata for the connection
18pub(crate) struct EppConnection<C: Connector> {
19    pub(crate) registry: String,
20    connector: C,
21    stream: C::Connection,
22    pub(crate) greeting: String,
23    timeout: Duration,
24    // A request that is currently in flight
25    //
26    // Because the code here currently depends on only one request being in flight at a time,
27    // this needs to be finished (written, and response read) before we start another one.
28    current: Option<RequestState>,
29    // The next request to be sent
30    //
31    // If we get a request while another request is in flight (because its future was dropped),
32    // we will store it here until the current request is finished.
33    next: Option<RequestState>,
34}
35
36impl<C: Connector> EppConnection<C> {
37    pub(crate) async fn new(
38        connector: C,
39        registry: String,
40        timeout: Duration,
41    ) -> Result<Self, Error> {
42        let mut this = Self {
43            registry,
44            stream: connector.connect(timeout).await?,
45            connector,
46            greeting: String::new(),
47            timeout,
48            current: None,
49            next: None,
50        };
51
52        this.read_greeting().await?;
53        Ok(this)
54    }
55
56    async fn read_greeting(&mut self) -> Result<(), Error> {
57        assert!(self.current.is_none());
58        self.current = Some(RequestState::ReadLength {
59            read: 0,
60            buf: vec![0; 256],
61        });
62
63        self.greeting = RequestFuture { conn: self }.await?;
64        Ok(())
65    }
66
67    pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
68        debug!("{}: reconnecting", self.registry);
69        let _ = self.current.take();
70        let _ = self.next.take();
71        self.stream = self.connector.connect(self.timeout).await?;
72        self.read_greeting().await?;
73        Ok(())
74    }
75
76    /// Sends an EPP XML request to the registry and returns the response
77    pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
78        let new = RequestState::new(command)?;
79
80        // If we have a request currently in flight, finish that first
81        // If another request was queued up behind the one in flight, just replace it
82        match self.current.is_some() {
83            true => {
84                debug!(
85                    "{}: Queueing up request in order to finish in-flight request",
86                    self.registry
87                );
88                self.next = Some(new);
89            }
90            false => self.current = Some(new),
91        }
92
93        Ok(RequestFuture { conn: self })
94    }
95
96    /// Closes the socket and shuts down the connection
97    pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
98        info!("{}: Closing connection", self.registry);
99        timeout(self.timeout, self.stream.shutdown()).await?;
100        Ok(())
101    }
102
103    fn handle(
104        &mut self,
105        mut state: RequestState,
106        cx: &mut Context<'_>,
107    ) -> Result<Transition, Error> {
108        match &mut state {
109            RequestState::Writing { mut start, buf } => {
110                let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
111                    Poll::Ready(Ok(wrote)) => wrote,
112                    Poll::Ready(Err(err)) => return Err(err.into()),
113                    Poll::Pending => return Ok(Transition::Pending(state)),
114                };
115
116                if wrote == 0 {
117                    return Err(io::Error::new(
118                        io::ErrorKind::UnexpectedEof,
119                        format!("{}: Unexpected EOF while writing", self.registry),
120                    )
121                    .into());
122                }
123
124                start += wrote;
125                debug!(
126                    "{}: Wrote {} bytes, {} out of {} done",
127                    self.registry,
128                    wrote,
129                    start,
130                    buf.len()
131                );
132
133                // Transition to reading the response's frame header once
134                // we've written the entire request
135                if start < buf.len() {
136                    return Ok(Transition::Next(state));
137                }
138
139                Ok(Transition::Next(RequestState::ReadLength {
140                    read: 0,
141                    buf: vec![0; 256],
142                }))
143            }
144            RequestState::ReadLength { mut read, buf } => {
145                let mut read_buf = ReadBuf::new(&mut buf[read..]);
146                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
147                    Poll::Ready(Ok(())) => {}
148                    Poll::Ready(Err(err)) => return Err(err.into()),
149                    Poll::Pending => return Ok(Transition::Pending(state)),
150                };
151
152                let filled = read_buf.filled();
153                if filled.is_empty() {
154                    return Err(io::Error::new(
155                        io::ErrorKind::UnexpectedEof,
156                        format!("{}: Unexpected EOF while reading length", self.registry),
157                    )
158                    .into());
159                }
160
161                // We're looking for the frame header which tells us how long the response will be.
162                // The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
163                // have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.
164
165                read += filled.len();
166                if read < 4 {
167                    return Ok(Transition::Next(state));
168                }
169
170                let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
171                debug!("{}: Expected response length: {}", self.registry, expected);
172                buf.resize(expected, 0);
173                Ok(Transition::Next(RequestState::Reading {
174                    read,
175                    buf: mem::take(buf),
176                    expected,
177                }))
178            }
179            RequestState::Reading {
180                mut read,
181                buf,
182                expected,
183            } => {
184                let mut read_buf = ReadBuf::new(&mut buf[read..]);
185                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
186                    Poll::Ready(Ok(())) => {}
187                    Poll::Ready(Err(err)) => return Err(err.into()),
188                    Poll::Pending => return Ok(Transition::Pending(state)),
189                }
190
191                let filled = read_buf.filled();
192                if filled.is_empty() {
193                    return Err(io::Error::new(
194                        io::ErrorKind::UnexpectedEof,
195                        format!("{}: Unexpected EOF while reading", self.registry),
196                    )
197                    .into());
198                }
199
200                read += filled.len();
201                debug!(
202                    "{}: Read {} bytes, {} out of {} done",
203                    self.registry,
204                    filled.len(),
205                    read,
206                    expected
207                );
208
209                //
210
211                Ok(if read < *expected {
212                    // If we haven't received the entire response yet, stick to the `Reading` state.
213                    Transition::Next(state)
214                } else if let Some(next) = self.next.take() {
215                    // Otherwise, if we were just pushing through this request because it was already
216                    // in flight when we started a new one, ignore this response and move to the
217                    // next request (the one this `RequestFuture` is actually for).
218                    Transition::Next(next)
219                } else {
220                    // Otherwise, drain off the frame header and convert the rest to a `String`.
221                    buf.drain(..4);
222                    Transition::Done(String::from_utf8(mem::take(buf))?)
223                })
224            }
225        }
226    }
227}
228
229pub(crate) struct RequestFuture<'a, C: Connector> {
230    conn: &'a mut EppConnection<C>,
231}
232
233impl<'a, C: Connector> Future for RequestFuture<'a, C> {
234    type Output = Result<String, Error>;
235
236    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        let this = self.get_mut();
238        loop {
239            let state = this.conn.current.take().unwrap();
240            match this.conn.handle(state, cx) {
241                Ok(Transition::Next(next)) => {
242                    this.conn.current = Some(next);
243                    continue;
244                }
245                Ok(Transition::Pending(state)) => {
246                    this.conn.current = Some(state);
247                    return Poll::Pending;
248                }
249                Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
250                Err(err) => {
251                    // Assume the error means the connection can no longer be used
252                    this.conn.next = None;
253                    return Poll::Ready(Err(err));
254                }
255            }
256        }
257    }
258}
259
260// Transitions between `RequestState`s
261enum Transition {
262    Pending(RequestState),
263    Next(RequestState),
264    Done(String),
265}
266
267#[derive(Debug)]
268enum RequestState {
269    // Writing the request command out to the peer
270    Writing {
271        // The amount of bytes we've already written
272        start: usize,
273        // The full XML request
274        buf: Vec<u8>,
275    },
276    // Reading the frame header (32-bit big-endian unsigned integer)
277    ReadLength {
278        // The amount of bytes we've already read
279        read: usize,
280        // The buffer we're using to read into
281        buf: Vec<u8>,
282    },
283    // Reading the entire frame
284    Reading {
285        // The amount of bytes we've already read
286        read: usize,
287        // The buffer we're using to read into
288        //
289        // This will still have the frame header in it, needs to be cut off before
290        // yielding the response to the caller.
291        buf: Vec<u8>,
292        // The expected length of the response according to the frame header
293        expected: usize,
294    },
295}
296
297impl RequestState {
298    fn new(command: &str) -> Result<Self, Error> {
299        let len = command.len();
300
301        let buf_size = len + 4;
302        let mut buf: Vec<u8> = vec![0u8; buf_size];
303
304        let len = len + 4;
305        let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
306
307        buf[..4].clone_from_slice(&len_u32);
308        buf[4..].clone_from_slice(command.as_bytes());
309        Ok(Self::Writing { start: 0, buf })
310    }
311}
312
313pub(crate) async fn timeout<T, E: Into<Error>>(
314    timeout: Duration,
315    fut: impl Future<Output = Result<T, E>>,
316) -> Result<T, Error> {
317    match tokio::time::timeout(timeout, fut).await {
318        Ok(Ok(t)) => Ok(t),
319        Ok(Err(e)) => Err(e.into()),
320        Err(_) => Err(Error::Timeout),
321    }
322}
323
324#[async_trait]
325pub trait Connector {
326    type Connection: AsyncRead + AsyncWrite + Unpin;
327
328    async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
329}