1use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::Duration;
7use std::{io, mem, str, u32};
8
9use async_trait::async_trait;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
11use tracing::{debug, info};
12
13use crate::error::Error;
14
15pub(crate) struct EppConnection<C: Connector> {
17 pub registry: String,
18 connector: C,
19 stream: C::Connection,
20 pub greeting: String,
21 timeout: Duration,
22 current: Option<RequestState>,
27 next: Option<RequestState>,
32}
33
34impl<C: Connector> EppConnection<C> {
35 pub(crate) async fn new(
36 connector: C,
37 registry: String,
38 timeout: Duration,
39 ) -> Result<Self, Error> {
40 let mut this = Self {
41 registry,
42 stream: connector.connect(timeout).await?,
43 connector,
44 greeting: String::new(),
45 timeout,
46 current: None,
47 next: None,
48 };
49
50 this.read_greeting().await?;
51 Ok(this)
52 }
53
54 async fn read_greeting(&mut self) -> Result<(), Error> {
55 assert!(self.current.is_none());
56 self.current = Some(RequestState::ReadLength {
57 read: 0,
58 buf: vec![0; 256],
59 });
60
61 self.greeting = RequestFuture { conn: self }.await?;
62 Ok(())
63 }
64
65 pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
66 debug!("{}: reconnecting", self.registry);
67 let _ = self.current.take();
68 let _ = self.next.take();
69 self.stream = self.connector.connect(self.timeout).await?;
70 self.read_greeting().await?;
71 Ok(())
72 }
73
74 pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
76 let new = RequestState::new(command)?;
77
78 match self.current.is_some() {
81 true => {
82 debug!(
83 "{}: Queueing up request in order to finish in-flight request",
84 self.registry
85 );
86 self.next = Some(new);
87 }
88 false => self.current = Some(new),
89 }
90
91 Ok(RequestFuture { conn: self })
92 }
93
94 pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
96 info!("{}: Closing connection", self.registry);
97 timeout(self.timeout, self.stream.shutdown()).await?;
98 Ok(())
99 }
100
101 fn handle(
102 &mut self,
103 mut state: RequestState,
104 cx: &mut Context<'_>,
105 ) -> Result<Transition, Error> {
106 match &mut state {
107 RequestState::Writing { mut start, buf } => {
108 let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
109 Poll::Ready(Ok(wrote)) => wrote,
110 Poll::Ready(Err(err)) => return Err(err.into()),
111 Poll::Pending => return Ok(Transition::Pending(state)),
112 };
113
114 if wrote == 0 {
115 return Err(io::Error::new(
116 io::ErrorKind::UnexpectedEof,
117 format!("{}: Unexpected EOF while writing", self.registry),
118 )
119 .into());
120 }
121
122 start += wrote;
123 debug!(
124 "{}: Wrote {} bytes, {} out of {} done",
125 self.registry,
126 wrote,
127 start,
128 buf.len()
129 );
130
131 if start < buf.len() {
134 return Ok(Transition::Next(state));
135 }
136
137 Ok(Transition::Next(RequestState::ReadLength {
138 read: 0,
139 buf: vec![0; 256],
140 }))
141 }
142 RequestState::ReadLength { mut read, buf } => {
143 let mut read_buf = ReadBuf::new(&mut buf[read..]);
144 match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
145 Poll::Ready(Ok(())) => {}
146 Poll::Ready(Err(err)) => return Err(err.into()),
147 Poll::Pending => return Ok(Transition::Pending(state)),
148 };
149
150 let filled = read_buf.filled();
151 if filled.is_empty() {
152 return Err(io::Error::new(
153 io::ErrorKind::UnexpectedEof,
154 format!("{}: Unexpected EOF while reading length", self.registry),
155 )
156 .into());
157 }
158
159 read += filled.len();
164 if read < 4 {
165 return Ok(Transition::Next(state));
166 }
167
168 let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
169 debug!("{}: Expected response length: {}", self.registry, expected);
170 buf.resize(expected, 0);
171 Ok(Transition::Next(RequestState::Reading {
172 read,
173 buf: mem::take(buf),
174 expected,
175 }))
176 }
177 RequestState::Reading {
178 mut read,
179 buf,
180 expected,
181 } => {
182 let mut read_buf = ReadBuf::new(&mut buf[read..]);
183 match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
184 Poll::Ready(Ok(())) => {}
185 Poll::Ready(Err(err)) => return Err(err.into()),
186 Poll::Pending => return Ok(Transition::Pending(state)),
187 }
188
189 let filled = read_buf.filled();
190 if filled.is_empty() {
191 return Err(io::Error::new(
192 io::ErrorKind::UnexpectedEof,
193 format!("{}: Unexpected EOF while reading", self.registry),
194 )
195 .into());
196 }
197
198 read += filled.len();
199 debug!(
200 "{}: Read {} bytes, {} out of {} done",
201 self.registry,
202 filled.len(),
203 read,
204 expected
205 );
206
207 Ok(if read < *expected {
210 Transition::Next(state)
212 } else if let Some(next) = self.next.take() {
213 Transition::Next(next)
217 } else {
218 buf.drain(..4);
220 Transition::Done(String::from_utf8(mem::take(buf))?)
221 })
222 }
223 }
224 }
225}
226
227pub(crate) struct RequestFuture<'a, C: Connector> {
228 conn: &'a mut EppConnection<C>,
229}
230
231impl<'a, C: Connector> Future for RequestFuture<'a, C> {
232 type Output = Result<String, Error>;
233
234 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
235 let this = self.get_mut();
236 loop {
237 let state = this.conn.current.take().unwrap();
238 match this.conn.handle(state, cx) {
239 Ok(Transition::Next(next)) => {
240 this.conn.current = Some(next);
241 continue;
242 }
243 Ok(Transition::Pending(state)) => {
244 this.conn.current = Some(state);
245 return Poll::Pending;
246 }
247 Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
248 Err(err) => {
249 this.conn.next = None;
251 return Poll::Ready(Err(err));
252 }
253 }
254 }
255 }
256}
257
258enum Transition {
260 Pending(RequestState),
261 Next(RequestState),
262 Done(String),
263}
264
265#[derive(Debug)]
266enum RequestState {
267 Writing {
269 start: usize,
271 buf: Vec<u8>,
273 },
274 ReadLength {
276 read: usize,
278 buf: Vec<u8>,
280 },
281 Reading {
283 read: usize,
285 buf: Vec<u8>,
290 expected: usize,
292 },
293}
294
295impl RequestState {
296 fn new(command: &str) -> Result<Self, Error> {
297 let len = command.len();
298
299 let buf_size = len + 4;
300 let mut buf: Vec<u8> = vec![0u8; buf_size];
301
302 let len = len + 4;
303 let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
304
305 buf[..4].clone_from_slice(&len_u32);
306 buf[4..].clone_from_slice(command.as_bytes());
307 Ok(Self::Writing { start: 0, buf })
308 }
309}
310
311pub(crate) async fn timeout<T, E: Into<Error>>(
312 timeout: Duration,
313 fut: impl Future<Output = Result<T, E>>,
314) -> Result<T, Error> {
315 match tokio::time::timeout(timeout, fut).await {
316 Ok(Ok(t)) => Ok(t),
317 Ok(Err(e)) => Err(e.into()),
318 Err(_) => Err(Error::Timeout),
319 }
320}
321
322#[async_trait]
323pub trait Connector {
324 type Connection: AsyncRead + AsyncWrite + Unpin;
325
326 async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
327}