use crate::ast::Literal;
use crate::diagnostics::{Error, Result};
use crate::eval::value::{PrimitiveProcedure, PrimitiveImpl, Value, ThreadSafeEnvironment};
use std::sync::Arc;
pub fn bind_bytevector_operations(env: &Arc<ThreadSafeEnvironment>) {
env.define("make-bytevector".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "make-bytevector".to_string(),
arity_min: 1,
arity_max: Some(2),
implementation: PrimitiveImpl::RustFn(primitive_make_bytevector),
effects: vec![],
})));
env.define("bytevector".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector".to_string(),
arity_min: 0,
arity_max: None,
implementation: PrimitiveImpl::RustFn(primitive_bytevector),
effects: vec![],
})));
env.define("bytevector-copy".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector-copy".to_string(),
arity_min: 1,
arity_max: Some(3),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_copy),
effects: vec![],
})));
env.define("bytevector?".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector?".to_string(),
arity_min: 1,
arity_max: Some(1),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_p),
effects: vec![],
})));
env.define("bytevector-length".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector-length".to_string(),
arity_min: 1,
arity_max: Some(1),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_length),
effects: vec![],
})));
env.define("bytevector-u8-ref".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector-u8-ref".to_string(),
arity_min: 2,
arity_max: Some(2),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_u8_ref),
effects: vec![],
})));
env.define("bytevector-u8-set!".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector-u8-set!".to_string(),
arity_min: 3,
arity_max: Some(3),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_u8_set),
effects: vec![],
})));
env.define("bytevector->list".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector->list".to_string(),
arity_min: 1,
arity_max: Some(3),
implementation: PrimitiveImpl::RustFn(primitive_bytevector_to_list),
effects: vec![],
})));
env.define("list->bytevector".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "list->bytevector".to_string(),
arity_min: 1,
arity_max: Some(1),
implementation: PrimitiveImpl::RustFn(primitive_list_to_bytevector),
effects: vec![],
})));
env.define("string->utf8".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "string->utf8".to_string(),
arity_min: 1,
arity_max: Some(3),
implementation: PrimitiveImpl::RustFn(primitive_string_to_utf8),
effects: vec![],
})));
env.define("utf8->string".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "utf8->string".to_string(),
arity_min: 1,
arity_max: Some(3),
implementation: PrimitiveImpl::RustFn(primitive_utf8_to_string),
effects: vec![],
})));
env.define("bytevector=?".to_string(), Value::Primitive(Arc::new(PrimitiveProcedure {
name: "bytevector=?".to_string(),
arity_min: 0,
arity_max: None,
implementation: PrimitiveImpl::RustFn(primitive_bytevector_equal),
effects: vec![],
})));
}
fn extract_bytevector(value: &Value, operation: &str) -> Result<Vec<u8>> {
match value {
Value::Literal(Literal::Bytevector(bv)) => Ok(bv.clone()),
_ => Err(Box::new(Error::runtime_error(
format!("{operation} requires a bytevector argument"),
None,
))),
}
}
fn extract_bytevector_mut(value: &Value, operation: &str) -> Result<Vec<u8>> {
match value {
Value::Literal(Literal::Bytevector(bv)) => Ok(bv.clone()),
_ => Err(Box::new(Error::runtime_error(
format!("{operation} requires a bytevector argument"),
None,
))),
}
}
fn extract_integer(value: &Value, operation: &str) -> Result<i64> {
match value {
Value::Literal(literal) => {
if let Some(i) = literal.to_i64() {
Ok(i)
} else {
Err(Box::new(Error::runtime_error(
format!("{operation} requires an integer argument"),
None,
)))
}
}
_ => Err(Box::new(Error::runtime_error(
format!("{operation} requires an integer argument"),
None,
))),
}
}
fn extract_non_negative_integer(value: &Value, operation: &str) -> Result<usize> {
let int = extract_integer(value, operation)?;
if int < 0 {
return Err(Box::new(Error::runtime_error(
format!("{operation} requires a non-negative integer"),
None,
)));
}
Ok(int as usize)
}
fn extract_byte(value: &Value, operation: &str) -> Result<u8> {
let int = extract_integer(value, operation)?;
if !(0..=255).contains(&int) {
return Err(Box::new(Error::runtime_error(
format!("{operation} requires a byte value (0-255)"),
None,
)));
}
Ok(int as u8)
}
fn extract_string(value: &Value, operation: &str) -> Result<String> {
match value {
Value::Literal(Literal::String(s)) => Ok(s.clone()),
_ => Err(Box::new(Error::runtime_error(
format!("{operation} requires a string argument"),
None,
))),
}
}
pub fn primitive_make_bytevector(args: &[Value]) -> Result<Value> {
match args.len() {
1 => {
let k = extract_non_negative_integer(&args[0], "make-bytevector")?;
Ok(Value::bytevector(vec![0; k]))
}
2 => {
let k = extract_non_negative_integer(&args[0], "make-bytevector")?;
let byte = extract_byte(&args[1], "make-bytevector")?;
Ok(Value::bytevector(vec![byte; k]))
}
_ => Err(Box::new(Error::runtime_error(
format!("make-bytevector expects 1 or 2 arguments, got {}", args.len()),
None,
)))
}
}
pub fn primitive_bytevector(args: &[Value]) -> Result<Value> {
let mut bytes = Vec::new();
for (i, arg) in args.iter().enumerate() {
let byte = extract_byte(arg, &format!("bytevector (argument {})", i + 1))?;
bytes.push(byte);
}
Ok(Value::bytevector(bytes))
}
pub fn primitive_bytevector_copy(args: &[Value]) -> Result<Value> {
match args.len() {
1 => {
let bv = extract_bytevector(&args[0], "bytevector-copy")?;
Ok(Value::bytevector(bv))
}
2 => {
let bv = extract_bytevector(&args[0], "bytevector-copy")?;
let start = extract_non_negative_integer(&args[1], "bytevector-copy")?;
if start > bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector-copy: start index out of bounds".to_string(),
None,
)));
}
Ok(Value::bytevector(bv[start..].to_vec()))
}
3 => {
let bv = extract_bytevector(&args[0], "bytevector-copy")?;
let start = extract_non_negative_integer(&args[1], "bytevector-copy")?;
let end = extract_non_negative_integer(&args[2], "bytevector-copy")?;
if start > end || end > bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector-copy: invalid start/end indices".to_string(),
None,
)));
}
Ok(Value::bytevector(bv[start..end].to_vec()))
}
_ => Err(Box::new(Error::runtime_error(
format!("bytevector-copy expects 1 to 3 arguments, got {}", args.len()),
None,
)))
}
}
pub fn primitive_bytevector_p(args: &[Value]) -> Result<Value> {
if args.len() != 1 {
return Err(Box::new(Error::runtime_error(
format!("bytevector? expects 1 argument, got {}", args.len()),
None,
)));
}
let result = matches!(args[0], Value::Literal(Literal::Bytevector(_)));
Ok(Value::boolean(result))
}
pub fn primitive_bytevector_length(args: &[Value]) -> Result<Value> {
if args.len() != 1 {
return Err(Box::new(Error::runtime_error(
format!("bytevector-length expects 1 argument, got {}", args.len()),
None,
)));
}
let bv = extract_bytevector(&args[0], "bytevector-length")?;
Ok(Value::integer(bv.len() as i64))
}
pub fn primitive_bytevector_u8_ref(args: &[Value]) -> Result<Value> {
if args.len() != 2 {
return Err(Box::new(Error::runtime_error(
format!("bytevector-u8-ref expects 2 arguments, got {}", args.len()),
None,
)));
}
let bv = extract_bytevector(&args[0], "bytevector-u8-ref")?;
let k = extract_non_negative_integer(&args[1], "bytevector-u8-ref")?;
if k >= bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector-u8-ref: index out of bounds".to_string(),
None,
)));
}
Ok(Value::integer(bv[k] as i64))
}
pub fn primitive_bytevector_u8_set(args: &[Value]) -> Result<Value> {
if args.len() != 3 {
return Err(Box::new(Error::runtime_error(
format!("bytevector-u8-set! expects 3 arguments, got {}", args.len()),
None,
)));
}
let mut bv = extract_bytevector_mut(&args[0], "bytevector-u8-set!")?;
let k = extract_non_negative_integer(&args[1], "bytevector-u8-set!")?;
let byte = extract_byte(&args[2], "bytevector-u8-set!")?;
if k >= bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector-u8-set!: index out of bounds".to_string(),
None,
)));
}
bv[k] = byte;
Ok(Value::Unspecified)
}
pub fn primitive_bytevector_to_list(args: &[Value]) -> Result<Value> {
match args.len() {
1 => {
let bv = extract_bytevector(&args[0], "bytevector->list")?;
let values: Vec<Value> = bv.iter().map(|&b| Value::integer(b as i64)).collect();
Ok(Value::list(values))
}
2 => {
let bv = extract_bytevector(&args[0], "bytevector->list")?;
let start = extract_non_negative_integer(&args[1], "bytevector->list")?;
if start > bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector->list: start index out of bounds".to_string(),
None,
)));
}
let values: Vec<Value> = bv[start..].iter().map(|&b| Value::integer(b as i64)).collect();
Ok(Value::list(values))
}
3 => {
let bv = extract_bytevector(&args[0], "bytevector->list")?;
let start = extract_non_negative_integer(&args[1], "bytevector->list")?;
let end = extract_non_negative_integer(&args[2], "bytevector->list")?;
if start > end || end > bv.len() {
return Err(Box::new(Error::runtime_error(
"bytevector->list: invalid start/end indices".to_string(),
None,
)));
}
let values: Vec<Value> = bv[start..end].iter().map(|&b| Value::integer(b as i64)).collect();
Ok(Value::list(values))
}
_ => Err(Box::new(Error::runtime_error(
format!("bytevector->list expects 1 to 3 arguments, got {}", args.len()),
None,
)))
}
}
pub fn primitive_list_to_bytevector(args: &[Value]) -> Result<Value> {
if args.len() != 1 {
return Err(Box::new(Error::runtime_error(
format!("list->bytevector expects 1 argument, got {}", args.len()),
None,
)));
}
let list = args[0].as_list().ok_or_else(|| Error::runtime_error(
"list->bytevector requires a list argument".to_string(),
None,
))?;
let mut bytes = Vec::new();
for (i, value) in list.iter().enumerate() {
let byte = extract_byte(value, &format!("list->bytevector (element {i})"))?;
bytes.push(byte);
}
Ok(Value::bytevector(bytes))
}
pub fn primitive_string_to_utf8(args: &[Value]) -> Result<Value> {
match args.len() {
1 => {
let s = extract_string(&args[0], "string->utf8")?;
Ok(Value::bytevector(s.into_bytes()))
}
2 => {
let s = extract_string(&args[0], "string->utf8")?;
let start = extract_non_negative_integer(&args[1], "string->utf8")?;
let chars: Vec<char> = s.chars().collect();
if start > chars.len() {
return Err(Box::new(Error::runtime_error(
"string->utf8: start index out of bounds".to_string(),
None,
)));
}
let substring: String = chars[start..].iter().collect();
Ok(Value::bytevector(substring.into_bytes()))
}
3 => {
let s = extract_string(&args[0], "string->utf8")?;
let start = extract_non_negative_integer(&args[1], "string->utf8")?;
let end = extract_non_negative_integer(&args[2], "string->utf8")?;
let chars: Vec<char> = s.chars().collect();
if start > end || end > chars.len() {
return Err(Box::new(Error::runtime_error(
"string->utf8: invalid start/end indices".to_string(),
None,
)));
}
let substring: String = chars[start..end].iter().collect();
Ok(Value::bytevector(substring.into_bytes()))
}
_ => Err(Box::new(Error::runtime_error(
format!("string->utf8 expects 1 to 3 arguments, got {}", args.len()),
None,
)))
}
}
pub fn primitive_utf8_to_string(args: &[Value]) -> Result<Value> {
match args.len() {
1 => {
let bv = extract_bytevector(&args[0], "utf8->string")?;
let s = String::from_utf8(bv).map_err(|_| Error::runtime_error(
"utf8->string: invalid UTF-8 sequence".to_string(),
None,
))?;
Ok(Value::string(s))
}
2 => {
let bv = extract_bytevector(&args[0], "utf8->string")?;
let start = extract_non_negative_integer(&args[1], "utf8->string")?;
if start > bv.len() {
return Err(Box::new(Error::runtime_error(
"utf8->string: start index out of bounds".to_string(),
None,
)));
}
let s = String::from_utf8(bv[start..].to_vec()).map_err(|_| Error::runtime_error(
"utf8->string: invalid UTF-8 sequence".to_string(),
None,
))?;
Ok(Value::string(s))
}
3 => {
let bv = extract_bytevector(&args[0], "utf8->string")?;
let start = extract_non_negative_integer(&args[1], "utf8->string")?;
let end = extract_non_negative_integer(&args[2], "utf8->string")?;
if start > end || end > bv.len() {
return Err(Box::new(Error::runtime_error(
"utf8->string: invalid start/end indices".to_string(),
None,
)));
}
let s = String::from_utf8(bv[start..end].to_vec()).map_err(|_| Error::runtime_error(
"utf8->string: invalid UTF-8 sequence".to_string(),
None,
))?;
Ok(Value::string(s))
}
_ => Err(Box::new(Error::runtime_error(
format!("utf8->string expects 1 to 3 arguments, got {}", args.len()),
None,
)))
}
}
pub fn primitive_bytevector_equal(args: &[Value]) -> Result<Value> {
if args.is_empty() {
return Ok(Value::boolean(true));
}
let mut bytevectors = Vec::new();
for (i, arg) in args.iter().enumerate() {
let bv = extract_bytevector(arg, &format!("bytevector=? (argument {})", i + 1))?;
bytevectors.push(bv);
}
let first = &bytevectors[0];
for bv in &bytevectors[1..] {
if bv != first {
return Ok(Value::boolean(false));
}
}
Ok(Value::boolean(true))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_make_bytevector() {
let result = primitive_make_bytevector(&[Value::integer(5)]).unwrap();
assert_eq!(result, Value::bytevector(vec![0, 0, 0, 0, 0]));
let result = primitive_make_bytevector(&[Value::integer(3), Value::integer(42)]).unwrap();
assert_eq!(result, Value::bytevector(vec![42, 42, 42]));
}
#[test]
fn test_bytevector() {
let result = primitive_bytevector(&[]).unwrap();
assert_eq!(result, Value::bytevector(vec![]));
let result = primitive_bytevector(&[
Value::integer(65),
Value::integer(66),
Value::integer(67)
]).unwrap();
assert_eq!(result, Value::bytevector(vec![65, 66, 67]));
}
#[test]
fn test_bytevector_predicate() {
let result = primitive_bytevector_p(&[Value::bytevector(vec![1, 2, 3])]).unwrap();
assert_eq!(result, Value::boolean(true));
let result = primitive_bytevector_p(&[Value::integer(42)]).unwrap();
assert_eq!(result, Value::boolean(false));
}
#[test]
fn test_bytevector_length() {
let result = primitive_bytevector_length(&[Value::bytevector(vec![1, 2, 3, 4, 5])]).unwrap();
assert_eq!(result, Value::integer(5));
}
#[test]
fn test_bytevector_u8_ref() {
let bv = Value::bytevector(vec![10, 20, 30]);
let result = primitive_bytevector_u8_ref(&[bv, Value::integer(1)]).unwrap();
assert_eq!(result, Value::integer(20));
}
#[test]
fn test_bytevector_to_list() {
let bv = Value::bytevector(vec![65, 66, 67]);
let result = primitive_bytevector_to_list(&[bv]).unwrap();
let expected = Value::list(vec![Value::integer(65), Value::integer(66), Value::integer(67)]);
assert_eq!(result, expected);
}
#[test]
fn test_list_to_bytevector() {
let list = Value::list(vec![Value::integer(65), Value::integer(66), Value::integer(67)]);
let result = primitive_list_to_bytevector(&[list]).unwrap();
assert_eq!(result, Value::bytevector(vec![65, 66, 67]));
}
#[test]
fn test_string_to_utf8() {
let result = primitive_string_to_utf8(&[Value::string("ABC")]).unwrap();
assert_eq!(result, Value::bytevector(vec![65, 66, 67]));
}
#[test]
fn test_utf8_to_string() {
let bv = Value::bytevector(vec![65, 66, 67]);
let result = primitive_utf8_to_string(&[bv]).unwrap();
assert_eq!(result, Value::string("ABC"));
}
#[test]
fn test_bytevector_equal() {
let bv1 = Value::bytevector(vec![1, 2, 3]);
let bv2 = Value::bytevector(vec![1, 2, 3]);
let bv3 = Value::bytevector(vec![1, 2, 4]);
let result = primitive_bytevector_equal(&[bv1.clone(), bv2]).unwrap();
assert_eq!(result, Value::boolean(true));
let result = primitive_bytevector_equal(&[bv1, bv3]).unwrap();
assert_eq!(result, Value::boolean(false));
let result = primitive_bytevector_equal(&[]).unwrap();
assert_eq!(result, Value::boolean(true));
}
}