Skip to main content

aver_rt/
tcp.rs

1use crate::TcpConnection;
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Read, Write};
5use std::net::{TcpStream, ToSocketAddrs};
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::Duration;
8
9const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
10const IO_TIMEOUT: Duration = Duration::from_secs(30);
11const BODY_LIMIT: usize = 10 * 1024 * 1024;
12const MAX_CONNECTIONS: usize = 256;
13
14static NEXT_ID: AtomicU64 = AtomicU64::new(1);
15
16thread_local! {
17    static CONNECTIONS: RefCell<HashMap<String, BufReader<TcpStream>>> =
18        RefCell::new(HashMap::new());
19}
20
21pub fn connect(host: &str, port: i64) -> Result<TcpConnection, String> {
22    let count = CONNECTIONS.with(|map| map.borrow().len());
23    if count >= MAX_CONNECTIONS {
24        return Err(format!(
25            "Tcp.connect: connection limit reached ({} max)",
26            MAX_CONNECTIONS
27        ));
28    }
29
30    let socket_addr = resolve(&format!("{}:{}", host, port))?;
31    let stream =
32        TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
33    stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
34    stream.set_write_timeout(Some(IO_TIMEOUT)).ok();
35
36    let id = format!("tcp-{}", NEXT_ID.fetch_add(1, Ordering::Relaxed));
37    CONNECTIONS.with(|map| {
38        map.borrow_mut().insert(id.clone(), BufReader::new(stream));
39    });
40
41    Ok(TcpConnection {
42        id,
43        host: host.to_string(),
44        port,
45    })
46}
47
48pub fn write_line(conn: &TcpConnection, line: &str) -> Result<(), String> {
49    CONNECTIONS.with(|map| {
50        let mut borrow = map.borrow_mut();
51        match borrow.get_mut(&conn.id) {
52            None => Err(format!("Tcp.writeLine: unknown connection '{}'", conn.id)),
53            Some(reader) => {
54                let msg = format!("{}\r\n", line);
55                reader
56                    .get_mut()
57                    .write_all(msg.as_bytes())
58                    .map_err(|e| e.to_string())
59            }
60        }
61    })
62}
63
64pub fn read_line(conn: &TcpConnection) -> Result<String, String> {
65    CONNECTIONS.with(|map| {
66        let mut borrow = map.borrow_mut();
67        match borrow.get_mut(&conn.id) {
68            None => Err(format!("Tcp.readLine: unknown connection '{}'", conn.id)),
69            Some(reader) => {
70                let mut line = String::new();
71                reader.read_line(&mut line).map_err(|e| e.to_string())?;
72                if line.ends_with('\n') {
73                    line.pop();
74                    if line.ends_with('\r') {
75                        line.pop();
76                    }
77                }
78                Ok(line)
79            }
80        }
81    })
82}
83
84pub fn close(conn: &TcpConnection) -> Result<(), String> {
85    let removed = CONNECTIONS.with(|map| map.borrow_mut().remove(&conn.id));
86    match removed {
87        Some(_) => Ok(()),
88        None => Err(format!("Tcp.close: unknown connection '{}'", conn.id)),
89    }
90}
91
92pub fn send(host: &str, port: i64, message: &str) -> Result<String, String> {
93    let socket_addr = resolve(&format!("{}:{}", host, port))?;
94    let mut stream =
95        TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
96    stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
97    stream.set_write_timeout(Some(IO_TIMEOUT)).ok();
98    stream
99        .write_all(message.as_bytes())
100        .map_err(|e| e.to_string())?;
101    stream.shutdown(std::net::Shutdown::Write).ok();
102
103    let mut buf = Vec::new();
104    Read::by_ref(&mut stream)
105        .take(BODY_LIMIT as u64 + 1)
106        .read_to_end(&mut buf)
107        .map_err(|e| e.to_string())?;
108    if buf.len() > BODY_LIMIT {
109        return Err("Tcp.send: response exceeds 10 MB limit".to_string());
110    }
111    Ok(String::from_utf8_lossy(&buf).into_owned())
112}
113
114pub fn ping(host: &str, port: i64) -> Result<(), String> {
115    let socket_addr = resolve(&format!("{}:{}", host, port))?;
116    TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
117    Ok(())
118}
119
120fn resolve(addr: &str) -> Result<std::net::SocketAddr, String> {
121    addr.to_socket_addrs()
122        .map_err(|e| format!("Tcp: DNS resolution failed for {}: {}", addr, e))?
123        .next()
124        .ok_or_else(|| format!("Tcp: no address found for {}", addr))
125}