1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
//! Manages registry connections and reading/writing to them
//!
//! See also [RFC 5734](https://tools.ietf.org/html/rfc5734).

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, mem, str, u32};

use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tracing::{debug, info};

use crate::error::Error;

/// EPP Connection struct with some metadata for the connection
pub(crate) struct EppConnection<C: Connector> {
    pub(crate) registry: String,
    connector: C,
    stream: C::Connection,
    pub(crate) greeting: String,
    timeout: Duration,
    // A request that is currently in flight
    //
    // Because the code here currently depends on only one request being in flight at a time,
    // this needs to be finished (written, and response read) before we start another one.
    current: Option<RequestState>,
    // The next request to be sent
    //
    // If we get a request while another request is in flight (because its future was dropped),
    // we will store it here until the current request is finished.
    next: Option<RequestState>,
}

impl<C: Connector> EppConnection<C> {
    pub(crate) async fn new(
        connector: C,
        registry: String,
        timeout: Duration,
    ) -> Result<Self, Error> {
        let mut this = Self {
            registry,
            stream: connector.connect(timeout).await?,
            connector,
            greeting: String::new(),
            timeout,
            current: None,
            next: None,
        };

        this.read_greeting().await?;
        Ok(this)
    }

    async fn read_greeting(&mut self) -> Result<(), Error> {
        assert!(self.current.is_none());
        self.current = Some(RequestState::ReadLength {
            read: 0,
            buf: vec![0; 256],
        });

        self.greeting = RequestFuture { conn: self }.await?;
        Ok(())
    }

    pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
        debug!("{}: reconnecting", self.registry);
        let _ = self.current.take();
        let _ = self.next.take();
        self.stream = self.connector.connect(self.timeout).await?;
        self.read_greeting().await?;
        Ok(())
    }

    /// Sends an EPP XML request to the registry and returns the response
    pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
        let new = RequestState::new(command)?;

        // If we have a request currently in flight, finish that first
        // If another request was queued up behind the one in flight, just replace it
        match self.current.is_some() {
            true => {
                debug!(
                    "{}: Queueing up request in order to finish in-flight request",
                    self.registry
                );
                self.next = Some(new);
            }
            false => self.current = Some(new),
        }

        Ok(RequestFuture { conn: self })
    }

    /// Closes the socket and shuts down the connection
    pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
        info!("{}: Closing connection", self.registry);
        timeout(self.timeout, self.stream.shutdown()).await?;
        Ok(())
    }

    fn handle(
        &mut self,
        mut state: RequestState,
        cx: &mut Context<'_>,
    ) -> Result<Transition, Error> {
        match &mut state {
            RequestState::Writing { mut start, buf } => {
                let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
                    Poll::Ready(Ok(wrote)) => wrote,
                    Poll::Ready(Err(err)) => return Err(err.into()),
                    Poll::Pending => return Ok(Transition::Pending(state)),
                };

                if wrote == 0 {
                    return Err(io::Error::new(
                        io::ErrorKind::UnexpectedEof,
                        format!("{}: Unexpected EOF while writing", self.registry),
                    )
                    .into());
                }

                start += wrote;
                debug!(
                    "{}: Wrote {} bytes, {} out of {} done",
                    self.registry,
                    wrote,
                    start,
                    buf.len()
                );

                // Transition to reading the response's frame header once
                // we've written the entire request
                if start < buf.len() {
                    return Ok(Transition::Next(state));
                }

                Ok(Transition::Next(RequestState::ReadLength {
                    read: 0,
                    buf: vec![0; 256],
                }))
            }
            RequestState::ReadLength { mut read, buf } => {
                let mut read_buf = ReadBuf::new(&mut buf[read..]);
                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
                    Poll::Ready(Ok(())) => {}
                    Poll::Ready(Err(err)) => return Err(err.into()),
                    Poll::Pending => return Ok(Transition::Pending(state)),
                };

                let filled = read_buf.filled();
                if filled.is_empty() {
                    return Err(io::Error::new(
                        io::ErrorKind::UnexpectedEof,
                        format!("{}: Unexpected EOF while reading length", self.registry),
                    )
                    .into());
                }

                // We're looking for the frame header which tells us how long the response will be.
                // The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
                // have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.

                read += filled.len();
                if read < 4 {
                    return Ok(Transition::Next(state));
                }

                let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
                debug!("{}: Expected response length: {}", self.registry, expected);
                buf.resize(expected, 0);
                Ok(Transition::Next(RequestState::Reading {
                    read,
                    buf: mem::take(buf),
                    expected,
                }))
            }
            RequestState::Reading {
                mut read,
                buf,
                expected,
            } => {
                let mut read_buf = ReadBuf::new(&mut buf[read..]);
                match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
                    Poll::Ready(Ok(())) => {}
                    Poll::Ready(Err(err)) => return Err(err.into()),
                    Poll::Pending => return Ok(Transition::Pending(state)),
                }

                let filled = read_buf.filled();
                if filled.is_empty() {
                    return Err(io::Error::new(
                        io::ErrorKind::UnexpectedEof,
                        format!("{}: Unexpected EOF while reading", self.registry),
                    )
                    .into());
                }

                read += filled.len();
                debug!(
                    "{}: Read {} bytes, {} out of {} done",
                    self.registry,
                    filled.len(),
                    read,
                    expected
                );

                //

                Ok(if read < *expected {
                    // If we haven't received the entire response yet, stick to the `Reading` state.
                    Transition::Next(state)
                } else if let Some(next) = self.next.take() {
                    // Otherwise, if we were just pushing through this request because it was already
                    // in flight when we started a new one, ignore this response and move to the
                    // next request (the one this `RequestFuture` is actually for).
                    Transition::Next(next)
                } else {
                    // Otherwise, drain off the frame header and convert the rest to a `String`.
                    buf.drain(..4);
                    Transition::Done(String::from_utf8(mem::take(buf))?)
                })
            }
        }
    }
}

