1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
//! Manages registry connections and reading/writing to them
//!
//! See also [RFC 5734](https://tools.ietf.org/html/rfc5734).
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, mem, str, u32};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tracing::{debug, info};
use crate::error::Error;
/// EPP Connection struct with some metadata for the connection
pub(crate) struct EppConnection<C: Connector> {
pub(crate) registry: String,
connector: C,
stream: C::Connection,
pub(crate) greeting: String,
timeout: Duration,
// A request that is currently in flight
//
// Because the code here currently depends on only one request being in flight at a time,
// this needs to be finished (written, and response read) before we start another one.
current: Option<RequestState>,
// The next request to be sent
//
// If we get a request while another request is in flight (because its future was dropped),
// we will store it here until the current request is finished.
next: Option<RequestState>,
}
impl<C: Connector> EppConnection<C> {
pub(crate) async fn new(
connector: C,
registry: String,
timeout: Duration,
) -> Result<Self, Error> {
let mut this = Self {
registry,
stream: connector.connect(timeout).await?,
connector,
greeting: String::new(),
timeout,
current: None,
next: None,
};
this.read_greeting().await?;
Ok(this)
}
async fn read_greeting(&mut self) -> Result<(), Error> {
assert!(self.current.is_none());
self.current = Some(RequestState::ReadLength {
read: 0,
buf: vec![0; 256],
});
self.greeting = RequestFuture { conn: self }.await?;
Ok(())
}
pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
debug!("{}: reconnecting", self.registry);
let _ = self.current.take();
let _ = self.next.take();
self.stream = self.connector.connect(self.timeout).await?;
self.read_greeting().await?;
Ok(())
}
/// Sends an EPP XML request to the registry and returns the response
pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
let new = RequestState::new(command)?;
// If we have a request currently in flight, finish that first
// If another request was queued up behind the one in flight, just replace it
match self.current.is_some() {
true => {
debug!(
"{}: Queueing up request in order to finish in-flight request",
self.registry
);
self.next = Some(new);
}
false => self.current = Some(new),
}
Ok(RequestFuture { conn: self })
}
/// Closes the socket and shuts down the connection
pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
info!("{}: Closing connection", self.registry);
timeout(self.timeout, self.stream.shutdown()).await?;
Ok(())
}
fn handle(
&mut self,
mut state: RequestState,
cx: &mut Context<'_>,
) -> Result<Transition, Error> {
match &mut state {
RequestState::Writing { mut start, buf } => {
let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
Poll::Ready(Ok(wrote)) => wrote,
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
};
if wrote == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while writing", self.registry),
)
.into());
}
start += wrote;
debug!(
"{}: Wrote {} bytes, {} out of {} done",
self.registry,
wrote,
start,
buf.len()
);
// Transition to reading the response's frame header once
// we've written the entire request
if start < buf.len() {
return Ok(Transition::Next(state));
}
Ok(Transition::Next(RequestState::ReadLength {
read: 0,
buf: vec![0; 256],
}))
}
RequestState::ReadLength { mut read, buf } => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
};
let filled = read_buf.filled();
if filled.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while reading length", self.registry),
)
.into());
}
// We're looking for the frame header which tells us how long the response will be.
// The frame header is a 32-bit (4-byte) big-endian unsigned integer. If we don't
// have 4 bytes yet, stay in the `ReadLength` state, otherwise we transition to `Reading`.
read += filled.len();
if read < 4 {
return Ok(Transition::Next(state));
}
let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
debug!("{}: Expected response length: {}", self.registry, expected);
buf.resize(expected, 0);
Ok(Transition::Next(RequestState::Reading {
read,
buf: mem::take(buf),
expected,
}))
}
RequestState::Reading {
mut read,
buf,
expected,
} => {
let mut read_buf = ReadBuf::new(&mut buf[read..]);
match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Err(err.into()),
Poll::Pending => return Ok(Transition::Pending(state)),
}
let filled = read_buf.filled();
if filled.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("{}: Unexpected EOF while reading", self.registry),
)
.into());
}
read += filled.len();
debug!(
"{}: Read {} bytes, {} out of {} done",
self.registry,
filled.len(),
read,
expected
);
//
Ok(if read < *expected {
// If we haven't received the entire response yet, stick to the `Reading` state.
Transition::Next(state)
} else if let Some(next) = self.next.take() {
// Otherwise, if we were just pushing through this request because it was already
// in flight when we started a new one, ignore this response and move to the
// next request (the one this `RequestFuture` is actually for).
Transition::Next(next)
} else {
// Otherwise, drain off the frame header and convert the rest to a `String`.
buf.drain(..4);
Transition::Done(String::from_utf8(mem::take(buf))?)
})
}
}
}
}
pub(crate) struct RequestFuture<'a, C: Connector> {
conn: &'a mut EppConnection<C>,
}
impl<'a, C: Connector> Future for RequestFuture<'a, C> {
type Output = Result<String, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
let state = this.conn.current.take().unwrap();
match this.conn.handle(state, cx) {
Ok(Transition::Next(next)) => {
this.conn.current = Some(next);
continue;
}
Ok(Transition::Pending(state)) => {
this.conn.current = Some(state);
return Poll::Pending;
}
Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
Err(err) => {
// Assume the error means the connection can no longer be used
this.conn.next = None;
return Poll::Ready(Err(err));
}
}
}
}
}
// Transitions between `RequestState`s
enum Transition {
Pending(RequestState),
Next(RequestState),
Done(String),
}
#[derive(Debug)]
enum RequestState {
// Writing the request command out to the peer
Writing {
// The amount of bytes we've already written
start: usize,
// The full XML request
buf: Vec<u8>,
},
// Reading the frame header (32-bit big-endian unsigned integer)
ReadLength {
// The amount of bytes we've already read
read: usize,
// The buffer we're using to read into
buf: Vec<u8>,
},
// Reading the entire frame
Reading {
// The amount of bytes we've already read
read: usize,
// The buffer we're using to read into
//
// This will still have the frame header in it, needs to be cut off before
// yielding the response to the caller.
buf: Vec<u8>,
// The expected length of the response according to the frame header
expected: usize,
},
}
impl RequestState {
fn new(command: &str) -> Result<Self, Error> {
let len = command.len();
let buf_size = len + 4;
let mut buf: Vec<u8> = vec![0u8; buf_size];
let len = len + 4;
let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
buf[..4].clone_from_slice(&len_u32);
buf[4..].clone_from_slice(command.as_bytes());
Ok(Self::Writing { start: 0, buf })
}
}
pub(crate) async fn timeout<T, E: Into<Error>>(
timeout: Duration,
fut: impl Future<Output = Result<T, E>>,
) -> Result<T, Error> {
match tokio::time::timeout(timeout, fut).await {
Ok(Ok(t)) => Ok(t),
Ok(Err(e)) => Err(e.into()),
Err(_) => Err(Error::Timeout),
}
}
#[async_trait]
pub trait Connector {
type Connection: AsyncRead + AsyncWrite + Unpin;
async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
}