open_protocol_client/
client.rs

1use crate::network::Network;
2use bytes::{BufMut, BytesMut};
3use flume::{bounded, Receiver, Sender};
4use open_protocol::messages::communication::MID0001rev7;
5use open_protocol::messages::keep_alive::MID9999rev1;
6use open_protocol::{Header, Message};
7use open_protocol::{decode, encode::{self, Encode, Encoder}};
8use std::collections::VecDeque;
9use std::io;
10use std::net::SocketAddr;
11use std::pin::Pin;
12use std::str::FromStr;
13use std::time::Duration;
14use thiserror;
15use tokio::net::TcpStream;
16use tokio::time::{Instant, Sleep};
17use tokio::{select, time};
18
19#[derive(Debug, thiserror::Error)]
20pub enum ConnectionError {
21    #[error("Requests are done")]
22    RequestsDone,
23    #[error("Decode error: {0}")]
24    DecodeError(#[from] decode::Error),
25    #[error("Encode error: {0}")]
26    EncodeError(#[from] encode::Error),
27    #[error("IO error: {0}")]
28    IoError(#[from] io::Error),
29}
30
31#[derive(Debug)]
32pub enum Event {
33    Incoming(Message),
34    Outgoing(Message),
35}
36
37pub struct EventLoop {
38    network: Option<Network>,
39    requests_rx: Receiver<Message>,
40    pub requests_tx: Sender<Message>,
41    pending: VecDeque<Message>, // should not be added to yet...
42    events: VecDeque<Event>,
43    write_buf: BytesMut,
44    keepalive_timeout: Option<Pin<Box<Sleep>>>,
45}
46
47impl EventLoop {
48    pub fn new(socket: TcpStream) -> Self {
49        let (requests_tx, requests_rx) = bounded(1000);
50        let pending = VecDeque::with_capacity(1000);
51        let events = VecDeque::with_capacity(1000);
52
53        Self {
54            requests_tx,
55            requests_rx,
56            pending,
57            events,
58            write_buf: BytesMut::with_capacity(10 * 1024),
59            network: Some(Network::new(socket)),
60            keepalive_timeout: None,
61        }
62    }
63
64    async fn select(&mut self) -> Result<Event, ConnectionError> {
65        // let network = self.network.as_mut().unwrap();
66        // let await_acks = self.state.await_acks;
67
68        // let network = self.network.as_mut().unwrap();
69
70        let inflight_full = false; //self.state.inflight >= self.state.max_outgoing_inflight;
71        let collision = false; //self.state.collision.is_some();
72
73        if let Some(event) = self.events.pop_front() {
74            return Ok(event);
75        }
76
77        select! {
78            o = next_request(
79                &mut self.pending,
80                &self.requests_rx,
81                Duration::ZERO
82            ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
83                Ok(request) => {
84                    self.handle_outgoing_packet(request)?;
85                    self.network.as_mut().unwrap().flush(&mut self.write_buf).await?;
86                    Ok(self.events.pop_front().unwrap())
87                }
88                Err(_) => Err(ConnectionError::RequestsDone),
89            },
90
91            o = self.network.as_mut().unwrap().read(&mut self.events) => {
92                o?;
93                // flush all the acks and return first incoming packet
94                self.network.as_mut().unwrap().flush(&mut self.write_buf).await?;
95                Ok(self.events.pop_front().unwrap())
96            },
97
98            _ = self.keepalive_timeout.as_mut().unwrap() => {
99                let timeout = self.keepalive_timeout.as_mut().unwrap();
100                timeout.as_mut().reset(Instant::now() + Duration::from_secs(5));
101
102                self.handle_outgoing_packet(Message::MID9999rev1(MID9999rev1 {}))?;
103                self.network.as_mut().unwrap().flush(&mut self.write_buf).await?;
104                Ok(self.events.pop_front().unwrap())
105            }
106        }
107    }
108
109    fn handle_outgoing_packet(&mut self, request: Message) -> Result<(), ConnectionError> {
110        let mut payload_encoder = Encoder::new();
111        request.encode_payload(&mut payload_encoder)?;
112
113        let (mid, revision) = request.mid_revision();
114        let header = Header {
115            mid,
116            revision: Some(revision),
117            length: (payload_encoder.len() as u16) + 20,
118            ..Default::default()
119        };
120        let mut header_encoder = Encoder::new();
121        header.encode(&mut header_encoder)?;
122
123        self.write_buf.extend(header_encoder.as_slice());
124        self.write_buf.extend(payload_encoder.as_slice());
125        self.write_buf.put_u8(0x0);
126
127        self.events.push_back(Event::Outgoing(request));
128        Ok(())
129    }
130
131    pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
132        if self.keepalive_timeout.is_none() {
133            self.keepalive_timeout = Some(Box::pin(time::sleep(Duration::from_secs(5))));
134        }
135
136        match self.select().await {
137            Ok(v) => Ok(v),
138            Err(e) => {
139                // self.clean();
140                Err(e)
141            }
142        }
143    }
144}
145
146
147async fn next_request(
148    pending: &mut VecDeque<Message>,
149    rx: &Receiver<Message>,
150    pending_throttle: Duration,
151) -> Result<Message, ConnectionError> {
152    if !pending.is_empty() {
153        time::sleep(pending_throttle).await;
154        Ok(pending.pop_front().unwrap())
155    } else {
156        match rx.recv_async().await {
157            Ok(r) => Ok(r),
158            Err(_) => Err(ConnectionError::RequestsDone),
159        }
160    }
161}
162
163pub async fn connect() -> (Sender<Message>, EventLoop) {
164    let socket = TcpStream::connect(SocketAddr::from_str("127.0.0.1:4545").unwrap())
165        .await
166        .unwrap();
167
168    let event_loop = EventLoop::new(socket);
169    let sender = event_loop.requests_tx.clone();
170
171    sender.send_async(Message::MID0001rev1(MID0001rev7 { keep_alive: None })).await.unwrap();
172
173    (sender, event_loop)
174}