pub(crate) struct RequestFuture<'a, C: Connector> {
    conn: &'a mut EppConnection<C>,
}

impl<'a, C: Connector> Future for RequestFuture<'a, C> {
    type Output = Result<String, Error>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.get_mut();
        loop {
            let state = this.conn.current.take().unwrap();
            match this.conn.handle(state, cx) {
                Ok(Transition::Next(next)) => {
                    this.conn.current = Some(next);
                    continue;
                }
                Ok(Transition::Pending(state)) => {
                    this.conn.current = Some(state);
                    return Poll::Pending;
                }
                Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
                Err(err) => {
                    // Assume the error means the connection can no longer be used
                    this.conn.next = None;
                    return Poll::Ready(Err(err));
                }
            }
        }
    }
}

// Transitions between `RequestState`s
enum Transition {
    Pending(RequestState),
    Next(RequestState),
    Done(String),
}

#[derive(Debug)]
enum RequestState {
    // Writing the request command out to the peer
    Writing {
        // The amount of bytes we've already written
        start: usize,
        // The full XML request
        buf: Vec<u8>,
    },
    // Reading the frame header (32-bit big-endian unsigned integer)
    ReadLength {
        // The amount of bytes we've already read
        read: usize,
        // The buffer we're using to read into
        buf: Vec<u8>,
    },
    // Reading the entire frame
    Reading {
        // The amount of bytes we've already read
        read: usize,
        // The buffer we're using to read into
        //
        // This will still have the frame header in it, needs to be cut off before
        // yielding the response to the caller.
        buf: Vec<u8>,
        // The expected length of the response according to the frame header
        expected: usize,
    },
}

impl RequestState {
    fn new(command: &str) -> Result<Self, Error> {
        let len = command.len();

        let buf_size = len + 4;
        let mut buf: Vec<u8> = vec![0u8; buf_size];

        let len = len + 4;
        let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);

        buf[..4].clone_from_slice(&len_u32);
        buf[4..].clone_from_slice(command.as_bytes());
        Ok(Self::Writing { start: 0, buf })
    }
}

pub(crate) async fn timeout<T, E: Into<Error>>(
    timeout: Duration,
    fut: impl Future<Output = Result<T, E>>,
) -> Result<T, Error> {
    match tokio::time::timeout(timeout, fut).await {
        Ok(Ok(t)) => Ok(t),
        Ok(Err(e)) => Err(e.into()),
        Err(_) => Err(Error::Timeout),
    }
}

#[async_trait]
pub trait Connector {
    type Connection: AsyncRead + AsyncWrite + Unpin;

    async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
}