harn-vm 0.8.62

Async bytecode virtual machine for the Harn programming language
Documentation
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;

use crate::stdlib::macros::{harn_builtin, VmBuiltinDef};
use crate::value::{VmError, VmValue};
use crate::vm::Vm;

const DEFAULT_TIMEOUT_MS: u64 = 1_000;
const DEFAULT_MAX_RESPONSE_BYTES: usize = 1_048_576;

pub(crate) fn register_net_builtins(vm: &mut Vm) {
    for def in MODULE_BUILTINS {
        vm.register_builtin_def(def);
    }
}

pub(crate) const MODULE_BUILTINS: &[&VmBuiltinDef] = &[&NET_UNIX_SOCKET_JSON_REQUEST_IMPL_DEF];

#[derive(Clone, Copy)]
struct UnixSocketOptions {
    timeout: Duration,
    max_response_bytes: usize,
}

fn option_int(options: Option<&BTreeMap<String, VmValue>>, key: &str, default: i64) -> i64 {
    options
        .and_then(|opts| opts.get(key))
        .and_then(VmValue::as_int)
        .unwrap_or(default)
}

fn unix_socket_options(value: Option<&VmValue>) -> UnixSocketOptions {
    let options = value.and_then(VmValue::as_dict);
    let timeout_ms = option_int(options, "timeout_ms", DEFAULT_TIMEOUT_MS as i64).max(1) as u64;
    let max_response_bytes = option_int(
        options,
        "max_response_bytes",
        DEFAULT_MAX_RESPONSE_BYTES as i64,
    )
    .max(1) as usize;
    UnixSocketOptions {
        timeout: Duration::from_millis(timeout_ms),
        max_response_bytes,
    }
}

fn insert_string(dict: &mut BTreeMap<String, VmValue>, key: &str, value: impl Into<String>) {
    dict.insert(key.to_string(), VmValue::String(Arc::from(value.into())));
}

fn insert_int(dict: &mut BTreeMap<String, VmValue>, key: &str, value: i64) {
    dict.insert(key.to_string(), VmValue::Int(value));
}

fn elapsed_ms(started_ms: i64) -> i64 {
    crate::stdlib::clock::now_monotonic_ms().saturating_sub(started_ms)
}

fn socket_result(
    started_ms: i64,
    path: &str,
    ok: bool,
    status: &str,
    error: Option<String>,
    raw_response: Option<String>,
    response: Option<VmValue>,
) -> VmValue {
    let mut out = BTreeMap::new();
    out.insert("ok".to_string(), VmValue::Bool(ok));
    insert_string(&mut out, "status", status);
    insert_string(&mut out, "path", path);
    insert_int(&mut out, "duration_ms", elapsed_ms(started_ms));
    if let Some(error) = error {
        insert_string(&mut out, "error", error);
    }
    if let Some(raw_response) = raw_response {
        insert_int(&mut out, "bytes_read", raw_response.len() as i64);
        insert_string(&mut out, "raw_response", raw_response);
    }
    if let Some(response) = response {
        out.insert("response".to_string(), response);
    }
    VmValue::Dict(Arc::new(out))
}

#[cfg(unix)]
fn io_status(error: &std::io::Error) -> &'static str {
    use std::io;

    match error.kind() {
        io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock => "timeout",
        _ => "io_error",
    }
}

#[cfg(unix)]
fn unix_socket_json_request(
    path: &str,
    request_json: &str,
    options: UnixSocketOptions,
    started_ms: i64,
) -> VmValue {
    use std::io::{BufRead, BufReader, Read, Write};
    use std::os::unix::net::UnixStream;

    let mut stream = match UnixStream::connect(path) {
        Ok(stream) => stream,
        Err(error) => {
            return socket_result(
                started_ms,
                path,
                false,
                "connect_error",
                Some(error.to_string()),
                None,
                None,
            );
        }
    };
    let _ = stream.set_read_timeout(Some(options.timeout));
    let _ = stream.set_write_timeout(Some(options.timeout));

    if let Err(error) = stream.write_all(request_json.as_bytes()) {
        return socket_result(
            started_ms,
            path,
            false,
            io_status(&error),
            Some(error.to_string()),
            None,
            None,
        );
    }
    if let Err(error) = stream.write_all(b"\n") {
        return socket_result(
            started_ms,
            path,
            false,
            io_status(&error),
            Some(error.to_string()),
            None,
            None,
        );
    }
    if let Err(error) = stream.flush() {
        return socket_result(
            started_ms,
            path,
            false,
            io_status(&error),
            Some(error.to_string()),
            None,
            None,
        );
    }

    let mut reader = BufReader::new(stream).take((options.max_response_bytes as u64) + 1);
    let mut response_bytes = Vec::new();
    match reader.read_until(b'\n', &mut response_bytes) {
        Ok(0) => {
            return socket_result(
                started_ms,
                path,
                false,
                "eof",
                Some("socket closed before a response line was received".to_string()),
                None,
                None,
            );
        }
        Ok(_) => {}
        Err(error) => {
            return socket_result(
                started_ms,
                path,
                false,
                io_status(&error),
                Some(error.to_string()),
                None,
                None,
            );
        }
    }
    if response_bytes.len() > options.max_response_bytes {
        return socket_result(
            started_ms,
            path,
            false,
            "response_too_large",
            Some(format!(
                "response exceeded max_response_bytes={}",
                options.max_response_bytes
            )),
            None,
            None,
        );
    }

    while matches!(response_bytes.last(), Some(b'\n' | b'\r')) {
        response_bytes.pop();
    }
    let raw_response = match String::from_utf8(response_bytes) {
        Ok(text) => text,
        Err(error) => {
            return socket_result(
                started_ms,
                path,
                false,
                "invalid_utf8",
                Some(error.to_string()),
                None,
                None,
            );
        }
    };
    let parsed = match serde_json::from_str::<serde_json::Value>(&raw_response) {
        Ok(value) => crate::schema::json_to_vm_value(&value),
        Err(error) => {
            return socket_result(
                started_ms,
                path,
                false,
                "invalid_json",
                Some(error.to_string()),
                Some(raw_response),
                None,
            );
        }
    };
    socket_result(
        started_ms,
        path,
        true,
        "ok",
        None,
        Some(raw_response),
        Some(parsed),
    )
}

