1use crate::TcpConnection;
2use std::cell::RefCell;
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Read, Write};
5use std::net::{TcpStream, ToSocketAddrs};
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::Duration;
8
9const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
10const IO_TIMEOUT: Duration = Duration::from_secs(30);
11const BODY_LIMIT: usize = 10 * 1024 * 1024;
12const MAX_CONNECTIONS: usize = 256;
13
14static NEXT_ID: AtomicU64 = AtomicU64::new(1);
15
16thread_local! {
17 static CONNECTIONS: RefCell<HashMap<String, BufReader<TcpStream>>> =
18 RefCell::new(HashMap::new());
19}
20
21pub fn connect(host: &str, port: i64) -> Result<TcpConnection, String> {
22 let count = CONNECTIONS.with(|map| map.borrow().len());
23 if count >= MAX_CONNECTIONS {
24 return Err(format!(
25 "Tcp.connect: connection limit reached ({} max)",
26 MAX_CONNECTIONS
27 ));
28 }
29
30 let socket_addr = resolve(&format!("{}:{}", host, port))?;
31 let stream =
32 TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
33 stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
34 stream.set_write_timeout(Some(IO_TIMEOUT)).ok();
35
36 let id = format!("tcp-{}", NEXT_ID.fetch_add(1, Ordering::Relaxed));
37 CONNECTIONS.with(|map| {
38 map.borrow_mut().insert(id.clone(), BufReader::new(stream));
39 });
40
41 Ok(TcpConnection {
42 id,
43 host: host.to_string(),
44 port,
45 })
46}
47
48pub fn write_line(conn: &TcpConnection, line: &str) -> Result<(), String> {
49 CONNECTIONS.with(|map| {
50 let mut borrow = map.borrow_mut();
51 match borrow.get_mut(&conn.id) {
52 None => Err(format!("Tcp.writeLine: unknown connection '{}'", conn.id)),
53 Some(reader) => {
54 let msg = format!("{}\r\n", line);
55 reader
56 .get_mut()
57 .write_all(msg.as_bytes())
58 .map_err(|e| e.to_string())
59 }
60 }
61 })
62}
63
64pub fn read_line(conn: &TcpConnection) -> Result<String, String> {
65 CONNECTIONS.with(|map| {
66 let mut borrow = map.borrow_mut();
67 match borrow.get_mut(&conn.id) {
68 None => Err(format!("Tcp.readLine: unknown connection '{}'", conn.id)),
69 Some(reader) => {
70 let mut line = String::new();
71 reader.read_line(&mut line).map_err(|e| e.to_string())?;
72 if line.ends_with('\n') {
73 line.pop();
74 if line.ends_with('\r') {
75 line.pop();
76 }
77 }
78 Ok(line)
79 }
80 }
81 })
82}
83
84pub fn close(conn: &TcpConnection) -> Result<(), String> {
85 let removed = CONNECTIONS.with(|map| map.borrow_mut().remove(&conn.id));
86 match removed {
87 Some(_) => Ok(()),
88 None => Err(format!("Tcp.close: unknown connection '{}'", conn.id)),
89 }
90}
91
92pub fn send(host: &str, port: i64, message: &str) -> Result<String, String> {
93 let socket_addr = resolve(&format!("{}:{}", host, port))?;
94 let mut stream =
95 TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
96 stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
97 stream.set_write_timeout(Some(IO_TIMEOUT)).ok();
98 stream
99 .write_all(message.as_bytes())
100 .map_err(|e| e.to_string())?;
101 stream.shutdown(std::net::Shutdown::Write).ok();
102
103 let mut buf = Vec::new();
104 Read::by_ref(&mut stream)
105 .take(BODY_LIMIT as u64 + 1)
106 .read_to_end(&mut buf)
107 .map_err(|e| e.to_string())?;
108 if buf.len() > BODY_LIMIT {
109 return Err("Tcp.send: response exceeds 10 MB limit".to_string());
110 }
111 Ok(String::from_utf8_lossy(&buf).into_owned())
112}
113
114pub fn ping(host: &str, port: i64) -> Result<(), String> {
115 let socket_addr = resolve(&format!("{}:{}", host, port))?;
116 TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
117 Ok(())
118}
119
120fn resolve(addr: &str) -> Result<std::net::SocketAddr, String> {
121 addr.to_socket_addrs()
122 .map_err(|e| format!("Tcp: DNS resolution failed for {}: {}", addr, e))?
123 .next()
124 .ok_or_else(|| format!("Tcp: no address found for {}", addr))
125}