use std::sync::Arc;
use crate::register_fns;
use cljrs_gc::GcPtr;
use cljrs_value::value::SetValue;
use cljrs_value::{Arity, MapValue, PersistentHashSet, Value, ValueError, ValueResult};
pub fn register(globals: &Arc<cljrs_eval::GlobalEnv>, ns: &str) {
register_fns!(
globals,
ns,
[
("union", Arity::Variadic { min: 0 }, union),
("intersection", Arity::Variadic { min: 1 }, intersection),
("difference", Arity::Variadic { min: 1 }, difference),
("subset?", Arity::Fixed(2), subset_q),
("superset?", Arity::Fixed(2), superset_q),
("select", Arity::Fixed(2), select),
("map-invert", Arity::Fixed(1), map_invert),
]
);
}
fn get_set(v: &Value) -> ValueResult<&SetValue> {
match v {
Value::Set(s) => Ok(s),
other => Err(ValueError::WrongType {
expected: "set",
got: other.type_name().to_string(),
}),
}
}
fn union(args: &[Value]) -> ValueResult<Value> {
let mut result = PersistentHashSet::empty();
for arg in args {
let s = get_set(arg)?;
for v in s.iter() {
result.conj_mut(v.clone());
}
}
Ok(Value::Set(SetValue::Hash(GcPtr::new(result))))
}
fn intersection(args: &[Value]) -> ValueResult<Value> {
if args.is_empty() {
return Ok(Value::Set(SetValue::Hash(GcPtr::new(
PersistentHashSet::empty(),
))));
}
let first = get_set(&args[0])?;
let mut result = PersistentHashSet::from_iter(first.iter().cloned());
for arg in &args[1..] {
let s = get_set(arg)?;
let mut next = PersistentHashSet::empty();
for v in result.iter() {
if s.contains(v) {
next.conj_mut(v.clone());
}
}
result = next;
}
Ok(Value::Set(SetValue::Hash(GcPtr::new(result))))
}
fn difference(args: &[Value]) -> ValueResult<Value> {
if args.is_empty() {
return Ok(Value::Set(SetValue::Hash(GcPtr::new(
PersistentHashSet::empty(),
))));
}
let first = get_set(&args[0])?;
let mut result = PersistentHashSet::from_iter(first.iter().cloned());
for arg in &args[1..] {
let s = get_set(arg)?;
let mut next = PersistentHashSet::empty();
for v in result.iter() {
if !s.contains(v) {
next.conj_mut(v.clone());
}
}
result = next;
}
Ok(Value::Set(SetValue::Hash(GcPtr::new(result))))
}
fn subset_q(args: &[Value]) -> ValueResult<Value> {
let s1 = get_set(&args[0])?;
let s2 = get_set(&args[1])?;
let result = s1.iter().all(|v| s2.contains(v));
Ok(Value::Bool(result))
}
fn superset_q(args: &[Value]) -> ValueResult<Value> {
let s1 = get_set(&args[0])?;
let s2 = get_set(&args[1])?;
let result = s2.iter().all(|v| s1.contains(v));
Ok(Value::Bool(result))
}
fn select(_args: &[Value]) -> ValueResult<Value> {
Err(ValueError::WrongType {
expected: "intercepted",
got: "clojure.set/select sentinel should not be called directly".to_string(),
})
}
fn map_invert(args: &[Value]) -> ValueResult<Value> {
match &args[0] {
Value::Map(m) => {
let mut result = MapValue::empty();
m.for_each(|k, v| {
result = result.assoc(v.clone(), k.clone());
});
Ok(Value::Map(result))
}
other => Err(ValueError::WrongType {
expected: "map",
got: other.type_name().to_string(),
}),
}
}