use fancy_regex::RegexBuilder;
use phf_codegen::Map as PhfMap;
use serde::Deserialize;
fn translate_pcre2_to_fancy_regex(pattern: &str) -> String {
pattern.replace(r#"\g<"#, r#"\k<"#)
}
fn validate_fancy_regex(pattern: &str) -> Result<(), fancy_regex::Error> {
RegexBuilder::new(&format!("(?m){pattern}"))
.build()
.map(|_| ())
}
use std::{
collections::HashMap,
fs::{self, File},
io::{BufWriter, Write},
iter,
path::Path,
};
use palate::FileType;
type NamedPatterns = HashMap<String, MaybeMany<String>>;
#[derive(Deserialize)]
struct Heuristics {
disambiguations: Vec<Disambiguation>,
named_patterns: NamedPatterns,
}
#[derive(Deserialize)]
struct Disambiguation {
extensions: Vec<String>,
rules: Vec<RuleDTO>,
}
impl Disambiguation {
fn to_domain_object_code(&self, named_patterns: &NamedPatterns) -> String {
let mut rules = String::new();
for rule in self.rules.iter() {
let rule_code = rule.to_domain_object_code(named_patterns);
if !rule_code.is_empty() {
rules.push_str(format!("{},", rule_code).as_str());
}
}
format!("&[{}]", rules)
}
}
#[derive(Deserialize)]
struct RuleDTO {
language: MaybeMany<String>,
#[serde(flatten)]
pattern: Option<PatternDTO>,
}
impl RuleDTO {
fn to_domain_object_code(&self, named_patterns: &NamedPatterns) -> String {
let languages = match &self.language {
MaybeMany::Many(values) => values.clone(),
MaybeMany::One(value) => vec![value.clone()],
};
let pattern_code = match &self.pattern {
Some(pattern) => format!("Some({})", pattern.to_domain_object_code(named_patterns)),
None => String::from("None"),
};
let language_types: Vec<String> = languages
.iter()
.filter_map(|lang| filetype_for_language(lang).map(|ft| format!("FileType::{ft:?}")))
.collect();
if language_types.is_empty() {
return String::new();
}
format!(
"Rule {{ languages: &[{}], pattern: {}}}",
language_types.join(", "),
pattern_code
)
}
}
#[derive(Clone, Deserialize)]
enum PatternDTO {
#[serde(rename = "and")]
And(Vec<PatternDTO>),
#[serde(rename = "named_pattern")]
Named(String),
#[serde(rename = "negative_pattern")]
Negative(String),
#[serde(rename = "pattern")]
Positive(MaybeMany<String>),
}
impl PatternDTO {
fn to_domain_object_code(&self, named_patterns: &NamedPatterns) -> String {
match self {
PatternDTO::Positive(MaybeMany::One(pattern)) => {
let pattern = translate_pcre2_to_fancy_regex(pattern);
if let Err(e) = validate_fancy_regex(&pattern) {
panic!("Invalid regex pattern: {}\n{}", pattern, e);
}
format!("Pattern::Positive({:?})", pattern)
}
PatternDTO::Negative(pattern) => {
let pattern = translate_pcre2_to_fancy_regex(pattern);
if let Err(e) = validate_fancy_regex(&pattern) {
panic!("Invalid regex pattern: {}\n{}", pattern, e);
}
format!("Pattern::Negative({:?})", pattern)
}
PatternDTO::Positive(MaybeMany::Many(patterns)) => {
let mut code = String::from("Pattern::Or(&[");
for pattern in patterns.iter() {
let p = PatternDTO::Positive(MaybeMany::One(pattern.clone()));
code.push_str(format!("{},", p.to_domain_object_code(named_patterns)).as_str());
}
code.push_str("])");
code
}
PatternDTO::And(patterns) => {
let mut code = String::from("Pattern::And(&[");
for pattern in patterns.iter() {
code.push_str(
format!("{},", pattern.to_domain_object_code(named_patterns)).as_str(),
);
}
code.push_str("])");
code
}
PatternDTO::Named(pattern_name) => {
if let Some(pattern) = named_patterns.get(pattern_name) {
let pattern = PatternDTO::Positive(pattern.clone());
return pattern.to_domain_object_code(named_patterns);
} else {
panic!(
"Named pattern: {} not found in named pattern map",
pattern_name
);
};
}
}
}
}
#[derive(Clone, Deserialize)]
#[serde(untagged)]
enum MaybeMany<T> {
Many(Vec<T>),
One(T),
}
const DISAMBIGUATION_HEURISTICS_FILE: &str = "src/codegen/disambiguation-heuristics-map.rs";
const TOKEN_LOG_PROBABILITY_FILE: &str = "src/codegen/token-log-probabilities.rs";
const HEURISTICS_SOURCE_FILE: &str = "heuristics.yml";
const MAX_TOKEN_BYTES: usize = 32;
fn main() {
let heuristics: Heuristics =
serde_norway::from_str(&fs::read_to_string(HEURISTICS_SOURCE_FILE).unwrap()[..]).unwrap();
validate_all_patterns(&heuristics);
create_disambiguation_heuristics_map(heuristics);
if Path::new("samples").exists() {
train_classifier();
} else {
println!("Note: Skipping classifier training - 'samples' directory not found");
println!(" Copy/link samples from hyperpolyglot to enable classifier training");
}
}
fn validate_all_patterns(heuristics: &Heuristics) {
let mut pattern_count = 0usize;
for dis in &heuristics.disambiguations {
for rule in &dis.rules {
if let Some(pattern) = &rule.pattern {
validate_pattern_dto(pattern, &heuristics.named_patterns, &mut pattern_count);
}
}
}
for (name, pattern) in &heuristics.named_patterns {
match pattern {
MaybeMany::One(p) => {
let translated = translate_pcre2_to_fancy_regex(p);
if let Err(e) = validate_fancy_regex(&translated) {
panic!("Invalid named pattern '{}': {}\n{}", name, translated, e);
}
pattern_count += 1;
}
MaybeMany::Many(patterns) => {
for p in patterns {
let translated = translate_pcre2_to_fancy_regex(p);
if let Err(e) = validate_fancy_regex(&translated) {
panic!("Invalid named pattern '{}': {}\n{}", name, translated, e);
}
pattern_count += 1;
}
}
}
}
println!("✓ Validated {} regex patterns with fancy-regex", pattern_count);
}
fn validate_pattern_dto(
pattern: &PatternDTO,
named_patterns: &NamedPatterns,
count: &mut usize,
) {
match pattern {
PatternDTO::Positive(MaybeMany::One(p)) | PatternDTO::Negative(p) => {
let translated = translate_pcre2_to_fancy_regex(p);
if let Err(e) = validate_fancy_regex(&translated) {
panic!("Invalid regex pattern: {}\n{}", translated, e);
}
*count += 1;
}
PatternDTO::Positive(MaybeMany::Many(patterns)) => {
for p in patterns {
validate_pattern_dto(
&PatternDTO::Positive(MaybeMany::One(p.clone())),
named_patterns,
count,
);
}
}
PatternDTO::And(patterns) => {
for p in patterns {
validate_pattern_dto(p, named_patterns, count);
}
}
PatternDTO::Named(name) => {
if let Some(pattern) = named_patterns.get(name) {
let dto = PatternDTO::Positive(pattern.clone());
validate_pattern_dto(&dto, named_patterns, count);
} else {
panic!("Named pattern '{}' not found", name);
}
}
}
}
fn create_disambiguation_heuristics_map(heuristics: Heuristics) {
let mut file = BufWriter::new(File::create(DISAMBIGUATION_HEURISTICS_FILE).unwrap());
let mut temp_map: HashMap<String, String> = HashMap::new();
for mut dis in heuristics.disambiguations.into_iter() {
for ext in dis.extensions.iter() {
if ext == ".h" {
dis.rules.push(RuleDTO {
language: MaybeMany::One(String::from("C")),
pattern: None,
});
}
let extension = ext.clone().to_ascii_lowercase();
let key = extension;
let value = dis.to_domain_object_code(&heuristics.named_patterns);
temp_map.insert(key, value);
}
}
let mut disambiguation_heuristic_map = PhfMap::new();
for (key, value) in temp_map.iter() {
disambiguation_heuristic_map.entry(&key[..], &value[..]);
}
writeln!(
&mut file,
"static DISAMBIGUATIONS: phf::Map<&'static str, &'static [Rule]> =\n{};\n",
disambiguation_heuristic_map.build()
)
.unwrap();
}
fn train_classifier() {
let mut temp_token_count: HashMap<String, HashMap<String, i32>> = HashMap::new();
let mut temp_total_tokens_count = HashMap::new();
fs::read_dir("samples")
.unwrap()
.map(|entry| entry.unwrap())
.filter(|entry| entry.path().is_dir())
.map(|language_dir| {
let path = language_dir.path();
let language = path.file_name().unwrap();
let language = language.to_string_lossy().into_owned();
let language = match &language[..] {
"Fstar" => String::from("F*"),
_ => language,
};
let file_paths = fs::read_dir(language_dir.path())
.unwrap()
.map(|entry| entry.unwrap().path())
.filter(|path| path.is_file());
let language_iter = iter::repeat(language);
file_paths.zip(language_iter)
})
.flatten()
.for_each(|(entry, language)| {
let content = fs::read(entry).unwrap();
let tokens = palate_polyglot_tokenizer::get_key_tokens(
std::str::from_utf8(&content[..]).unwrap_or(""),
);
for token in tokens {
if token.len() <= MAX_TOKEN_BYTES {
let total_tokens = temp_total_tokens_count.entry(language.clone()).or_insert(0);
*total_tokens += 1;
let tokens_count = temp_token_count
.entry(language.clone())
.or_insert(HashMap::new());
let count = tokens_count.entry(String::from(token)).or_insert(0);
*count += 1;
}
}
});
let mut file = BufWriter::new(File::create(TOKEN_LOG_PROBABILITY_FILE).unwrap());
let mut language_entries: Vec<(String, String)> = Vec::new();
for (language, token_count_map) in temp_token_count.iter() {
let total_tokens = *temp_total_tokens_count.get(language).unwrap() as f64;
let mut token_log_probabilities = PhfMap::new();
let mut token_entries: Vec<(String, String)> = Vec::new();
for (token, token_count) in token_count_map.iter() {
let probability = (*token_count as f64) / (total_tokens);
let log_probability = probability.ln();
token_entries.push((token.clone(), format!("{}f64", log_probability)));
}
for (token, value) in token_entries.iter() {
token_log_probabilities.entry(token.as_str(), value.as_str());
}
language_entries.push((
language.clone(),
token_log_probabilities.build().to_string(),
));
}
let mut language_token_log_probabilities = PhfMap::new();
for (language, map) in language_entries.iter() {
language_token_log_probabilities.entry(language.as_str(), map.as_str());
}
writeln!(
&mut file,
"static TOKEN_LOG_PROBABILITIES: phf::Map<&'static str, phf::Map<&'static str, f64>> =\n{};\n",
language_token_log_probabilities.build()
)
.unwrap();
}
fn filetype_for_language(language: &str) -> Option<FileType> {
use std::str::FromStr;
fn slugify(s: &str) -> String {
s.trim()
.to_lowercase()
.chars()
.map(|c| if c.is_ascii_alphanumeric() { c } else { '-' })
.collect::<String>()
.split('-')
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join("-")
}
let raw = language.trim();
if raw.is_empty() {
return None;
}
let lower = raw.to_lowercase();
let special = match lower.as_str() {
"c#" => Some("csharp"),
"c++" => Some("cpp"),
"f#" => Some("fsharp"),
"f*" => Some("fstar"),
"objective-c" => Some("objc"),
"objective-c++" => Some("objcpp"),
_ => None,
};
if let Some(s) = special {
if let Ok(ft) = FileType::from_str(s) {
return Some(ft);
}
}
let slug = slugify(raw);
let collapsed = lower.replace([' ', '-', '_'], "");
for candidate in [lower.as_str(), slug.as_str(), collapsed.as_str()] {
if let Ok(ft) = FileType::from_str(candidate) {
return Some(ft);
}
}
None
}