use cljrs_reader::Form;
use cljrs_reader::form::FormKind;
#[derive(Debug)]
pub struct LetChain {
pub start: usize,
pub len: usize,
pub ops: Vec<ChainOpKind>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChainOpKind {
Assoc,
Conj,
}
pub fn detect_let_chains(bindings: &[Form]) -> Vec<LetChain> {
if bindings.len() < 4 {
return vec![];
}
let mut chains = Vec::new();
let pairs: Vec<(&Form, &Form)> = bindings.chunks(2).map(|c| (&c[0], &c[1])).collect();
let mut i = 0;
while i < pairs.len() {
if let Some(chain) = try_start_chain(&pairs, i)
&& chain.len >= 2
{
let end = chain.start + chain.len;
chains.push(chain);
i = end; continue;
}
i += 1;
}
chains
}
fn try_start_chain(pairs: &[(&Form, &Form)], start: usize) -> Option<LetChain> {
let (first_name, first_expr) = pairs[start];
let first_op = classify_call(first_expr)?;
let first_name_str = symbol_name(first_name)?;
let mut ops = vec![first_op];
let mut prev_name = first_name_str;
for &(name, expr) in &pairs[(start + 1)..] {
let op = match classify_call(expr) {
Some(op) => op,
None => break,
};
if !call_collection_arg_is(expr, prev_name) {
break;
}
prev_name = match symbol_name(name) {
Some(s) => s,
None => break,
};
ops.push(op);
}
if ops.len() >= 2 {
Some(LetChain {
start,
len: ops.len(),
ops,
})
} else {
None
}
}
fn classify_call(expr: &Form) -> Option<ChainOpKind> {
let forms = match &expr.kind {
FormKind::List(forms) if !forms.is_empty() => forms,
_ => return None,
};
let head = match &forms[0].kind {
FormKind::Symbol(s) => s.as_str(),
_ => return None,
};
match head {
"assoc" | "clojure.core/assoc" if forms.len() >= 4 => Some(ChainOpKind::Assoc),
"conj" | "clojure.core/conj" if forms.len() >= 3 => Some(ChainOpKind::Conj),
_ => None,
}
}
fn call_collection_arg_is(expr: &Form, name: &str) -> bool {
let forms = match &expr.kind {
FormKind::List(forms) if forms.len() >= 2 => forms,
_ => return false,
};
matches!(&forms[1].kind, FormKind::Symbol(s) if s == name)
}
fn symbol_name(form: &Form) -> Option<&str> {
match &form.kind {
FormKind::Symbol(s) => Some(s.as_str()),
_ => None,
}
}
pub fn binding_used_in_body(name: &str, body: &[Form]) -> bool {
body.iter().any(|f| form_references_symbol(f, name))
}
pub fn binding_used_in_other_bindings(
name: &str,
bindings: &[Form],
chain_start: usize,
chain_len: usize,
) -> bool {
let chain_end_pair = chain_start + chain_len;
for (i, chunk) in bindings.chunks(2).enumerate() {
if i >= chain_start && i < chain_end_pair {
if i > chain_start {
if let FormKind::List(forms) = &chunk[1].kind {
for arg_form in forms.iter().skip(2) {
if form_references_symbol(arg_form, name) {
return true;
}
}
}
}
continue;
}
if chunk.len() >= 2 && form_references_symbol(&chunk[1], name) {
return true;
}
}
false
}
fn form_references_symbol(form: &Form, name: &str) -> bool {
match &form.kind {
FormKind::Symbol(s) => s == name,
FormKind::List(forms)
| FormKind::Vector(forms)
| FormKind::Set(forms)
| FormKind::Map(forms) => forms.iter().any(|f| form_references_symbol(f, name)),
FormKind::Quote(_) => false, FormKind::SyntaxQuote(inner)
| FormKind::Unquote(inner)
| FormKind::UnquoteSplice(inner)
| FormKind::Deref(inner)
| FormKind::Var(inner) => form_references_symbol(inner, name),
FormKind::Meta(m, inner) => {
form_references_symbol(m, name) || form_references_symbol(inner, name)
}
FormKind::AnonFn(forms) => forms.iter().any(|f| form_references_symbol(f, name)),
FormKind::TaggedLiteral(_, inner) => form_references_symbol(inner, name),
FormKind::ReaderCond { clauses, .. } => {
clauses.iter().any(|f| form_references_symbol(f, name))
}
_ => false, }
}
#[cfg(test)]
mod tests {
use super::*;
use cljrs_reader::Parser;
fn parse_bindings(src: &str) -> Vec<Form> {
let mut parser = Parser::new(src.to_string(), "<test>".to_string());
let form = parser.parse_one().unwrap().unwrap();
match form.kind {
FormKind::List(forms) => match &forms[1].kind {
FormKind::Vector(v) => v.clone(),
_ => panic!("expected vector"),
},
_ => panic!("expected list"),
}
}
#[test]
fn test_detect_assoc_chain() {
let bindings =
parse_bindings("(let [a (assoc m :x 1) b (assoc a :y 2) c (assoc b :z 3)] c)");
let chains = detect_let_chains(&bindings);
assert_eq!(chains.len(), 1);
assert_eq!(chains[0].start, 0);
assert_eq!(chains[0].len, 3);
assert!(chains[0].ops.iter().all(|op| *op == ChainOpKind::Assoc));
}
#[test]
fn test_detect_conj_chain() {
let bindings = parse_bindings("(let [a (conj v 1) b (conj a 2) c (conj b 3)] c)");
let chains = detect_let_chains(&bindings);
assert_eq!(chains.len(), 1);
assert_eq!(chains[0].len, 3);
assert!(chains[0].ops.iter().all(|op| *op == ChainOpKind::Conj));
}
#[test]
fn test_no_chain_too_short() {
let bindings = parse_bindings("(let [a (assoc m :x 1)] a)");
let chains = detect_let_chains(&bindings);
assert!(chains.is_empty());
}
#[test]
fn test_no_chain_different_names() {
let bindings = parse_bindings("(let [a (assoc m :x 1) b (assoc m :y 2)] b)");
let chains = detect_let_chains(&bindings);
assert!(chains.is_empty());
}
#[test]
fn test_mixed_chain() {
let bindings = parse_bindings("(let [a (assoc m :x 1) b (conj a 2) c (assoc b :z 3)] c)");
let chains = detect_let_chains(&bindings);
assert_eq!(chains.len(), 1);
assert_eq!(chains[0].len, 3);
}
#[test]
fn test_body_reference_detected() {
let mut parser = Parser::new("(println a)".to_string(), "<test>".to_string());
let form = parser.parse_one().unwrap().unwrap();
assert!(binding_used_in_body("a", &[form]));
let mut parser = Parser::new("(println b)".to_string(), "<test>".to_string());
let form = parser.parse_one().unwrap().unwrap();
assert!(!binding_used_in_body("a", &[form]));
}
}