use crate::{Constant, Constraints, Measurement, Private, Public};
use core::fmt::Debug;
use std::{
cmp::Ordering,
collections::{BTreeSet, HashMap},
env,
fmt::Display,
fs,
ops::Range,
path::{Path, PathBuf},
sync::{LazyLock, Mutex, OnceLock},
};
static FILES: LazyLock<Mutex<HashMap<&'static str, FileUpdates>>> = LazyLock::new(Default::default);
static WORKSPACE_ROOT: OnceLock<PathBuf> = OnceLock::new();
#[macro_export]
macro_rules! count_is {
($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
$crate::UpdatableCount {
constant: $crate::Measurement::Exact($num_constants),
public: $crate::Measurement::Exact($num_public),
private: $crate::Measurement::Exact($num_private),
constraints: $crate::Measurement::Exact($num_constraints),
file: file!(),
line: line!(),
column: column!(),
}
};
(<=$num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
$crate::UpdatableCount {
constant: $crate::Measurement::UpperBound($num_constants),
public: $crate::Measurement::Exact($num_public),
private: $crate::Measurement::Exact($num_private),
constraints: $crate::Measurement::Exact($num_constraints),
file: file!(),
line: line!(),
column: column!(),
}
};
}
#[macro_export]
macro_rules! count_less_than {
($num_constants:literal, $num_public:literal, $num_private:literal, $num_constraints:literal) => {
$crate::UpdatableCount {
constant: $crate::Measurement::UpperBound($num_constants),
public: $crate::Measurement::UpperBound($num_public),
private: $crate::Measurement::UpperBound($num_private),
constraints: $crate::Measurement::UpperBound($num_constraints),
file: file!(),
line: line!(),
column: column!(),
}
};
}
#[derive(Copy, Clone, Debug)]
pub struct UpdatableCount {
pub constant: Constant,
pub public: Public,
pub private: Private,
pub constraints: Constraints,
#[doc(hidden)]
pub file: &'static str,
#[doc(hidden)]
pub line: u32,
#[doc(hidden)]
pub column: u32,
}
impl Display for UpdatableCount {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Constants: {}, Public: {}, Private: {}, Constraints: {}",
self.constant, self.public, self.private, self.constraints
)
}
}
impl UpdatableCount {
pub fn matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) -> bool {
self.constant.matches(num_constants)
&& self.public.matches(num_public)
&& self.private.matches(num_private)
&& self.constraints.matches(num_constraints)
}
pub fn assert_matches(&self, num_constants: u64, num_public: u64, num_private: u64, num_constraints: u64) {
if !self.matches(num_constants, num_public, num_private, num_constraints) {
let mut files = FILES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
match env::var("UPDATE_COUNT") {
Ok(query_string) if self.file.contains(&query_string) => {
files.entry(self.file).or_insert_with(|| FileUpdates::new(self)).update_count(
self,
num_constants,
num_public,
num_private,
num_constraints,
);
}
_ => {
println!(
"\n
\x1b[1m\x1b[91merror\x1b[97m: Count does not match\x1b[0m
\x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
\x1b[1mExpected\x1b[0m:
----
{}
----
\x1b[1mActual\x1b[0m:
----
Constants: {}, Public: {}, Private: {}, Constraints: {}
----
",
self.file,
self.line,
self.column,
self,
num_constants,
num_public,
num_private,
num_constraints,
);
std::panic::resume_unwind(Box::new(()));
}
}
}
}
fn locate(&self, file: &str) -> Range<usize> {
let mut line_start = 0;
let mut starting_index = None;
let mut ending_index = None;
for (i, line) in LinesWithEnds::from(file).enumerate() {
if i == self.line as usize - 1 {
let mut argument_character_indices = line.char_indices().skip((self.column - 1).try_into().unwrap())
.skip_while(|&(_, c)| c != '!') .skip(1) .skip_while(|(_, c)| c.is_whitespace());
starting_index = Some(
line_start
+ argument_character_indices
.next()
.expect("Could not find the beginning of the macro invocation.")
.0,
);
}
if starting_index.is_some() {
match line.char_indices().find(|&(_, c)| c == ')') {
None => (), Some((offset, _)) => {
ending_index = Some(line_start + offset + 1);
break;
}
}
}
line_start += line.len();
}
Range {
start: starting_index.expect("Could not find the beginning of the macro invocation."),
end: ending_index.expect("Could not find the ending of the macro invocation."),
}
}
pub fn difference_between(&self, other: &Self) -> (i64, i64, i64, i64) {
let difference = |self_measurement, other_measurement| match (self_measurement, other_measurement) {
(Measurement::Exact(self_value), Measurement::Exact(other_value))
| (Measurement::UpperBound(self_value), Measurement::UpperBound(other_value)) => {
(self_value as i64) - (other_value as i64)
}
_ => panic!(
"Cannot compute difference for `Measurement::Range` or if both measurements are of different types."
),
};
(
difference(self.constant, other.constant),
difference(self.public, other.public),
difference(self.private, other.private),
difference(self.constraints, other.constraints),
)
}
fn dummy(constant: Constant, public: Public, private: Private, constraints: Constraints) -> Self {
Self {
constant,
public,
private,
constraints,
file: Default::default(),
line: Default::default(),
column: Default::default(),
}
}
fn as_argument_string(&self) -> String {
let generate_arg = |measurement| match measurement {
Measurement::Exact(value) => value,
Measurement::UpperBound(bound) => bound,
Measurement::Range(..) => panic!(
"Cannot create an argument string from an `UpdatableCount` that contains a `Measurement::Range`."
),
};
format!(
"({}, {}, {}, {})",
generate_arg(self.constant),
generate_arg(self.public),
generate_arg(self.private),
generate_arg(self.constraints)
)
}
}
struct FileUpdates {
absolute_path: PathBuf,
original_text: String,
modified_text: String,
updates: BTreeSet<Update>,
}
impl FileUpdates {
fn new(count: &UpdatableCount) -> Self {
let path = Path::new(count.file);
let absolute_path = match path.is_absolute() {
true => path.to_owned(),
false => {
WORKSPACE_ROOT
.get_or_init(|| {
Path::new(&env!("CARGO_MANIFEST_DIR"))
.ancestors()
.filter(|it| it.join("Cargo.toml").exists())
.last()
.unwrap()
.to_path_buf()
})
.join(path)
}
};
let original_text = fs::read_to_string(&absolute_path).unwrap();
let modified_text = original_text.clone();
let updates = Default::default();
Self { absolute_path, original_text, modified_text, updates }
}
fn update_count(
&mut self,
count: &UpdatableCount,
num_constants: u64,
num_public: u64,
num_private: u64,
num_constraints: u64,
) {
let range = count.locate(&self.original_text);
let mut new_range = range.clone();
let mut update_with_same_start = None;
for previous_update in &self.updates {
let amount_deleted = previous_update.end - previous_update.start;
let amount_inserted = previous_update.argument_string.len();
match previous_update.start.cmp(&range.start) {
Ordering::Less => {
new_range.start = new_range.start - amount_deleted + amount_inserted;
new_range.end = new_range.end - amount_deleted + amount_inserted;
}
Ordering::Equal => {
new_range.end = new_range.end - amount_deleted + amount_inserted;
update_with_same_start = Some(previous_update);
}
Ordering::Greater => {
break;
}
}
}
if let Some(update) = update_with_same_start {
if update.count.matches(num_constants, num_public, num_private, num_constraints) {
return;
}
}
let new_update = match update_with_same_start {
None => Update::new(&range, count, num_constants, num_public, num_private, num_constraints),
Some(update) => Update::new(&range, &update.count, num_constants, num_public, num_private, num_constraints),
};
self.modified_text.replace_range(new_range, &new_update.argument_string);
let difference = new_update.count.difference_between(count);
println!(
"\n
\x1b[1m\x1b[33mwarning\x1b[97m: Updated count\x1b[0m
\x1b[1m\x1b[34m-->\x1b[0m {}:{}:{}
\x1b[1mOriginal count\x1b[0m:
----
{}
----
\x1b[1mUpdated count\x1b[0m:
----
{}
----
\x1b[1mDifference between updated and original\x1b[0m:
----
Constants: {}, Public: {}, Private: {}, Constraints: {}
----
",
count.file,
count.line,
count.column,
count,
new_update.count,
difference.0,
difference.1,
difference.2,
difference.3
);
self.updates.replace(new_update);
fs::write(&self.absolute_path, &self.modified_text).unwrap()
}
}
#[derive(Debug)]
struct Update {
start: usize,
end: usize,
count: UpdatableCount,
argument_string: String,
}
impl Update {
fn new(
range: &Range<usize>,
old_count: &UpdatableCount,
num_constants: u64,
num_public: u64,
num_private: u64,
num_constraints: u64,
) -> Self {
let generate_new_measurement = |measurement: Measurement<u64>, expected: u64| match measurement {
Measurement::Exact(..) => Measurement::Exact(expected),
Measurement::Range(..) => panic!("UpdatableCount does not support ranges."),
Measurement::UpperBound(bound) => Measurement::UpperBound(std::cmp::max(expected, bound)),
};
let count = UpdatableCount::dummy(
generate_new_measurement(old_count.constant, num_constants),
generate_new_measurement(old_count.public, num_public),
generate_new_measurement(old_count.private, num_private),
generate_new_measurement(old_count.constraints, num_constraints),
);
Self { start: range.start, end: range.end, count, argument_string: count.as_argument_string() }
}
}
impl PartialEq for Update {
fn eq(&self, other: &Self) -> bool {
self.start == other.start
}
}
impl Eq for Update {}
impl PartialOrd for Update {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Update {
fn cmp(&self, other: &Self) -> Ordering {
self.start.cmp(&other.start)
}
}
struct LinesWithEnds<'a> {
text: &'a str,
}
impl<'a> Iterator for LinesWithEnds<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<&'a str> {
match self.text.is_empty() {
true => None,
false => {
let idx = self.text.find('\n').map_or(self.text.len(), |it| it + 1);
let (res, next) = self.text.split_at(idx);
self.text = next;
Some(res)
}
}
}
}
impl<'a> From<&'a str> for LinesWithEnds<'a> {
fn from(text: &'a str) -> Self {
LinesWithEnds { text }
}
}
#[cfg(test)]
mod test {
use serial_test::serial;
use std::env;
#[test]
fn check_position() {
let count = count_is!(0, 0, 0, 0);
assert_eq!(count.file, "circuit/environment/src/helpers/updatable_count.rs");
assert_eq!(count.line, 505);
assert_eq!(count.column, 21);
}
#[test]
#[serial]
fn check_count_passes() {
let count = count_is!(1, 2, 3, 4);
let num_constants = 1;
let num_public = 2;
let num_private = 3;
let num_inputs = 4;
count.assert_matches(num_constants, num_public, num_private, num_inputs);
}
#[test]
#[serial]
#[should_panic]
fn check_count_fails() {
let count = count_is!(1, 2, 3, 4);
let num_constants = 5;
let num_public = 6;
let num_private = 7;
let num_inputs = 8;
count.assert_matches(num_constants, num_public, num_private, num_inputs);
}
#[test]
#[serial]
#[should_panic]
fn check_count_does_not_update_if_env_var_is_not_set_correctly() {
let count = count_is!(1, 2, 3, 4);
let num_constants = 5;
let num_public = 6;
let num_private = 7;
let num_inputs = 8;
env::set_var("UPDATE_COUNT", "1");
count.assert_matches(num_constants, num_public, num_private, num_inputs);
env::remove_var("UPDATE_COUNT");
}
#[test]
#[serial]
fn check_count_updates_correctly() {
let count = count_is!(11, 12, 13, 14);
let num_constants = 11;
let num_public = 12;
let num_private = 13;
let num_inputs = 14;
env::set_var("UPDATE_COUNT", "updatable_count.rs");
count.assert_matches(num_constants, num_public, num_private, num_inputs);
env::remove_var("UPDATE_COUNT");
}
#[test]
#[serial]
fn check_count_updates_correctly_multiple_times() {
let count = count_is!(17, 18, 19, 20);
env::set_var("UPDATE_COUNT", "updatable_count.rs");
let (num_constants, num_public, num_private, num_inputs) = (5, 6, 7, 8);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (9, 10, 11, 12);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (13, 14, 15, 16);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (17, 18, 19, 20);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
env::remove_var("UPDATE_COUNT");
}
#[test]
#[serial]
fn check_count_less_than_selects_maximum() {
let count = count_less_than!(17, 18, 19, 20);
env::set_var("UPDATE_COUNT", "updatable_count.rs");
let (num_constants, num_public, num_private, num_inputs) = (5, 18, 7, 8);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (17, 10, 11, 12);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (13, 6, 19, 16);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
let (num_constants, num_public, num_private, num_inputs) = (9, 18, 15, 20);
count.assert_matches(num_constants, num_public, num_private, num_inputs);
env::remove_var("UPDATE_COUNT");
}
}