specul/
lib.rs

1use std::io;
2
3use err_derive::Error;
4use packet::{Packet, PacketType};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7mod packet;
8
9#[derive(Debug, Error)]
10pub enum Error {
11    #[error(display = "{}", _0)]
12    Io(#[error(source)] io::Error),
13
14    #[error(display = "authentication failed")]
15    Authentication,
16
17    #[error(display = "command exceded maximum length")]
18    CommandLength,
19}
20
21pub type Result<T> = std::result::Result<T, Error>;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub struct Connection<T> {
25    io: T,
26    default_packet_id: i32,
27    current_packet_id: i32,
28    max_payload_size: usize,
29}
30
31impl<T> Connection<T>
32where
33    T: Unpin + AsyncRead + AsyncWrite,
34{
35    pub fn builder() -> Builder {
36        Builder {
37            default_packet_id: 0,
38            max_payload_size: 4096 - 10,
39        }
40    }
41
42    pub async fn authenticate(&mut self, password: &str) -> Result<()> {
43        self.send(PacketType::Auth, password.to_string()).await?;
44
45        let packet = loop {
46            let packet = self.receive_packet().await;
47
48            if let Some(packet) = packet.ok() {
49                if packet.packet_type == PacketType::AuthResponse {
50                    break packet;
51                }
52            }
53        };
54
55        if packet.is_error() {
56            Err(Error::Authentication)
57        } else {
58            Ok(())
59        }
60    }
61
62    pub async fn execute_command(&mut self, command: &str) -> Result<Vec<String>> {
63        if command.len() > self.max_payload_size {
64            return Err(Error::CommandLength);
65        }
66
67        self.send(PacketType::ExecCommand, command.to_string())
68            .await?;
69
70        let response = self.recieve().await?;
71
72        Ok(response)
73    }
74
75    async fn send(&mut self, packet_type: PacketType, payload: String) -> Result<()> {
76        let packet = packet::Packet::new(self.new_packet_id(), packet_type, payload);
77        self.send_packet(packet).await
78    }
79
80    async fn recieve(&mut self) -> Result<Vec<String>> {
81        let mut responses = Vec::new();
82
83        loop {
84            let response = self.recieve_single_response().await?;
85            responses.push(response);
86
87            if let Some(last) = responses.last() {
88                if last.is_empty() {
89                    break;
90                }
91            }
92        }
93
94        Ok(responses)
95    }
96
97    async fn recieve_single_response(&mut self) -> Result<String> {
98        let packet = self.receive_packet().await?;
99
100        Ok(packet.payload.into())
101    }
102
103    async fn send_packet(&mut self, packet: Packet) -> Result<()> {
104        match packet.write_to_io(&mut self.io).await {
105            Ok(_) => Ok(()),
106            Err(err) => Err(Error::Io(err)),
107        }
108    }
109
110    async fn receive_packet(&mut self) -> Result<Packet> {
111        match Packet::read_from_io(&mut self.io).await {
112            Ok(packet) => Ok(packet),
113            Err(err) => Err(Error::Io(err)),
114        }
115    }
116
117    fn new_packet_id(&mut self) -> i32 {
118        let id = self.current_packet_id;
119
120        self.current_packet_id = self
121            .current_packet_id
122            .checked_add(1)
123            .unwrap_or(self.default_packet_id);
124
125        id
126    }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
130pub struct Builder {
131    default_packet_id: i32,
132    max_payload_size: usize,
133}
134
135impl Builder {
136    pub fn default_packet_id(mut self, id: i32) -> Self {
137        self.default_packet_id = id;
138        self
139    }
140
141    pub fn max_payload_size(mut self, size: usize) -> Self {
142        self.max_payload_size = size;
143        self
144    }
145
146    pub fn build<T>(self, io: T) -> Connection<T>
147    where
148        T: Unpin + AsyncRead + AsyncWrite,
149    {
150        Connection {
151            io,
152            default_packet_id: self.default_packet_id,
153            current_packet_id: self.default_packet_id,
154            max_payload_size: self.max_payload_size,
155        }
156    }
157}