1use crate::message::{
6    Message, MessageError, ReCmdMsg, ReCmdMsgPayload, ReCmdMsgType, HDR_LEN_ON_WIRE,
7};
8use bytes::BytesMut;
9use std::fmt;
10use std::io::{Read, Write};
11use std::net::{IpAddr, SocketAddr, TcpStream};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::config::Config;
15use crate::crypt::Crypt;
16
17#[derive(Debug)]
18pub struct Snd {
19    srv_ip: IpAddr,
20    port: u16,
21    data: Vec<u8>,
22    config: Config,
23}
24
25#[derive(Debug)]
26pub enum SndError {
27    TcpError,
28}
29
30impl std::error::Error for SndError {}
31
32impl fmt::Display for SndError {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        match self {
35            SndError::TcpError => write!(f, "TCP error"),
36        }
37    }
38}
39
40impl From<std::io::Error> for SndError {
41    fn from(_e: std::io::Error) -> Self {
42        SndError::TcpError
43    }
44}
45
46impl From<MessageError> for SndError {
47    fn from(_e: MessageError) -> Self {
48        SndError::TcpError
49    }
50}
51
52impl Snd {
53    pub fn new(srv_ip: IpAddr, port: u16, data: Vec<u8>) -> Self {
54        Snd {
55            srv_ip,
56            port,
57            data,
58            config: Config::init(),
59        }
60    }
61
62    pub fn run(&self) -> Result<Vec<u8>, SndError> {
63        match TcpStream::connect_timeout(
64            &SocketAddr::new(self.srv_ip, self.port),
65            self.config.get_tcp_connect_to(),
66        ) {
67            Ok(mut stream) => {
68                stream.set_write_timeout(Some(self.config.get_tcp_write_to()))?;
69                stream.set_read_timeout(Some(self.config.get_tcp_resp_to()))?;
70
71                let ts: u64 = SystemTime::now()
72                    .duration_since(UNIX_EPOCH)
73                    .unwrap()
74                    .as_secs();
75                let data_to_send = self.encrypt_serialize(ts)?;
76                stream.write_all(&data_to_send)?;
77                let mut data_res = Vec::new();
78                Snd::read_message(&mut stream, &mut data_res)?;
79                let msg_dec: ReCmdMsg = self.deserialize_decrypt(&data_res)?;
80                drop(stream);
81
82                if msg_dec.hdr.msg_type == ReCmdMsgType::DirectCmdRes {
83                    if let ReCmdMsgPayload::DirectCmdRes {
84                        ts: ts_dec,
85                        m: m_dec,
86                        ..
87                    } = &msg_dec.payload
88                    {
89                        if ts == *ts_dec {
90                            Ok(m_dec.to_vec())
91                        } else {
92                            Err(SndError::TcpError)
93                        }
94                    } else {
95                        Err(SndError::TcpError)
96                    }
97                } else {
98                    Err(SndError::TcpError)
99                }
100            }
101            Err(_) => Err(SndError::TcpError),
102        }
103    }
104
105    fn read_message(stream: &mut TcpStream, buf: &mut Vec<u8>) -> Result<usize, std::io::Error> {
106        let mut hdrdata = [0u8; HDR_LEN_ON_WIRE];
107        stream.read_exact(&mut hdrdata)?;
108
109        match Message::parse_hdr(&hdrdata) {
110            Ok((_, (_, len, _))) => {
111                let len = len.try_into();
112
113                match len {
114                    Ok(len) => {
115                        let mut payloaddata: Vec<u8> = vec![0; len];
116                        let npayload = stream.read(&mut payloaddata)?;
117
118                        if npayload == len {
119                            buf.append(&mut hdrdata.to_vec());
120                            buf.append(&mut payloaddata);
121                            Ok(HDR_LEN_ON_WIRE + npayload)
122                        } else {
123                            Err(std::io::Error::new(
124                                std::io::ErrorKind::Other,
125                                "Payload too short",
126                            ))
127                        }
128                    }
129                    _ => Err(std::io::Error::new(
130                        std::io::ErrorKind::Other,
131                        "Conversion error",
132                    )),
133                }
134            }
135            _ => Err(std::io::Error::new(
136                std::io::ErrorKind::Other,
137                "Hdr decoding error",
138            )),
139        }
140    }
141
142    fn deserialize_decrypt(&self, data: &[u8]) -> Result<ReCmdMsg, MessageError> {
143        let cipher = Box::new(Crypt::new(self.config.get_key()));
144        let msg = Message::new(cipher);
145        let mut msg_enc = BytesMut::with_capacity(0);
146        msg_enc.extend_from_slice(data);
147        msg.deserialize_decrypt(&msg_enc)
148    }
149
150    fn encrypt_serialize(&self, ts: u64) -> Result<Vec<u8>, std::io::Error> {
151        let cipher = Box::new(Crypt::new(self.config.get_key()));
152        let msg = Message::new(cipher);
153        if let Ok(msg_enc) = msg.encrypt_serialize(ReCmdMsgType::DirectCmdReq, &self.data, ts) {
154            Ok(msg_enc.to_vec())
155        } else {
156            Err(std::io::Error::new(
157                std::io::ErrorKind::Other,
158                "Encrypt error",
159            ))
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::Snd;
167    use crate::message::{ReCmdMsg, ReCmdMsgPayload, ReCmdMsgType};
168    use std::net::{IpAddr, Ipv4Addr};
169    use std::time::{SystemTime, UNIX_EPOCH};
170
171    #[test]
172    fn encrypt_serialize_deserialize_decrypt() {
173        let cmd_str = "echo test";
174        let ts: u64 = SystemTime::now()
175            .duration_since(UNIX_EPOCH)
176            .unwrap()
177            .as_secs();
178        let snd = Snd::new(
179            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
180            6666u16,
181            cmd_str.as_bytes().to_vec(),
182        );
183
184        let data = snd.encrypt_serialize(ts).unwrap();
185        let msg_dec: ReCmdMsg = snd.deserialize_decrypt(&data).unwrap();
186
187        assert_eq!(msg_dec.hdr.msg_type, ReCmdMsgType::DirectCmdReq);
188        if let ReCmdMsgPayload::DirectCmdReq { ts: ts_dec, .. } = &msg_dec.payload {
189            assert_eq!(ts, *ts_dec);
190        }
191    }
192}