1use crate::message::{
6 Message, MessageError, ReCmdMsg, ReCmdMsgPayload, ReCmdMsgType, HDR_LEN_ON_WIRE,
7};
8use bytes::BytesMut;
9use std::fmt;
10use std::io::{Read, Write};
11use std::net::{IpAddr, SocketAddr, TcpStream};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::config::Config;
15use crate::crypt::Crypt;
16
17#[derive(Debug)]
18pub struct Snd {
19 srv_ip: IpAddr,
20 port: u16,
21 data: Vec<u8>,
22 config: Config,
23}
24
25#[derive(Debug)]
26pub enum SndError {
27 TcpError,
28}
29
30impl std::error::Error for SndError {}
31
32impl fmt::Display for SndError {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 match self {
35 SndError::TcpError => write!(f, "TCP error"),
36 }
37 }
38}
39
40impl From<std::io::Error> for SndError {
41 fn from(_e: std::io::Error) -> Self {
42 SndError::TcpError
43 }
44}
45
46impl From<MessageError> for SndError {
47 fn from(_e: MessageError) -> Self {
48 SndError::TcpError
49 }
50}
51
52impl Snd {
53 pub fn new(srv_ip: IpAddr, port: u16, data: Vec<u8>) -> Self {
54 Snd {
55 srv_ip,
56 port,
57 data,
58 config: Config::init(),
59 }
60 }
61
62 pub fn run(&self) -> Result<Vec<u8>, SndError> {
63 match TcpStream::connect_timeout(
64 &SocketAddr::new(self.srv_ip, self.port),
65 self.config.get_tcp_connect_to(),
66 ) {
67 Ok(mut stream) => {
68 stream.set_write_timeout(Some(self.config.get_tcp_write_to()))?;
69 stream.set_read_timeout(Some(self.config.get_tcp_resp_to()))?;
70
71 let ts: u64 = SystemTime::now()
72 .duration_since(UNIX_EPOCH)
73 .unwrap()
74 .as_secs();
75 let data_to_send = self.encrypt_serialize(ts)?;
76 stream.write_all(&data_to_send)?;
77 let mut data_res = Vec::new();
78 Snd::read_message(&mut stream, &mut data_res)?;
79 let msg_dec: ReCmdMsg = self.deserialize_decrypt(&data_res)?;
80 drop(stream);
81
82 if msg_dec.hdr.msg_type == ReCmdMsgType::DirectCmdRes {
83 if let ReCmdMsgPayload::DirectCmdRes {
84 ts: ts_dec,
85 m: m_dec,
86 ..
87 } = &msg_dec.payload
88 {
89 if ts == *ts_dec {
90 Ok(m_dec.to_vec())
91 } else {
92 Err(SndError::TcpError)
93 }
94 } else {
95 Err(SndError::TcpError)
96 }
97 } else {
98 Err(SndError::TcpError)
99 }
100 }
101 Err(_) => Err(SndError::TcpError),
102 }
103 }
104
105 fn read_message(stream: &mut TcpStream, buf: &mut Vec<u8>) -> Result<usize, std::io::Error> {
106 let mut hdrdata = [0u8; HDR_LEN_ON_WIRE];
107 stream.read_exact(&mut hdrdata)?;
108
109 match Message::parse_hdr(&hdrdata) {
110 Ok((_, (_, len, _))) => {
111 let len = len.try_into();
112
113 match len {
114 Ok(len) => {
115 let mut payloaddata: Vec<u8> = vec![0; len];
116 let npayload = stream.read(&mut payloaddata)?;
117
118 if npayload == len {
119 buf.append(&mut hdrdata.to_vec());
120 buf.append(&mut payloaddata);
121 Ok(HDR_LEN_ON_WIRE + npayload)
122 } else {
123 Err(std::io::Error::new(
124 std::io::ErrorKind::Other,
125 "Payload too short",
126 ))
127 }
128 }
129 _ => Err(std::io::Error::new(
130 std::io::ErrorKind::Other,
131 "Conversion error",
132 )),
133 }
134 }
135 _ => Err(std::io::Error::new(
136 std::io::ErrorKind::Other,
137 "Hdr decoding error",
138 )),
139 }
140 }
141
142 fn deserialize_decrypt(&self, data: &[u8]) -> Result<ReCmdMsg, MessageError> {
143 let cipher = Box::new(Crypt::new(self.config.get_key()));
144 let msg = Message::new(cipher);
145 let mut msg_enc = BytesMut::with_capacity(0);
146 msg_enc.extend_from_slice(data);
147 msg.deserialize_decrypt(&msg_enc)
148 }
149
150 fn encrypt_serialize(&self, ts: u64) -> Result<Vec<u8>, std::io::Error> {
151 let cipher = Box::new(Crypt::new(self.config.get_key()));
152 let msg = Message::new(cipher);
153 if let Ok(msg_enc) = msg.encrypt_serialize(ReCmdMsgType::DirectCmdReq, &self.data, ts) {
154 Ok(msg_enc.to_vec())
155 } else {
156 Err(std::io::Error::new(
157 std::io::ErrorKind::Other,
158 "Encrypt error",
159 ))
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::Snd;
167 use crate::message::{ReCmdMsg, ReCmdMsgPayload, ReCmdMsgType};
168 use std::net::{IpAddr, Ipv4Addr};
169 use std::time::{SystemTime, UNIX_EPOCH};
170
171 #[test]
172 fn encrypt_serialize_deserialize_decrypt() {
173 let cmd_str = "echo test";
174 let ts: u64 = SystemTime::now()
175 .duration_since(UNIX_EPOCH)
176 .unwrap()
177 .as_secs();
178 let snd = Snd::new(
179 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
180 6666u16,
181 cmd_str.as_bytes().to_vec(),
182 );
183
184 let data = snd.encrypt_serialize(ts).unwrap();
185 let msg_dec: ReCmdMsg = snd.deserialize_decrypt(&data).unwrap();
186
187 assert_eq!(msg_dec.hdr.msg_type, ReCmdMsgType::DirectCmdReq);
188 if let ReCmdMsgPayload::DirectCmdReq { ts: ts_dec, .. } = &msg_dec.payload {
189 assert_eq!(ts, *ts_dec);
190 }
191 }
192}