#[cfg(not(unix))]
fn unix_socket_json_request(
    path: &str,
    _request_json: &str,
    options: UnixSocketOptions,
    started_ms: i64,
) -> VmValue {
    let _ = (options.timeout, options.max_response_bytes);
    socket_result(
        started_ms,
        path,
        false,
        "unsupported",
        Some("Unix domain sockets are not supported on this target".to_string()),
        None,
        None,
    )
}

#[harn_builtin(
    sig = "__net_unix_socket_json_request(path: string, request: any, options?: dict) -> dict",
    category = "net"
)]
fn net_unix_socket_json_request_impl(
    args: &[VmValue],
    _out: &mut String,
) -> Result<VmValue, VmError> {
    let started_ms = crate::stdlib::clock::now_monotonic_ms();
    let path = args.first().map(VmValue::display).unwrap_or_default();
    if path.is_empty() {
        return Err(VmError::Thrown(VmValue::String(Arc::from(
            "__net_unix_socket_json_request: path is required",
        ))));
    }
    let request = args.get(1).unwrap_or(&VmValue::Nil);
    let request_json = super::json::vm_value_to_json(request);
    let options = unix_socket_options(args.get(2));
    Ok(unix_socket_json_request(
        &path,
        &request_json,
        options,
        started_ms,
    ))
}

#[cfg(all(test, unix))]
mod tests {
    use std::io::{BufRead, BufReader, Write};

    use super::*;

    #[test]
    fn unix_socket_json_request_round_trips_json_line() {
        use std::os::unix::net::UnixListener;

        let dir = tempfile::tempdir().expect("tempdir");
        let socket_path = dir.path().join("daemon.sock");
        let listener = UnixListener::bind(&socket_path).expect("bind unix listener");
        let server = std::thread::spawn(move || {
            let (mut stream, _) = listener.accept().expect("accept");
            let mut request = String::new();
            BufReader::new(stream.try_clone().expect("clone stream"))
                .read_line(&mut request)
                .expect("read request");
            let parsed: serde_json::Value = serde_json::from_str(request.trim()).expect("json");
            assert_eq!(parsed["method"], "ping");
            stream
                .write_all(br#"{"ok":true,"method":"pong"}"#)
                .expect("write response");
            stream.write_all(b"\n").expect("write newline");
        });

        let result = unix_socket_json_request(
            &socket_path.display().to_string(),
            r#"{"method":"ping"}"#,
            UnixSocketOptions {
                timeout: Duration::from_millis(500),
                max_response_bytes: 1024,
            },
            crate::stdlib::clock::now_monotonic_ms(),
        );
        server.join().expect("server thread");

        let dict = result.as_dict().expect("result dict");
        assert!(matches!(dict.get("ok"), Some(VmValue::Bool(true))));
        assert_eq!(
            dict.get("status").map(VmValue::display).as_deref(),
            Some("ok")
        );
        let response = dict
            .get("response")
            .and_then(VmValue::as_dict)
            .expect("response dict");
        assert_eq!(
            response.get("method").map(VmValue::display).as_deref(),
            Some("pong")
        );
    }

    #[test]
    fn unix_socket_json_request_reports_invalid_json() {
        use std::os::unix::net::UnixListener;

        let dir = tempfile::tempdir().expect("tempdir");
        let socket_path = dir.path().join("daemon.sock");
        let listener = UnixListener::bind(&socket_path).expect("bind unix listener");
        let server = std::thread::spawn(move || {
            let (mut stream, _) = listener.accept().expect("accept");
            stream.write_all(b"not-json\n").expect("write response");
        });

        let result = unix_socket_json_request(
            &socket_path.display().to_string(),
            r#"{"method":"ping"}"#,
            UnixSocketOptions {
                timeout: Duration::from_millis(500),
                max_response_bytes: 1024,
            },
            crate::stdlib::clock::now_monotonic_ms(),
        );
        server.join().expect("server thread");

        let dict = result.as_dict().expect("result dict");
        assert!(matches!(dict.get("ok"), Some(VmValue::Bool(false))));
        assert_eq!(
            dict.get("status").map(VmValue::display).as_deref(),
            Some("invalid_json")
        );
        assert_eq!(
            dict.get("raw_response").map(VmValue::display).as_deref(),
            Some("not-json")
        );
    }
}