use crate::copyterm::{self, TermBuf};
use crate::machine::Machine;
use crate::render::RenderedSolution;
use std::io::{self, Write};
pub struct Envelope<'a> {
pub count: usize,
pub exhausted: bool,
pub solutions: &'a [RenderedSolution],
pub program_output: Option<&'a str>,
pub atoms: Option<&'a [String]>,
}
impl<'a> Envelope<'a> {
pub fn from_machine(m: &'a Machine, exhausted: bool) -> Self {
Self {
count: m.solutions.len(),
exhausted,
solutions: &m.solutions,
program_output: m.captured_output(),
atoms: None,
}
}
}
pub enum WireError {
Parse(String),
Runtime(String),
}
#[repr(C)]
pub struct EncoderDesc {
pub name: &'static str,
pub write_envelope: fn(&mut dyn Write, &Machine, &Envelope) -> io::Result<()>,
pub write_error: fn(&mut dyn Write, &WireError) -> io::Result<()>,
pub can_stream: fn() -> bool,
}
impl EncoderDesc {
pub unsafe fn find(
caps: *const *const EncoderDesc,
len: usize,
name: &str,
) -> Option<&'static EncoderDesc> {
let slice = unsafe { std::slice::from_raw_parts(caps, len) };
for &p in slice {
let d = unsafe { &*p };
if d.name == name {
return Some(d);
}
}
None
}
}
fn text_write_envelope(w: &mut dyn Write, _m: &Machine, e: &Envelope) -> io::Result<()> {
if e.solutions.is_empty() {
return w.write_all(b"false.\n");
}
for sol in e.solutions {
if sol.bindings.is_empty() {
w.write_all(b"true.\n")?;
continue;
}
for b in &sol.bindings {
writeln!(w, "{} = {}", b.name, b.text)?;
}
}
Ok(())
}
fn text_write_error(w: &mut dyn Write, err: &WireError) -> io::Result<()> {
let msg = match err {
WireError::Parse(m) | WireError::Runtime(m) => m,
};
writeln!(w, "error: {msg}")
}
const fn text_can_stream() -> bool {
true
}
#[unsafe(no_mangle)]
pub static PLG_ENC_TEXT: EncoderDesc = EncoderDesc {
name: "text",
write_envelope: text_write_envelope,
write_error: text_write_error,
can_stream: text_can_stream,
};
fn serialize_termbuf(tb: &TermBuf) -> Vec<u8> {
let mut out = Vec::with_capacity(13 + tb.cells.len() * 8);
out.push(0x01); out.extend_from_slice(&(tb.cells.len() as u32).to_le_bytes());
out.extend_from_slice(&tb.root.to_le_bytes());
for c in &tb.cells {
out.extend_from_slice(&c.to_le_bytes());
}
out
}
const T_STRING: u8 = 0x02;
const T_DOCUMENT: u8 = 0x03;
const T_ARRAY: u8 = 0x04;
const T_BINARY: u8 = 0x05;
const T_BOOL: u8 = 0x08;
const T_INT32: u8 = 0x10;
const T_INT64: u8 = 0x12;
fn bson_cstring(buf: &mut Vec<u8>, s: &str) {
buf.extend_from_slice(s.as_bytes());
buf.push(0x00);
}
fn bson_doc_begin(buf: &mut Vec<u8>) -> usize {
let start = buf.len();
buf.extend_from_slice(&[0; 4]); start
}
fn bson_doc_end(buf: &mut Vec<u8>, start: usize) {
buf.push(0x00); let len = i32::try_from(buf.len() - start).expect("bson doc < 2GB");
buf[start..start + 4].copy_from_slice(&len.to_le_bytes());
}
fn bson_atoms_array(buf: &mut Vec<u8>, names: &[String]) {
buf.push(T_ARRAY);
bson_cstring(buf, "atoms");
let arr = bson_doc_begin(buf);
for (i, name) in names.iter().enumerate() {
buf.push(T_STRING);
bson_cstring(buf, &i.to_string());
let len = i32::try_from(name.len() + 1).expect("atom name < 2GB");
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(name.as_bytes());
buf.push(0x00);
}
bson_doc_end(buf, arr);
}
pub fn write_atom_map_bson<W: Write>(w: &mut W, m: &Machine) -> io::Result<()> {
let names: Vec<String> = (0..m.atoms.len())
.map(|i| {
m.atoms
.try_resolve(i as u32)
.unwrap_or_default()
.to_string()
})
.collect();
let mut buf = Vec::new();
let doc = bson_doc_begin(&mut buf);
buf.push(T_INT32);
bson_cstring(&mut buf, "count");
buf.extend_from_slice(&(names.len().min(i32::MAX as usize) as i32).to_le_bytes());
bson_atoms_array(&mut buf, &names);
bson_doc_end(&mut buf, doc);
w.write_all(&buf)
}
pub fn write_atom_map_text<W: Write>(w: &mut W, m: &Machine) -> io::Result<()> {
for i in 0..m.atoms.len() {
let name = m.atoms.try_resolve(i as u32).unwrap_or_default();
writeln!(w, "{i}\t{name}")?;
}
Ok(())
}
fn bson_write_envelope(w: &mut dyn Write, m: &Machine, e: &Envelope) -> io::Result<()> {
let mut buf = Vec::new();
let doc = bson_doc_begin(&mut buf);
buf.push(T_INT32);
bson_cstring(&mut buf, "count");
buf.extend_from_slice(&(e.count.min(i32::MAX as usize) as i32).to_le_bytes());
buf.push(T_BOOL);
bson_cstring(&mut buf, "exhausted");
buf.push(if e.exhausted { 0x01 } else { 0x00 });
if let Some(out) = e.program_output {
buf.push(T_STRING);
bson_cstring(&mut buf, "output");
let len = i32::try_from(out.len() + 1).expect("output string < 2GB");
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(out.as_bytes());
buf.push(0x00);
}
if let Some(names) = e.atoms {
bson_atoms_array(&mut buf, names);
}
buf.push(T_ARRAY);
bson_cstring(&mut buf, "solutions");
let arr = bson_doc_begin(&mut buf);
for (i, sol) in e.solutions.iter().enumerate() {
buf.push(T_DOCUMENT);
bson_cstring(&mut buf, &i.to_string());
let sdoc = bson_doc_begin(&mut buf);
for b in &sol.bindings {
let tb = copyterm::copy_to_buf(m, b.word);
let payload = serialize_termbuf(&tb);
buf.push(T_BINARY);
bson_cstring(&mut buf, &b.name);
let len = i32::try_from(payload.len()).expect("termbuf < 2GB");
buf.extend_from_slice(&len.to_le_bytes());
buf.push(0x00); buf.extend_from_slice(&payload);
}
bson_doc_end(&mut buf, sdoc);
}
bson_doc_end(&mut buf, arr);
bson_doc_end(&mut buf, doc);
w.write_all(&buf)
}
fn bson_write_error(w: &mut dyn Write, err: &WireError) -> io::Result<()> {
let msg = match err {
WireError::Parse(m) | WireError::Runtime(m) => m,
};
let mut buf = Vec::new();
let doc = bson_doc_begin(&mut buf);
buf.push(T_STRING);
bson_cstring(&mut buf, "error");
let len = i32::try_from(msg.len() + 1).expect("error message < 2GB");
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(msg.as_bytes());
buf.push(0x00);
bson_doc_end(&mut buf, doc);
w.write_all(&buf)
}
const fn bson_can_stream() -> bool {
false
}
#[unsafe(no_mangle)]
pub static PLG_ENC_BSON: EncoderDesc = EncoderDesc {
name: "bson",
write_envelope: bson_write_envelope,
write_error: bson_write_error,
can_stream: bson_can_stream,
};
#[derive(Debug)]
pub struct ParsedRequest {
pub query: String,
pub limit: Option<usize>,
}
pub fn parse_bson_request(buf: &[u8]) -> Result<ParsedRequest, String> {
if buf.len() < 5 {
return Err("bson request too short".to_string());
}
let total = i32::from_le_bytes(buf[0..4].try_into().unwrap()) as usize;
if total < 5 || total > buf.len() {
return Err(format!(
"bson request length mismatch: declared {total}, have {}",
buf.len()
));
}
let body = &buf[..total];
let mut off = 4; let end = total - 1; let mut query = None;
let mut limit = None;
while off < end {
let ty = body[off];
off += 1;
let (key, after_key) = read_cstring(body, off)?;
off = after_key;
match (ty, key.as_str()) {
(T_STRING, "query") => {
let (s, next) = read_string(body, off)?;
query = Some(s);
off = next;
}
(T_INT32, "limit") => {
let n = read_i32(body, off)?;
limit = Some(n.max(0) as usize);
off += 4;
}
(T_INT64, "limit") => {
let n = read_i64(body, off)?;
limit = Some(n.max(0) as usize);
off += 8;
}
_ => {
off = skip_value(body, off, ty)?;
}
}
}
let query = query.ok_or_else(|| "bson request missing required 'query' string".to_string())?;
Ok(ParsedRequest { query, limit })
}
fn read_cstring(buf: &[u8], mut off: usize) -> Result<(String, usize), String> {
let end = buf[off..]
.iter()
.position(|&b| b == 0)
.ok_or_else(|| "bson key not null-terminated".to_string())?;
let s = std::str::from_utf8(&buf[off..off + end])
.map_err(|_| "bson key not utf-8".to_string())?
.to_string();
off += end + 1;
Ok((s, off))
}
fn read_string(buf: &[u8], off: usize) -> Result<(String, usize), String> {
let n = read_i32(buf, off)? as usize;
if n == 0 || off + 4 + n > buf.len() {
return Err("bson string length out of range".to_string());
}
let s = std::str::from_utf8(&buf[off + 4..off + 4 + n - 1])
.map_err(|_| "bson string not utf-8".to_string())?
.to_string();
Ok((s, off + 4 + n))
}
fn read_i32(buf: &[u8], off: usize) -> Result<i32, String> {
buf[off..]
.get(..4)
.map(|b| i32::from_le_bytes(b.try_into().unwrap()))
.ok_or_else(|| "bson int32 truncated".to_string())
}
fn read_i64(buf: &[u8], off: usize) -> Result<i64, String> {
buf[off..]
.get(..8)
.map(|b| i64::from_le_bytes(b.try_into().unwrap()))
.ok_or_else(|| "bson int64 truncated".to_string())
}
fn skip_value(buf: &[u8], off: usize, ty: u8) -> Result<usize, String> {
match ty {
0x01 => Ok(off + 8), T_STRING => Ok(read_string(buf, off)?.1), T_DOCUMENT | T_ARRAY => {
let n = read_i32(buf, off)? as usize;
Ok(off + n)
}
T_BINARY => {
let n = read_i32(buf, off)? as usize;
Ok(off + 4 + 1 + n)
}
T_BOOL => Ok(off + 1),
0x0A => Ok(off), T_INT32 => Ok(off + 4),
T_INT64 => Ok(off + 8),
_ => Err(format!("bson: cannot skip unknown element type {ty:#x}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cell::{TAG_STR, make, make_atom, make_int, pack_functor, payload, tag_of};
use plg_shared::StringInterner;
use plg_shared::atom::ATOM_NIL;
fn machine() -> Box<Machine> {
Machine::new(StringInterner::new(), Vec::new())
}
fn bytes(f: impl FnOnce(&mut Vec<u8>) -> io::Result<()>) -> Vec<u8> {
let mut buf = Vec::new();
f(&mut buf).unwrap();
buf
}
fn env_with(bindings: Vec<(&str, &str)>) -> Vec<RenderedSolution> {
bindings
.into_iter()
.map(|(n, t)| RenderedSolution {
bindings: vec![crate::render::Binding {
name: n.to_string(),
text: t.to_string(),
word: make_atom(0),
}],
})
.collect()
}
#[test]
fn text_empty_is_false() {
let e = Envelope {
count: 0,
exhausted: true,
solutions: &[],
program_output: None,
atoms: None,
};
assert_eq!(
String::from_utf8(bytes(|w| (PLG_ENC_TEXT.write_envelope)(w, &machine(), &e))).unwrap(),
"false.\n"
);
}
#[test]
fn text_renders_bindings_and_true() {
let sols = env_with(vec![("X", "auth")]);
let empty_sols = vec![RenderedSolution { bindings: vec![] }];
let e1 = Envelope {
count: 1,
exhausted: false,
solutions: &sols,
program_output: None,
atoms: None,
};
assert_eq!(
String::from_utf8(bytes(|w| (PLG_ENC_TEXT.write_envelope)(w, &machine(), &e1)))
.unwrap(),
"X = auth\n"
);
let e2 = Envelope {
count: 1,
exhausted: true,
solutions: &empty_sols,
program_output: None,
atoms: None,
};
assert_eq!(
String::from_utf8(bytes(|w| (PLG_ENC_TEXT.write_envelope)(w, &machine(), &e2)))
.unwrap(),
"true.\n"
);
}
#[test]
fn descriptors_named_and_streaming() {
assert_eq!(PLG_ENC_TEXT.name, "text");
assert_eq!(PLG_ENC_BSON.name, "bson");
assert!((PLG_ENC_TEXT.can_stream)());
assert!(!(PLG_ENC_BSON.can_stream)());
}
#[test]
fn find_locates_advertised_encoders() {
let caps: [*const EncoderDesc; 2] = [&PLG_ENC_TEXT, &PLG_ENC_BSON];
assert_eq!(
unsafe { EncoderDesc::find(caps.as_ptr(), 2, "text") }
.unwrap()
.name,
"text"
);
assert_eq!(
unsafe { EncoderDesc::find(caps.as_ptr(), 2, "bson") }
.unwrap()
.name,
"bson"
);
assert!(unsafe { EncoderDesc::find(caps.as_ptr(), 2, "json") }.is_none());
}
#[test]
fn find_omitted_encoder_is_none() {
let caps: [*const EncoderDesc; 1] = [&PLG_ENC_TEXT];
assert!(unsafe { EncoderDesc::find(caps.as_ptr(), 1, "bson") }.is_none());
}
fn bson_doc_len(buf: &[u8]) -> i32 {
i32::from_le_bytes(buf[0..4].try_into().unwrap())
}
fn assert_valid_bson_doc(buf: &[u8]) {
assert_eq!(
bson_doc_len(buf) as usize,
buf.len(),
"bson doc self-delimits"
);
assert_eq!(
*buf.last().unwrap(),
0x00,
"bson doc ends in null terminator"
);
}
#[test]
fn bson_empty_envelope_self_delimits() {
let m = machine();
let e = Envelope {
count: 0,
exhausted: true,
solutions: &[],
program_output: None,
atoms: None,
};
let buf = bytes(|w| (PLG_ENC_BSON.write_envelope)(w, &m, &e));
assert_valid_bson_doc(&buf);
assert!(contains_key(&buf, b"count"));
assert!(contains_key(&buf, b"exhausted"));
assert!(contains_key(&buf, b"solutions"));
}
#[test]
fn bson_error_document_valid() {
let buf = bytes(|w| (PLG_ENC_BSON.write_error)(w, &WireError::Runtime("boom".into())));
assert_valid_bson_doc(&buf);
assert!(contains_key(&buf, b"error"));
}
fn contains_key(buf: &[u8], key: &[u8]) -> bool {
let mut needle = key.to_vec();
needle.push(0x00);
buf.windows(needle.len()).any(|w| w == needle.as_slice())
}
fn deserialize_termbuf(data: &[u8]) -> TermBuf {
assert_eq!(data[0], 0x01, "format version");
let n = u32::from_le_bytes(data[1..5].try_into().unwrap()) as usize;
let root = u64::from_le_bytes(data[5..13].try_into().unwrap());
let mut cells = Vec::with_capacity(n);
for i in 0..n {
let off = 13 + i * 8;
cells.push(u64::from_le_bytes(data[off..off + 8].try_into().unwrap()));
}
TermBuf { cells, root }
}
#[test]
fn termbuf_framing_roundtrips_scalar_and_cycle() {
let m = machine();
let a = make_atom(7);
let tb = copyterm::copy_to_buf(&m, a);
assert!(tb.cells.is_empty());
let rt = deserialize_termbuf(&serialize_termbuf(&tb));
assert_eq!(rt.root, a);
let mut m = machine();
let x = m.new_var();
let s = {
let i = m.heap.len();
m.heap.push(pack_functor(3, 1));
m.heap.push(x);
make(TAG_STR, i as u64)
};
m.bind(payload(x) as usize, s);
let tb = copyterm::copy_to_buf(&m, s);
let rt = deserialize_termbuf(&serialize_termbuf(&tb));
let restored = copyterm::restore_from_buf(&mut m, &rt);
assert_eq!(tag_of(restored), TAG_STR);
let ri = payload(restored) as usize;
assert_eq!(
m.deref(m.heap[ri + 1]),
restored,
"f(X) arg is the term itself"
);
}
fn req_doc(fields: &[(u8, &str, &[u8])]) -> Vec<u8> {
let mut buf = Vec::new();
let start = bson_doc_begin(&mut buf);
for (ty, key, val) in fields {
buf.push(*ty);
bson_cstring(&mut buf, key);
buf.extend_from_slice(val);
}
bson_doc_end(&mut buf, start);
buf
}
fn bson_int32(n: i32) -> Vec<u8> {
n.to_le_bytes().to_vec()
}
fn bson_str(s: &str) -> Vec<u8> {
let mut v = (s.len() as i32 + 1).to_le_bytes().to_vec();
v.extend_from_slice(s.as_bytes());
v.push(0x00);
v
}
#[test]
fn parses_query_and_int32_limit() {
let doc = req_doc(&[
(T_STRING, "query", &bson_str("p(X)")),
(T_INT32, "limit", &bson_int32(5)),
]);
let r = parse_bson_request(&doc).unwrap();
assert_eq!(r.query, "p(X)");
assert_eq!(r.limit, Some(5));
}
#[test]
fn ignores_unknown_fields() {
let doc = req_doc(&[
(T_STRING, "caller", &bson_str("x")),
(T_STRING, "query", &bson_str("ok")),
]);
assert_eq!(parse_bson_request(&doc).unwrap().query, "ok");
}
#[test]
fn missing_query_is_an_error() {
let doc = req_doc(&[(T_INT32, "limit", &bson_int32(3))]);
assert!(
parse_bson_request(&doc)
.unwrap_err()
.contains("missing required 'query'")
);
}
#[test]
fn _keep_imports_used() {
let _ = make_int(1);
let _ = ATOM_NIL;
}
}