use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;
use anyhow::anyhow;
use anyhow::Context;
use arc_interner::ArcIntern;
use byteorder::WriteBytesExt;
use glob::Pattern;
use log::warn;
use serde::{Deserialize, Serialize};
use crate::labelling::State::{Excluded, Labelled, Split};
use crate::tree::{FileTree, FileTree1};
pub fn load_labelling_rules(path: &Path, source_name: &str) -> anyhow::Result<LabellingRules> {
let rule_path = path.join("labelling").join(format!("{}.zst", source_name));
if rule_path.exists() {
let rule_file = File::open(&rule_path)?;
let rule_reader = zstd::stream::read::Decoder::new(rule_file)?;
let buf_reader = BufReader::new(rule_reader);
Ok(LabellingRules::load(buf_reader)?)
} else {
Ok(LabellingRules::default())
}
}
pub fn save_labelling_rules(
path: &Path,
source_name: &str,
rules: &LabellingRules,
) -> anyhow::Result<()> {
let rule_path = path.join("labelling").join(format!("{}.zst", source_name));
if rule_path.exists() {
let backup_rule_path = path.join("labelling").join(format!("{}.zst~", source_name));
std::fs::rename(&rule_path, &backup_rule_path)?;
}
let rule_file = File::create(rule_path)?;
let mut zstd_writer = zstd::stream::write::Encoder::new(rule_file, 18)?;
rules.save(&mut zstd_writer)?;
zstd_writer.finish()?; Ok(())
}
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialOrd, PartialEq, Hash)]
pub struct Label(pub ArcIntern<String>);
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialOrd, PartialEq)]
pub enum State {
Labelled(Label),
Split,
Excluded,
}
impl State {
pub fn should_inherit(&self) -> bool {
match self {
Labelled(_) => true,
Split => false,
Excluded => true,
}
}
}
#[derive(Clone, Debug)]
pub struct GlobRule {
pub pattern: String,
pub glob: Pattern,
pub outcome: State,
}
#[derive(Clone, Debug, Default)]
pub struct LabellingRules {
pub position_based_rules: HashMap<String, State>,
pub glob_based_rules: Vec<GlobRule>,
}
impl LabellingRules {
pub fn load<R: BufRead>(mut input: R) -> anyhow::Result<Self> {
let mut result = LabellingRules {
position_based_rules: Default::default(),
glob_based_rules: Default::default(),
};
let mut str = String::new();
loop {
str.clear();
let line_len = input.read_line(&mut str)?;
if line_len == 0 {
break;
}
if &str == "---\n" {
break;
}
let pieces: Vec<&str> = str.trim_end_matches('\n').split('\t').collect();
if pieces.len() == 2 {
match pieces[1] {
"?" => {
result
.position_based_rules
.insert(pieces[0].to_owned(), Split);
}
"!" => {
result
.position_based_rules
.insert(pieces[0].to_owned(), Excluded);
}
label_str => {
result.position_based_rules.insert(
pieces[0].to_owned(),
Labelled(Label(ArcIntern::new(label_str.to_owned()))),
);
}
}
} else {
warn!("not 2 pieces: {:?}", str);
}
}
loop {
str.clear();
let line_len = input.read_line(&mut str)?;
if line_len == 0 {
break;
}
let pieces: Vec<&str> = str.trim().split('\t').collect();
if pieces.len() == 2 {
let outcome = match pieces[1] {
"?" => Split,
"!" => Excluded,
label_str => Labelled(Label(ArcIntern::new(label_str.to_owned()))),
};
let pattern = pieces[0].to_owned();
let glob = Pattern::new(&pattern)
.with_context(|| anyhow!("Whilst compiling glob: {:?}", pattern))?;
result.glob_based_rules.push(GlobRule {
pattern,
glob,
outcome,
});
} else {
warn!("not 2 pieces: {:?}", str);
}
}
Ok(result)
}
pub fn save<W: Write>(&self, mut output: W) -> anyhow::Result<()> {
for (path, rule) in self.position_based_rules.iter() {
output.write_all(path.as_bytes())?;
output.write_u8('\t' as u8)?;
match rule {
Labelled(label) => {
output.write_all(label.0.as_bytes())?;
}
Split => {
output.write_u8('?' as u8)?;
}
Excluded => {
output.write_u8('!' as u8)?;
}
}
output.write_u8('\n' as u8)?;
}
output.write_all("---\n".as_bytes())?;
for glob_rule in self.glob_based_rules.iter() {
output.write_all(glob_rule.pattern.as_bytes())?;
output.write_u8('\t' as u8)?;
match &glob_rule.outcome {
Labelled(label) => {
output.write_all(label.0.as_bytes())?;
}
Split => {
output.write_u8('?' as u8)?;
}
Excluded => {
output.write_u8('!' as u8)?;
}
}
output.write_u8('\n' as u8)?;
}
output.flush()?;
Ok(())
}
pub fn apply(&self, path: &str) -> Option<State> {
if let Some(rule_state) = self.position_based_rules.get(path) {
return Some(rule_state.clone());
}
for glob_rule in self.glob_based_rules.iter() {
if glob_rule.glob.matches(path) {
return Some(glob_rule.outcome.clone());
}
}
None
}
}
pub fn label_node(
path: String,
current_state: Option<State>,
node: &mut FileTree1<Option<State>>,
labels: &Vec<Label>,
rules: &LabellingRules,
) -> anyhow::Result<()> {
let mut next_state = current_state;
if let Some(rule_state) = rules.apply(&path) {
next_state = Some(rule_state.clone());
} else if !next_state
.as_ref()
.map(|s| s.should_inherit())
.unwrap_or(false)
{
next_state = None;
}
match node {
FileTree::NormalFile { meta, .. } => {
*meta = next_state;
}
FileTree::Directory { meta, children, .. } => {
*meta = next_state.clone();
for (child_name, child) in children.iter_mut() {
let child_path = format!("{}/{}", path, child_name);
label_node(child_path, next_state.clone(), child, labels, rules)?;
}
}
FileTree::SymbolicLink { meta, .. } => {
*meta = next_state;
}
FileTree::Other(_) => {
panic!("Other() nodes shouldn't be present here.");
}
}
Ok(())
}
pub fn str_to_label<I: AsRef<str>>(input: I) -> Label {
Label(ArcIntern::new(input.as_ref().to_owned()))
}