1use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9use std::{io, mem, str, u32};
10
11use async_trait::async_trait;
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
13use tracing::{debug, info};
14
15use crate::error::Error;
16
17pub(crate) struct EppConnection<C: Connector> {
19 pub(crate) registry: String,
20 connector: C,
21 stream: C::Connection,
22 pub(crate) greeting: String,
23 timeout: Duration,
24 current: Option<RequestState>,
29 next: Option<RequestState>,
34}
35
36impl<C: Connector> EppConnection<C> {
37 pub(crate) async fn new(
38 connector: C,
39 registry: String,
40 timeout: Duration,
41 ) -> Result<Self, Error> {
42 let mut this = Self {
43 registry,
44 stream: connector.connect(timeout).await?,
45 connector,
46 greeting: String::new(),
47 timeout,
48 current: None,
49 next: None,
50 };
51
52 this.read_greeting().await?;
53 Ok(this)
54 }
55
56 async fn read_greeting(&mut self) -> Result<(), Error> {
57 assert!(self.current.is_none());
58 self.current = Some(RequestState::ReadLength {
59 read: 0,
60 buf: vec![0; 256],
61 });
62
63 self.greeting = RequestFuture { conn: self }.await?;
64 Ok(())
65 }
66
67 pub(crate) async fn reconnect(&mut self) -> Result<(), Error> {
68 debug!("{}: reconnecting", self.registry);
69 let _ = self.current.take();
70 let _ = self.next.take();
71 self.stream = self.connector.connect(self.timeout).await?;
72 self.read_greeting().await?;
73 Ok(())
74 }
75
76 pub(crate) fn transact<'a>(&'a mut self, command: &str) -> Result<RequestFuture<'a, C>, Error> {
78 let new = RequestState::new(command)?;
79
80 match self.current.is_some() {
83 true => {
84 debug!(
85 "{}: Queueing up request in order to finish in-flight request",
86 self.registry
87 );
88 self.next = Some(new);
89 }
90 false => self.current = Some(new),
91 }
92
93 Ok(RequestFuture { conn: self })
94 }
95
96 pub(crate) async fn shutdown(&mut self) -> Result<(), Error> {
98 info!("{}: Closing connection", self.registry);
99 timeout(self.timeout, self.stream.shutdown()).await?;
100 Ok(())
101 }
102
103 fn handle(
104 &mut self,
105 mut state: RequestState,
106 cx: &mut Context<'_>,
107 ) -> Result<Transition, Error> {
108 match &mut state {
109 RequestState::Writing { mut start, buf } => {
110 let wrote = match Pin::new(&mut self.stream).poll_write(cx, &buf[start..]) {
111 Poll::Ready(Ok(wrote)) => wrote,
112 Poll::Ready(Err(err)) => return Err(err.into()),
113 Poll::Pending => return Ok(Transition::Pending(state)),
114 };
115
116 if wrote == 0 {
117 return Err(io::Error::new(
118 io::ErrorKind::UnexpectedEof,
119 format!("{}: Unexpected EOF while writing", self.registry),
120 )
121 .into());
122 }
123
124 start += wrote;
125 debug!(
126 "{}: Wrote {} bytes, {} out of {} done",
127 self.registry,
128 wrote,
129 start,
130 buf.len()
131 );
132
133 if start < buf.len() {
136 return Ok(Transition::Next(state));
137 }
138
139 Ok(Transition::Next(RequestState::ReadLength {
140 read: 0,
141 buf: vec![0; 256],
142 }))
143 }
144 RequestState::ReadLength { mut read, buf } => {
145 let mut read_buf = ReadBuf::new(&mut buf[read..]);
146 match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
147 Poll::Ready(Ok(())) => {}
148 Poll::Ready(Err(err)) => return Err(err.into()),
149 Poll::Pending => return Ok(Transition::Pending(state)),
150 };
151
152 let filled = read_buf.filled();
153 if filled.is_empty() {
154 return Err(io::Error::new(
155 io::ErrorKind::UnexpectedEof,
156 format!("{}: Unexpected EOF while reading length", self.registry),
157 )
158 .into());
159 }
160
161 read += filled.len();
166 if read < 4 {
167 return Ok(Transition::Next(state));
168 }
169
170 let expected = u32::from_be_bytes(filled[..4].try_into()?) as usize;
171 debug!("{}: Expected response length: {}", self.registry, expected);
172 buf.resize(expected, 0);
173 Ok(Transition::Next(RequestState::Reading {
174 read,
175 buf: mem::take(buf),
176 expected,
177 }))
178 }
179 RequestState::Reading {
180 mut read,
181 buf,
182 expected,
183 } => {
184 let mut read_buf = ReadBuf::new(&mut buf[read..]);
185 match Pin::new(&mut self.stream).poll_read(cx, &mut read_buf) {
186 Poll::Ready(Ok(())) => {}
187 Poll::Ready(Err(err)) => return Err(err.into()),
188 Poll::Pending => return Ok(Transition::Pending(state)),
189 }
190
191 let filled = read_buf.filled();
192 if filled.is_empty() {
193 return Err(io::Error::new(
194 io::ErrorKind::UnexpectedEof,
195 format!("{}: Unexpected EOF while reading", self.registry),
196 )
197 .into());
198 }
199
200 read += filled.len();
201 debug!(
202 "{}: Read {} bytes, {} out of {} done",
203 self.registry,
204 filled.len(),
205 read,
206 expected
207 );
208
209 Ok(if read < *expected {
212 Transition::Next(state)
214 } else if let Some(next) = self.next.take() {
215 Transition::Next(next)
219 } else {
220 buf.drain(..4);
222 Transition::Done(String::from_utf8(mem::take(buf))?)
223 })
224 }
225 }
226 }
227}
228
229pub(crate) struct RequestFuture<'a, C: Connector> {
230 conn: &'a mut EppConnection<C>,
231}
232
233impl<'a, C: Connector> Future for RequestFuture<'a, C> {
234 type Output = Result<String, Error>;
235
236 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 let this = self.get_mut();
238 loop {
239 let state = this.conn.current.take().unwrap();
240 match this.conn.handle(state, cx) {
241 Ok(Transition::Next(next)) => {
242 this.conn.current = Some(next);
243 continue;
244 }
245 Ok(Transition::Pending(state)) => {
246 this.conn.current = Some(state);
247 return Poll::Pending;
248 }
249 Ok(Transition::Done(rsp)) => return Poll::Ready(Ok(rsp)),
250 Err(err) => {
251 this.conn.next = None;
253 return Poll::Ready(Err(err));
254 }
255 }
256 }
257 }
258}
259
260enum Transition {
262 Pending(RequestState),
263 Next(RequestState),
264 Done(String),
265}
266
267#[derive(Debug)]
268enum RequestState {
269 Writing {
271 start: usize,
273 buf: Vec<u8>,
275 },
276 ReadLength {
278 read: usize,
280 buf: Vec<u8>,
282 },
283 Reading {
285 read: usize,
287 buf: Vec<u8>,
292 expected: usize,
294 },
295}
296
297impl RequestState {
298 fn new(command: &str) -> Result<Self, Error> {
299 let len = command.len();
300
301 let buf_size = len + 4;
302 let mut buf: Vec<u8> = vec![0u8; buf_size];
303
304 let len = len + 4;
305 let len_u32: [u8; 4] = u32::to_be_bytes(len.try_into()?);
306
307 buf[..4].clone_from_slice(&len_u32);
308 buf[4..].clone_from_slice(command.as_bytes());
309 Ok(Self::Writing { start: 0, buf })
310 }
311}
312
313pub(crate) async fn timeout<T, E: Into<Error>>(
314 timeout: Duration,
315 fut: impl Future<Output = Result<T, E>>,
316) -> Result<T, Error> {
317 match tokio::time::timeout(timeout, fut).await {
318 Ok(Ok(t)) => Ok(t),
319 Ok(Err(e)) => Err(e.into()),
320 Err(_) => Err(Error::Timeout),
321 }
322}
323
324#[async_trait]
325pub trait Connector {
326 type Connection: AsyncRead + AsyncWrite + Unpin;
327
328 async fn connect(&self, timeout: Duration) -> Result<Self::Connection, Error>;
329}