1use std::io;
2
3use err_derive::Error;
4use packet::{Packet, PacketType};
5use tokio::io::{AsyncRead, AsyncWrite};
6
7mod packet;
8
9#[derive(Debug, Error)]
10pub enum Error {
11 #[error(display = "{}", _0)]
12 Io(#[error(source)] io::Error),
13
14 #[error(display = "authentication failed")]
15 Authentication,
16
17 #[error(display = "command exceded maximum length")]
18 CommandLength,
19}
20
21pub type Result<T> = std::result::Result<T, Error>;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub struct Connection<T> {
25 io: T,
26 default_packet_id: i32,
27 current_packet_id: i32,
28 max_payload_size: usize,
29}
30
31impl<T> Connection<T>
32where
33 T: Unpin + AsyncRead + AsyncWrite,
34{
35 pub fn builder() -> Builder {
36 Builder {
37 default_packet_id: 0,
38 max_payload_size: 4096 - 10,
39 }
40 }
41
42 pub async fn authenticate(&mut self, password: &str) -> Result<()> {
43 self.send(PacketType::Auth, password.to_string()).await?;
44
45 let packet = loop {
46 let packet = self.receive_packet().await;
47
48 if let Some(packet) = packet.ok() {
49 if packet.packet_type == PacketType::AuthResponse {
50 break packet;
51 }
52 }
53 };
54
55 if packet.is_error() {
56 Err(Error::Authentication)
57 } else {
58 Ok(())
59 }
60 }
61
62 pub async fn execute_command(&mut self, command: &str) -> Result<Vec<String>> {
63 if command.len() > self.max_payload_size {
64 return Err(Error::CommandLength);
65 }
66
67 self.send(PacketType::ExecCommand, command.to_string())
68 .await?;
69
70 let response = self.recieve().await?;
71
72 Ok(response)
73 }
74
75 async fn send(&mut self, packet_type: PacketType, payload: String) -> Result<()> {
76 let packet = packet::Packet::new(self.new_packet_id(), packet_type, payload);
77 self.send_packet(packet).await
78 }
79
80 async fn recieve(&mut self) -> Result<Vec<String>> {
81 let mut responses = Vec::new();
82
83 loop {
84 let response = self.recieve_single_response().await?;
85 responses.push(response);
86
87 if let Some(last) = responses.last() {
88 if last.is_empty() {
89 break;
90 }
91 }
92 }
93
94 Ok(responses)
95 }
96
97 async fn recieve_single_response(&mut self) -> Result<String> {
98 let packet = self.receive_packet().await?;
99
100 Ok(packet.payload.into())
101 }
102
103 async fn send_packet(&mut self, packet: Packet) -> Result<()> {
104 match packet.write_to_io(&mut self.io).await {
105 Ok(_) => Ok(()),
106 Err(err) => Err(Error::Io(err)),
107 }
108 }
109
110 async fn receive_packet(&mut self) -> Result<Packet> {
111 match Packet::read_from_io(&mut self.io).await {
112 Ok(packet) => Ok(packet),
113 Err(err) => Err(Error::Io(err)),
114 }
115 }
116
117 fn new_packet_id(&mut self) -> i32 {
118 let id = self.current_packet_id;
119
120 self.current_packet_id = self
121 .current_packet_id
122 .checked_add(1)
123 .unwrap_or(self.default_packet_id);
124
125 id
126 }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash)]
130pub struct Builder {
131 default_packet_id: i32,
132 max_payload_size: usize,
133}
134
135impl Builder {
136 pub fn default_packet_id(mut self, id: i32) -> Self {
137 self.default_packet_id = id;
138 self
139 }
140
141 pub fn max_payload_size(mut self, size: usize) -> Self {
142 self.max_payload_size = size;
143 self
144 }
145
146 pub fn build<T>(self, io: T) -> Connection<T>
147 where
148 T: Unpin + AsyncRead + AsyncWrite,
149 {
150 Connection {
151 io,
152 default_packet_id: self.default_packet_id,
153 current_packet_id: self.default_packet_id,
154 max_payload_size: self.max_payload_size,
155 }
156 }
157}