ctf_tcp_utils/
lib.rs

1use std::io::{Read, Write};
2use std::net::TcpStream;
3use thiserror::Error;
4
5#[derive(Error, Debug)]
6pub enum CtfTcpHandlerError {
7    #[error("Unable to connect to remote")]
8    ConnectionError,
9    #[error("Unable to set timeout")]
10    ConfigurationError,
11    #[error("Read error")]
12    ReadError,
13}
14
15pub struct TcpHandler {
16    stream: TcpStream,
17}
18
19impl TcpHandler {
20    /// Create a new `TcpHandler` wit a default 1s read timeout.
21    ///
22    /// # Errors
23    ///
24    /// May fail if server:port is unavaible or read timeout cannot be set.
25    pub fn new(server_url: &str, port: u16) -> Result<Self, CtfTcpHandlerError> {
26        let stream = {
27            let connection_uri = format!("{server_url}:{port}");
28            TcpStream::connect(connection_uri).map_err(|_| CtfTcpHandlerError::ConnectionError)
29        }?;
30        stream
31            .set_read_timeout(Some(std::time::Duration::from_millis(1000)))
32            .map_err(|_| CtfTcpHandlerError::ConfigurationError)?;
33
34        Ok(Self { stream })
35    }
36
37    /// Create a new `TcpHandler` wit a default 1s read timeout.
38    ///
39    /// # Errors
40    ///
41    /// May fail if server:port is unavaible or read timeout cannot be set.
42    pub fn new_with_timeout(
43        server_url: &str,
44        port: u16,
45        timeout: u64,
46    ) -> Result<Self, CtfTcpHandlerError> {
47        let stream = {
48            let connection_uri = format!("{server_url}:{port}");
49            TcpStream::connect(connection_uri).map_err(|_| CtfTcpHandlerError::ConnectionError)
50        }?;
51        stream
52            .set_read_timeout(Some(std::time::Duration::from_millis(timeout)))
53            .map_err(|_| CtfTcpHandlerError::ConfigurationError)?;
54
55        Ok(Self { stream })
56    }
57
58    /// Read TCP stream until read timeout is reached. Always produces a String using UTF-8 lossy conversion.
59    pub fn read_to_string(&mut self) -> String {
60        let mut res = String::new();
61        let mut buf = vec![0; 4096];
62
63        loop {
64            let size = self.stream.read(&mut buf).unwrap_or(0);
65            if size == 0 {
66                break;
67            }
68            let my_str = std::str::from_utf8(&buf[..size]).unwrap_or_default();
69            res = format!("{res}{my_str}");
70        }
71        res
72    }
73
74    /// Read TCP stream until read timeout is reached. Always produces a String using UTF-8 lossy conversion.
75    pub fn write_answer(&mut self, answer: &str) {
76        let data = format!("{answer}\n");
77        let _size = self.stream.write(data.as_bytes());
78    }
79}
80
81type BoxedFunction = Box<dyn Fn(&str) -> Option<String>>;
82
83/// `CtfLoopResponder` is a Builder pattern like to build a loop responder.
84///
85/// The main function connect to the server and run the same routine on every incoming message.
86pub struct CtfLoopResponder<'a> {
87    url: Option<&'a str>,
88    port: Option<u16>,
89    timeout: Option<u64>,
90    responder_func: Option<BoxedFunction>,
91}
92
93impl<'a> Default for CtfLoopResponder<'a> {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99impl<'a> CtfLoopResponder<'a> {
100    #[must_use]
101    /// Build a new empty `CtfLoopResponder`
102    pub fn new() -> Self {
103        Self {
104            url: None,
105            port: None,
106            timeout: None,
107            responder_func: None,
108        }
109    }
110
111    /// Build a `CtfLoopResponder` for localhost on given port.
112    #[must_use]
113    pub fn localhost(port: u16) -> Self {
114        Self::new().url("localhost").port(port)
115    }
116
117    #[must_use]
118    /// Set url
119    pub fn url(self, url: &'a str) -> Self {
120        Self {
121            url: Some(url),
122            ..self
123        }
124    }
125
126    #[must_use]
127    /// Set port
128    pub fn port(self, port: u16) -> Self {
129        Self {
130            port: Some(port),
131            ..self
132        }
133    }
134
135    #[must_use]
136    /// Set timeout
137    pub fn timeout(self, timeout: u64) -> Self {
138        Self {
139            timeout: Some(timeout),
140            ..self
141        }
142    }
143
144    /// Set the responder routine runned on each server's message.
145    #[must_use]
146    pub fn responder_func(self, responder_func: impl Fn(&str) -> Option<String> + 'static) -> Self {
147        Self {
148            responder_func: Some(Box::new(responder_func)),
149            ..self
150        }
151    }
152
153    /// Connect to the server and use the struct routine to answer each incoming message.
154    ///
155    /// # Errors
156    ///
157    /// The function will fail if either url, port or responder routine is not defined.
158    /// It may also fails if TCP connection fail.
159    pub fn connect_and_work(&self) -> Result<String, CtfTcpHandlerError> {
160        let url = self.url.ok_or(CtfTcpHandlerError::ConfigurationError)?;
161        let port = self.port.ok_or(CtfTcpHandlerError::ConfigurationError)?;
162        let responder = self
163            .responder_func
164            .as_ref()
165            .ok_or(CtfTcpHandlerError::ConfigurationError)?;
166        let mut tcp_handler = self
167            .timeout
168            .map_or_else(
169                || TcpHandler::new(url, port),
170                |timeout| TcpHandler::new_with_timeout(url, port, timeout),
171            )
172            .map_err(|_| CtfTcpHandlerError::ConnectionError)?;
173
174        let mut input = loop {
175            let input = tcp_handler.read_to_string();
176            log::debug!("Received:\n{input}");
177            if let Some(answer) = responder(&input) {
178                log::debug!("Answered: {answer}");
179                tcp_handler.write_answer(&answer);
180            } else {
181                break input;
182            }
183        };
184        input.push_str(&tcp_handler.read_to_string());
185        Ok(input)
186    }
187}