use std::{
any::Any,
ops::Range,
path::{Path, PathBuf},
};
use anyhow::Context;
use crate::streaming::{self, Executor, Stream};
use super::{parsing, validate};
#[derive(Debug)]
pub struct RunBook {
stages: Vec<Stage>,
max_points: usize,
max_tag: Option<usize>,
}
impl RunBook {
pub fn load(
path: &Path,
dataset: &str,
groundtruth: &mut dyn FindGroundtruth,
) -> anyhow::Result<Self> {
parsing::load(path, dataset, groundtruth)
}
pub(super) fn new(stages: Vec<Stage>, max_points: usize) -> anyhow::Result<Self> {
let mut this = Self {
stages,
max_points,
max_tag: None,
};
let mut validator = validate::Validate::new();
this.run_with(&mut validator, |_| Ok(()))?;
this.max_points = this.max_points.max(validator.max_active());
this.max_tag = validator.max_tag();
Ok(this)
}
pub fn max_points(&self) -> usize {
self.max_points
}
pub fn max_tag(&self) -> Option<usize> {
self.max_tag
}
pub fn len(&self) -> usize {
self.stages.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[cfg(test)]
pub(super) fn stages(&self) -> &[Stage] {
&self.stages
}
fn run_with_internal(
&self,
stream: &mut dyn streaming::Stream<Args, Output = Box<dyn Any>>,
collect: &mut dyn FnMut(Box<dyn Any>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
for (i, stage) in self.stages.iter().enumerate() {
#[derive(Clone, Copy)]
struct OnStage(usize, usize);
impl std::fmt::Display for OnStage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "on stage {} of {}", self.0, self.1)
}
}
let context = OnStage(i, self.len());
if stream.needs_maintenance() {
collect(stream.maintain(()).context(context)?).context(context)?;
}
let output = match stage {
Stage::Search { groundtruth } => {
let args = Search { groundtruth };
stream.search(args).context(context)?
}
Stage::Insert {
dataset_offsets_and_ids,
} => {
let args = Insert {
offsets: dataset_offsets_and_ids.clone(),
ids: dataset_offsets_and_ids.clone(),
};
stream.insert(args).context(context)?
}
Stage::Replace {
dataset_offsets,
ids,
} => {
let args = Replace {
offsets: dataset_offsets.clone(),
ids: ids.clone(),
};
stream.replace(args).context(context)?
}
Stage::Delete { ids } => {
let args = Delete { ids: ids.clone() };
stream.delete(args).context(context)?
}
};
collect(output).context(context)?;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Stage {
Search {
groundtruth: PathBuf,
},
Insert {
dataset_offsets_and_ids: Range<usize>,
},
Replace {
dataset_offsets: Range<usize>,
ids: Range<usize>,
},
Delete {
ids: Range<usize>,
},
}
#[derive(Debug)]
pub struct Search<'a> {
pub groundtruth: &'a Path,
}
pub struct Insert {
pub offsets: Range<usize>,
pub ids: Range<usize>,
}
pub struct Replace {
pub offsets: Range<usize>,
pub ids: Range<usize>,
}
pub struct Delete {
pub ids: Range<usize>,
}
#[derive(Debug, Clone, Copy)]
pub struct Args;
impl streaming::Arguments for Args {
type Search<'a> = Search<'a>;
type Insert<'a> = Insert;
type Replace<'a> = Replace;
type Delete<'a> = Delete;
type Maintain<'a> = ();
}
impl streaming::Executor for RunBook {
type Args = Args;
fn run_with<S, F, O>(&mut self, stream: &mut S, mut collect: F) -> anyhow::Result<()>
where
S: Stream<Args, Output = O>,
O: 'static,
F: FnMut(O) -> anyhow::Result<()>,
{
self.run_with_internal(&mut streaming::AnyStream::new(stream), &mut |any| {
let typed = *any
.downcast::<S::Output>()
.expect("the dynamic type should be configured correctly");
collect(typed)
})
}
}
pub trait FindGroundtruth {
fn find_groundtruth(&mut self, stage: usize) -> anyhow::Result<PathBuf>;
}
#[derive(Debug)]
pub struct ScanDirectory {
directory: PathBuf,
files: Vec<String>,
}
impl ScanDirectory {
pub fn new(directory: impl Into<PathBuf>) -> anyhow::Result<Self> {
Self::new_(directory.into())
}
fn new_(directory: PathBuf) -> anyhow::Result<Self> {
let read_dir = std::fs::read_dir(&directory).with_context(|| {
format!(
"while trying to read the contents of {}",
directory.display()
)
})?;
let files = read_dir
.filter_map(|entry| {
if let Ok(entry) = entry
&& let Ok(file_type) = entry.file_type()
&& file_type.is_file()
{
Some(entry.file_name().to_string_lossy().into())
} else {
None
}
})
.collect();
Ok(Self { directory, files })
}
}
impl FindGroundtruth for ScanDirectory {
fn find_groundtruth(&mut self, stage: usize) -> anyhow::Result<PathBuf> {
let prefix = format!("step{}.gt", stage);
enum Matches<'a> {
None,
One(&'a str),
Many(Vec<&'a str>),
}
impl<'a> Matches<'a> {
fn push(&mut self, file: &'a str) {
*self = match std::mem::replace(self, Self::None) {
Self::None => Self::One(file),
Self::One(first) => Self::Many(vec![first, file]),
Self::Many(mut all) => {
all.push(file);
Self::Many(all)
}
};
}
}
let mut matches = Matches::None;
for file in self.files.iter() {
if file.starts_with(&prefix) {
let suffix = &file[prefix.len()..];
if suffix.chars().all(|c| c.is_ascii_digit()) {
matches.push(file);
}
}
}
match matches {
Matches::One(m) => Ok(self.directory.join(m)),
Matches::None => Err(anyhow::anyhow!(
"No groundtruth found for step {} in \"{}\", expected pattern: \"step{}.gt[0-9]*\"",
stage,
self.directory.display(),
stage,
)),
Matches::Many(matches) => Err(anyhow::anyhow!(
"Multiple groundtruth files found for step {} in \"{}\": {:?}",
stage,
self.directory.display(),
matches,
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use tempfile::TempDir;
use crate::streaming::Executor;
struct MockStream {
stages: Vec<Stage>,
current_stage: usize,
asked_for_maintenance: bool,
}
impl MockStream {
fn new(stages: Vec<Stage>) -> Self {
Self {
stages,
current_stage: 0,
asked_for_maintenance: false,
}
}
fn increment(&mut self) -> usize {
let output = self.current_stage;
self.current_stage += 1;
output
}
fn current(&self) -> &Stage {
&self.stages[self.current_stage]
}
}
impl streaming::Stream<Args> for MockStream {
type Output = Option<usize>;
fn search(&mut self, args: Search<'_>) -> anyhow::Result<Option<usize>> {
if let Stage::Search { groundtruth } = self.current() {
assert_eq!(args.groundtruth, groundtruth.as_path());
Ok(Some(self.increment()))
} else {
Err(anyhow::anyhow!(
"Expected Search stage, instead got {:?}",
self.current()
))
}
}
fn insert(&mut self, args: Insert) -> anyhow::Result<Option<usize>> {
if let Stage::Insert {
dataset_offsets_and_ids,
} = self.current()
{
assert_eq!(&args.offsets, dataset_offsets_and_ids);
assert_eq!(&args.ids, dataset_offsets_and_ids);
Ok(Some(self.increment()))
} else {
Err(anyhow::anyhow!(
"Expected Insert stage, instead got {:?}",
self.current()
))
}
}
fn replace(&mut self, args: Replace) -> anyhow::Result<Option<usize>> {
if let Stage::Replace {
dataset_offsets,
ids,
} = self.current()
{
assert_eq!(&args.offsets, dataset_offsets);
assert_eq!(&args.ids, ids);
Ok(Some(self.increment()))
} else {
Err(anyhow::anyhow!(
"Expected Replace stage, instead got {:?}",
self.current()
))
}
}
fn delete(&mut self, args: Delete) -> anyhow::Result<Option<usize>> {
if let Stage::Delete { ids } = self.current() {
assert_eq!(&args.ids, ids);
Ok(Some(self.increment()))
} else {
Err(anyhow::anyhow!(
"Expected Delete stage, instead got {:?}",
self.current()
))
}
}
fn maintain(&mut self, _args: ()) -> anyhow::Result<Option<usize>> {
assert!(
self.asked_for_maintenance,
"Stream was not expected to need maintenance"
);
self.asked_for_maintenance = false;
Ok(None)
}
fn needs_maintenance(&mut self) -> bool {
let needs = self.asked_for_maintenance;
self.asked_for_maintenance = true;
needs
}
}
#[test]
fn test_runbook() {
let stages = vec![
Stage::Insert {
dataset_offsets_and_ids: 0..100,
},
Stage::Search {
groundtruth: PathBuf::from("gt0"),
},
Stage::Replace {
dataset_offsets: 100..200,
ids: 0..100,
},
Stage::Delete { ids: 50..75 },
Stage::Search {
groundtruth: PathBuf::from("gt1"),
},
];
let mut runbook = RunBook::new(stages.clone(), 1000).unwrap();
assert_eq!(runbook.len(), stages.len());
assert!(!runbook.is_empty());
assert_eq!(runbook.max_points(), 1000);
let mut stream = MockStream::new(stages.clone());
let outputs = runbook.run(&mut stream).unwrap();
let expected_outputs: Vec<usize> = (0..stages.len()).collect();
let non_maintenance: Vec<_> = outputs.iter().filter_map(|o| *o).collect();
assert_eq!(non_maintenance, expected_outputs);
}
#[test]
fn test_load_runbook_from_yaml() {
use std::io::Write;
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step1.gt100")).unwrap();
File::create(temp_dir.path().join("step7.gt100")).unwrap();
let yaml_content = r#"
test_dataset:
max_pts: 100
gt_url: "http://example.com/groundtruth"
0:
operation: "insert"
start: 0
end: 1000
1:
operation: "search"
2:
operation: "insert"
start: 1000
end: 2000
3:
operation: "delete"
start: 200
end: 400
4:
operation: "replace"
ids_start: 2000
ids_end: 2500
tags_start: 400
tags_end: 900
5:
operation: "insert"
start: 2500
end: 3000
6:
operation: "delete"
start: 500
end: 700
7:
operation: "search"
"#;
let yaml_path = temp_dir.path().join("runbook.yaml");
{
let mut file = File::create(&yaml_path).unwrap();
file.write_all(yaml_content.as_bytes()).unwrap();
}
let mut groundtruth_finder = ScanDirectory::new(temp_dir.path()).unwrap();
let runbook = RunBook::load(&yaml_path, "test_dataset", &mut groundtruth_finder).unwrap();
assert_eq!(runbook.len(), 8);
assert_eq!(runbook.max_points(), 2300);
assert_eq!(runbook.max_tag(), Some(2999));
let stages = runbook.stages();
assert_eq!(
stages[0],
Stage::Insert {
dataset_offsets_and_ids: 0..1000
}
);
assert!(
matches!(&stages[1], Stage::Search { groundtruth } if groundtruth.file_name().unwrap() == "step1.gt100")
);
assert_eq!(
stages[2],
Stage::Insert {
dataset_offsets_and_ids: 1000..2000
}
);
assert_eq!(stages[3], Stage::Delete { ids: 200..400 });
assert_eq!(
stages[4],
Stage::Replace {
dataset_offsets: 2000..2500,
ids: 400..900
}
);
assert_eq!(
stages[5],
Stage::Insert {
dataset_offsets_and_ids: 2500..3000
}
);
assert_eq!(stages[6], Stage::Delete { ids: 500..700 });
assert!(
matches!(&stages[7], Stage::Search { groundtruth } if groundtruth.file_name().unwrap() == "step7.gt100")
);
}
#[test]
fn scan_directory_finds_groundtruth_file() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step0.gt100")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
let result = scanner.find_groundtruth(0).unwrap();
assert_eq!(result, temp_dir.path().join("step0.gt100"));
}
#[test]
fn scan_directory_finds_groundtruth_without_suffix_digits() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step5.gt")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
let result = scanner.find_groundtruth(5).unwrap();
assert_eq!(result, temp_dir.path().join("step5.gt"));
}
#[test]
fn scan_directory_errors_when_no_groundtruth_found() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("other_file.bin")).unwrap();
File::create(temp_dir.path().join("step0.other")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
let err = scanner.find_groundtruth(0).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("No groundtruth found"), "Got: {}", msg);
}
#[test]
fn scan_directory_errors_when_multiple_groundtruth_files() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step0.gt100")).unwrap();
File::create(temp_dir.path().join("step0.gt200")).unwrap();
File::create(temp_dir.path().join("step0.gt300")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
let err = scanner.find_groundtruth(0).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("Multiple groundtruth files"), "Got: {}", msg);
}
#[test]
fn scan_directory_ignores_non_digit_suffix() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step0.gtabc")).unwrap();
File::create(temp_dir.path().join("step0.gt100")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
let result = scanner.find_groundtruth(0).unwrap();
assert_eq!(result, temp_dir.path().join("step0.gt100"));
}
#[test]
fn scan_directory_errors_on_nonexistent_directory() {
let _ = ScanDirectory::new("/nonexistent/path/that/does/not/exist").unwrap_err();
}
#[test]
fn scan_directory_handles_different_stage_indices() {
let temp_dir = TempDir::new().unwrap();
File::create(temp_dir.path().join("step0.gt")).unwrap();
File::create(temp_dir.path().join("step5.gt")).unwrap();
File::create(temp_dir.path().join("step10.gt")).unwrap();
let mut scanner = ScanDirectory::new(temp_dir.path()).unwrap();
assert_eq!(
scanner.find_groundtruth(0).unwrap(),
temp_dir.path().join("step0.gt")
);
assert_eq!(
scanner.find_groundtruth(5).unwrap(),
temp_dir.path().join("step5.gt")
);
assert_eq!(
scanner.find_groundtruth(10).unwrap(),
temp_dir.path().join("step10.gt")
);
assert!(scanner.find_groundtruth(1).is_err());
}
}