pub mod scoring;
use crate::debug_info;
use crate::logic::grammar::Grammar;
use crate::logic::partial::Synthesizer;
use crate::logic::typing::tree::{TypedAST, TypedNode};
use crate::logic::typing::{gather_terminals_typed, Context, Type};
use crate::regex::Regex as DerivativeRegex;
use std::collections::{BinaryHeap, HashSet, VecDeque};
#[derive(Debug, Clone, Copy)]
pub struct SearchConfig {
pub max_depth: usize,
pub max_token_examples: usize,
pub max_states: usize,
pub max_children_per_state: usize,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
max_depth: 10,
max_token_examples: 1,
max_states: 96,
max_children_per_state: 12,
}
}
}
#[derive(Debug)]
pub enum SearchResult {
Success {
complete_input: String,
ast: TypedNode,
completion_path: Vec<DerivativeRegex>,
depth: usize,
},
Exhausted {
max_depth: usize,
states_explored: usize,
visited_states: Vec<String>,
},
Invalid {
message: String,
},
}
#[derive(Clone)]
struct SearchState {
tree: TypedAST,
depth: usize,
path: Vec<DerivativeRegex>,
}
#[derive(Clone)]
struct ScoredState {
score: f64,
state: SearchState,
}
impl PartialEq for ScoredState {
fn eq(&self, other: &Self) -> bool {
self.score.to_bits() == other.score.to_bits()
}
}
impl Eq for ScoredState {}
impl PartialOrd for ScoredState {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.score.partial_cmp(&other.score)
}
}
impl Ord for ScoredState {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.score
.partial_cmp(&other.score)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
pub fn search_complete(
grammar: &Grammar,
input: &str,
config: &SearchConfig,
ctx: &Context,
) -> SearchResult {
let mut synth = Synthesizer::new(grammar.clone(), input);
let base_tree = match synth.partial_typed_ctx(ctx) {
Ok(ast) => ast,
Err(e) => {
return SearchResult::Invalid {
message: format!("Input is not partially valid: {}", e),
}
}
};
let mut visited: HashSet<String> = HashSet::new();
visited.insert(base_tree.text().to_string());
let mut visited_states: VecDeque<String> = VecDeque::new();
visited_states.push_back(base_tree.text().to_string());
let mut states_explored = 0usize;
if let Some(success) = try_greedy_complete(
&mut synth,
base_tree.clone(),
ctx,
config.max_depth,
config.max_token_examples,
) {
return success;
}
let initial_state = SearchState {
tree: base_tree,
depth: 0,
path: Vec::new(),
};
let mut frontier: BinaryHeap<ScoredState> = BinaryHeap::new();
let initial_score = scoring::calculate_score(&initial_state.tree, 0, config.max_depth).overall;
frontier.push(ScoredState {
score: initial_score,
state: initial_state,
});
while let Some(ScoredState { state, .. }) = frontier.pop() {
debug_info!(
"search",
"Exploring state: depth={} input='{}' score={}",
state.depth,
state.tree.text(),
scoring::calculate_score(&state.tree, state.depth, config.max_depth).overall
);
if let Some(complete_node) = find_valid_completion(&state.tree) {
let reconstructed = state.tree.text();
debug_info!(
"search",
"Completion found: depth={} input='{}'",
state.depth,
reconstructed
);
return SearchResult::Success {
complete_input: reconstructed,
ast: complete_node,
completion_path: state.path.clone(),
depth: state.depth,
};
}
if state.depth >= config.max_depth {
continue;
}
if states_explored >= config.max_states {
break;
}
let children = build_children(
&mut synth,
&state,
ctx,
config.max_depth,
config.max_token_examples,
config.max_children_per_state,
&mut visited,
&mut visited_states,
&mut states_explored,
);
for (child, score) in children {
frontier.push(ScoredState {
score,
state: child,
});
}
}
SearchResult::Exhausted {
max_depth: config.max_depth,
states_explored,
visited_states: visited_states.into_iter().collect(),
}
}
fn try_greedy_complete(
synth: &mut Synthesizer,
mut tree: TypedAST,
ctx: &Context,
max_depth: usize,
max_token_examples: usize,
) -> Option<SearchResult> {
let mut path = Vec::new();
for depth in 0..=max_depth {
if let Some(complete_node) = find_valid_completion(&tree) {
return Some(SearchResult::Success {
complete_input: tree.text(),
ast: complete_node,
completion_path: path,
depth,
});
}
if depth == max_depth {
break;
}
synth.set_input(tree.text());
let tokens = synth.completions_ctx(ctx);
if tokens.is_empty() {
break;
}
let mut local_terms = Vec::new();
for root in &tree.roots {
local_terms.extend(gather_terminals_typed(root));
}
let mut best_next: Option<(TypedAST, DerivativeRegex, usize, f64, f64)> = None;
for token in tokens.iter() {
let candidates = synth.extend_all_with_regex_candidates(
token,
ctx,
&local_terms,
max_token_examples,
);
for (next_tree, _ext) in candidates.into_iter() {
if !has_well_typed_root(&next_tree) || next_tree.text() == tree.text() {
continue;
}
if let Some(complete_node) = find_valid_completion(&next_tree) {
let mut completion_path = path.clone();
completion_path.push(token.clone());
return Some(SearchResult::Success {
complete_input: next_tree.text(),
ast: complete_node,
completion_path,
depth: depth + 1,
});
}
let state_score = scoring::calculate_score(&next_tree, depth + 1, max_depth);
let score = state_score.overall;
let open_slots = state_score.open_slots;
let grounded = grounded_root_count(&next_tree);
match &best_next {
Some((_, _, best_grounded, best_open_slots, best_score))
if grounded < *best_grounded
|| (grounded == *best_grounded
&& (open_slots < *best_open_slots
|| (open_slots == *best_open_slots
&& score <= *best_score))) => {}
_ => {
best_next = Some((next_tree, token.clone(), grounded, open_slots, score));
}
}
}
}
if let Some((next_tree, chosen_token, _, _, _)) = best_next {
path.push(chosen_token);
tree = next_tree;
} else {
break;
}
}
None
}
fn build_children(
synth: &mut Synthesizer,
state: &SearchState,
ctx: &Context,
max_depth: usize,
max_token_examples: usize,
max_children_per_state: usize,
visited: &mut HashSet<String>,
visited_states: &mut VecDeque<String>,
states_explored: &mut usize,
) -> Vec<(SearchState, f64)> {
synth.set_input(state.tree.text().to_string());
let tokens = synth.completions_ctx(ctx);
debug_info!(
"search",
"Expanding state: depth={} input='{}' tokens: {}",
state.depth,
state.tree.text(),
tokens
.iter()
.map(|t| t.to_string())
.collect::<Vec<_>>()
.join(", ")
);
let mut children: Vec<(SearchState, usize, f64)> = Vec::new();
let mut local_terms = Vec::new();
for root in &state.tree.roots {
local_terms.extend(gather_terminals_typed(root));
}
for token in tokens.iter() {
for child in
extend_states_with_token(synth, state, token, ctx, &local_terms, max_token_examples)
{
if visited.insert(child.tree.text().to_string()) {
visited_states.push_back(child.tree.text().to_string());
*states_explored += 1;
let score = scoring::calculate_score(&child.tree, child.depth, max_depth).overall;
let grounded = grounded_root_count(&child.tree);
children.push((child, grounded, score));
}
}
}
let has_grounded = children.iter().any(|(_, grounded, _)| *grounded > 0);
if has_grounded {
children.retain(|(_, grounded, _)| *grounded > 0);
}
children.sort_by(|(_, grounded_a, score_a), (_, grounded_b, score_b)| {
grounded_b.cmp(grounded_a).then_with(|| {
score_b
.partial_cmp(score_a)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
if max_children_per_state > 0 && children.len() > max_children_per_state {
children.truncate(max_children_per_state);
}
children
.into_iter()
.map(|(state, _grounded, score)| (state, score))
.collect()
}
fn extend_states_with_token(
synth: &mut Synthesizer,
state: &SearchState,
token: &DerivativeRegex,
ctx: &Context,
local_terms: &[String],
max_token_examples: usize,
) -> Vec<SearchState> {
synth.set_input(state.tree.text().to_string());
let mut out = Vec::new();
let mut extra = local_terms.to_vec();
extra.sort();
extra.dedup();
let candidates = synth.extend_all_with_regex_candidates(token, ctx, &extra, max_token_examples);
for (ext, _extended) in candidates.into_iter() {
if has_well_typed_root(&ext) {
let mut path = state.path.clone();
path.push(token.clone());
out.push(SearchState {
tree: ext,
depth: state.depth + 1,
path,
});
}
}
out
}
fn has_well_typed_root(ast: &TypedAST) -> bool {
!ast.roots.is_empty()
}
fn grounded_root_count(ast: &TypedAST) -> usize {
ast.roots
.iter()
.filter(|r| !matches!(r.ty(), Type::Any))
.count()
}
fn find_valid_completion(ast: &TypedAST) -> Option<TypedNode> {
ast.roots.iter().find(|r| r.is_complete()).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{set_debug_level, testing::load_example_grammar};
#[test]
fn search_never_returns_syntactically_or_typedly_invalid_completion() {
let grammar = load_example_grammar("fun");
let cfg = SearchConfig {
max_depth: 6,
..Default::default()
};
let prefix = "let";
let result = search_complete(&grammar, prefix, &cfg, &Context::new());
if let SearchResult::Success { complete_input, .. } = result {
let mut mp = crate::logic::partial::MetaParser::new(grammar).with_max_depth(62);
let typed = mp
.partial_typed(&complete_input)
.unwrap_or_else(|e| panic!("invalid completion '{}': {}", complete_input, e));
assert!(
typed.clone().complete().is_ok(),
"completion is not a complete typed tree: {}",
complete_input
);
}
}
#[test]
#[ignore = "search broken after max_states cap, needs fixing"]
fn search_fun_let_name_prefix_depth6() {
set_debug_level(crate::DebugLevel::Trace);
crate::add_module_filter("search");
let grammar = load_example_grammar("fun");
let cfg = SearchConfig {
max_depth: 6,
..Default::default()
};
let result = search_complete(&grammar, "let x", &cfg, &Context::new());
assert!(
matches!(result, SearchResult::Success { .. }),
"expected completion success for 'let x' with depth 6, got {:?}",
result
);
}
#[test]
#[ignore = "search broken after max_states cap, needs fixing"]
fn search_fun_let_prefix_depth7() {
set_debug_level(crate::DebugLevel::Trace);
crate::add_module_filter("search");
let grammar = load_example_grammar("fun");
let cfg = SearchConfig {
max_depth: 7,
..Default::default()
};
let result = search_complete(&grammar, "let", &cfg, &Context::new());
assert!(matches!(result, SearchResult::Success { .. }));
}
}