aver-rt 0.4.4

Shared Rust runtime pieces for Aver-generated programs
Documentation
use crate::{AverStr, TcpConnection};
use std::cell::RefCell;
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;

const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const IO_TIMEOUT: Duration = Duration::from_secs(30);
const BODY_LIMIT: usize = 10 * 1024 * 1024;
const MAX_CONNECTIONS: usize = 256;

static NEXT_ID: AtomicU64 = AtomicU64::new(1);

thread_local! {
    static CONNECTIONS: RefCell<HashMap<String, BufReader<TcpStream>>> =
        RefCell::new(HashMap::new());
}

pub fn connect(host: &str, port: i64) -> Result<TcpConnection, String> {
    let count = CONNECTIONS.with(|map| map.borrow().len());
    if count >= MAX_CONNECTIONS {
        return Err(format!(
            "Tcp.connect: connection limit reached ({} max)",
            MAX_CONNECTIONS
        ));
    }

    let socket_addr = resolve(&format!("{}:{}", host, port))?;
    let stream =
        TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
    stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
    stream.set_write_timeout(Some(IO_TIMEOUT)).ok();

    let id = format!("tcp-{}", NEXT_ID.fetch_add(1, Ordering::Relaxed));
    CONNECTIONS.with(|map| {
        map.borrow_mut().insert(id.clone(), BufReader::new(stream));
    });

    Ok(TcpConnection {
        id: AverStr::from(id),
        host: AverStr::from(host),
        port,
    })
}

pub fn write_line(conn: &TcpConnection, line: &str) -> Result<(), String> {
    CONNECTIONS.with(|map| {
        let mut borrow = map.borrow_mut();
        let id: &str = &conn.id;
        match borrow.get_mut(id) {
            None => Err(format!("Tcp.writeLine: unknown connection '{}'", conn.id)),
            Some(reader) => {
                let msg = format!("{}\r\n", line);
                reader
                    .get_mut()
                    .write_all(msg.as_bytes())
                    .map_err(|e| e.to_string())
            }
        }
    })
}

pub fn read_line(conn: &TcpConnection) -> Result<String, String> {
    CONNECTIONS.with(|map| {
        let mut borrow = map.borrow_mut();
        let id: &str = &conn.id;
        match borrow.get_mut(id) {
            None => Err(format!("Tcp.readLine: unknown connection '{}'", conn.id)),
            Some(reader) => {
                let mut line = String::new();
                reader.read_line(&mut line).map_err(|e| e.to_string())?;
                if line.ends_with('\n') {
                    line.pop();
                    if line.ends_with('\r') {
                        line.pop();
                    }
                }
                Ok(line)
            }
        }
    })
}

pub fn close(conn: &TcpConnection) -> Result<(), String> {
    let id: &str = &conn.id;
    let removed = CONNECTIONS.with(|map| map.borrow_mut().remove(id));
    match removed {
        Some(_) => Ok(()),
        None => Err(format!("Tcp.close: unknown connection '{}'", conn.id)),
    }
}

pub fn send(host: &str, port: i64, message: &str) -> Result<String, String> {
    let socket_addr = resolve(&format!("{}:{}", host, port))?;
    let mut stream =
        TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
    stream.set_read_timeout(Some(IO_TIMEOUT)).ok();
    stream.set_write_timeout(Some(IO_TIMEOUT)).ok();
    stream
        .write_all(message.as_bytes())
        .map_err(|e| e.to_string())?;
    stream.shutdown(std::net::Shutdown::Write).ok();

    let mut buf = Vec::new();
    Read::by_ref(&mut stream)
        .take(BODY_LIMIT as u64 + 1)
        .read_to_end(&mut buf)
        .map_err(|e| e.to_string())?;
    if buf.len() > BODY_LIMIT {
        return Err("Tcp.send: response exceeds 10 MB limit".to_string());
    }
    Ok(String::from_utf8_lossy(&buf).into_owned())
}

pub fn ping(host: &str, port: i64) -> Result<(), String> {
    let socket_addr = resolve(&format!("{}:{}", host, port))?;
    TcpStream::connect_timeout(&socket_addr, CONNECT_TIMEOUT).map_err(|e| e.to_string())?;
    Ok(())
}

fn resolve(addr: &str) -> Result<std::net::SocketAddr, String> {
    addr.to_socket_addrs()
        .map_err(|e| format!("Tcp: DNS resolution failed for {}: {}", addr, e))?
        .next()
        .ok_or_else(|| format!("Tcp: no address found for {}", addr))
}