use std::{borrow::Cow, fmt};
use crate::detection::{
Store,
detect::Match,
license::{LicenseType, TextData},
};
#[derive(Copy, Clone)]
pub struct IdentifiedLicense<'a> {
pub name: &'a str,
pub kind: LicenseType,
pub data: &'a TextData,
}
impl<'a> fmt::Debug for IdentifiedLicense<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IdentifiedLicense")
.field("name", &self.name)
.field("kind", &self.kind)
.finish()
}
}
#[derive(Debug)]
pub struct ScanResult<'a> {
pub score: f32,
pub license: Option<IdentifiedLicense<'a>>,
pub containing: Vec<ContainedResult<'a>>,
}
#[derive(Debug, Copy, Clone)]
pub struct ContainedResult<'a> {
pub score: f32,
pub license: IdentifiedLicense<'a>,
pub line_range: (usize, usize),
}
pub struct Scanner<'a> {
store: &'a Store,
mode: ScanMode,
confidence_threshold: f32,
shallow_limit: f32,
optimize: bool,
max_passes: u16,
}
pub enum ScanMode {
Elimination,
TopDown {
step_size: usize,
},
}
impl ScanMode {
#[inline]
pub fn top_down() -> Self {
Self::TopDown { step_size: 5 }
}
}
impl<'a> Scanner<'a> {
#[inline]
pub fn new(store: &'a Store) -> Self {
Self::with_scan_mode(store, ScanMode::Elimination)
}
#[inline]
pub fn with_scan_mode(store: &'a Store, mode: ScanMode) -> Self {
Self {
store,
mode,
confidence_threshold: 0.9,
shallow_limit: 0.99,
optimize: false,
max_passes: 10,
}
}
}
impl Scanner<'_> {
pub fn confidence_threshold(mut self, confidence_threshold: f32) -> Self {
self.confidence_threshold = confidence_threshold;
self
}
pub fn shallow_limit(mut self, shallow_limit: f32) -> Self {
self.shallow_limit = shallow_limit;
self
}
pub fn optimize(mut self, optimize: bool) -> Self {
self.optimize = optimize;
self
}
pub fn max_passes(mut self, max_passes: u16) -> Self {
self.max_passes = max_passes;
self
}
#[inline]
pub fn scan(&'_ self, text: &TextData) -> ScanResult<'_> {
match self.mode {
ScanMode::Elimination => self.scan_elimination(text),
ScanMode::TopDown { step_size } => self.scan_topdown(text, step_size),
}
}
fn scan_elimination(&'_ self, text: &TextData) -> ScanResult<'_> {
let mut analysis = self.store.analyze(text);
let score = analysis.score;
let mut license = None;
let mut containing = Vec::new();
if analysis.score > self.confidence_threshold {
license = Some(IdentifiedLicense {
name: analysis.name,
kind: analysis.license_type,
data: analysis.data,
});
if analysis.score > self.shallow_limit {
return ScanResult {
score,
license,
containing,
};
}
}
if !self.optimize {
return ScanResult {
score,
license,
containing,
};
}
let mut current_text: Cow<'_, TextData> = Cow::Borrowed(text);
for _n in 0..self.max_passes {
let (optimized, optimized_score) = current_text.optimize_bounds(analysis.data);
if optimized_score < self.confidence_threshold {
break;
}
containing.push(ContainedResult {
score: optimized_score,
license: IdentifiedLicense {
name: analysis.name,
kind: analysis.license_type,
data: analysis.data,
},
line_range: optimized.lines_view(),
});
current_text = Cow::Owned(optimized.white_out());
analysis = self.store.analyze(¤t_text);
}
ScanResult {
score,
license,
containing,
}
}
fn scan_topdown(&'_ self, text: &TextData, step_size: usize) -> ScanResult<'_> {
let (_, text_end) = text.lines_view();
let mut containing = Vec::new();
let mut current_start = 0usize;
while current_start < text_end {
let result = self.topdown_find_contained_license(text, current_start, step_size);
let contained = match result {
Some(c) => c,
None => break,
};
current_start = contained.line_range.1 + 1;
containing.push(contained);
}
ScanResult {
score: 0.0,
license: None,
containing,
}
}
fn topdown_find_contained_license(
&'_ self,
text: &TextData,
starting_at: usize,
step_size: usize,
) -> Option<ContainedResult<'_>> {
let (_, text_end) = text.lines_view();
let mut found: (usize, usize, Option<Match<'_>>) = (0, 0, None);
let mut hit_threshold = false;
'start: for start in (starting_at..text_end).step_by(step_size) {
for end in (start..=text_end).step_by(step_size) {
let view = text.with_view(start, end);
let analysis = self.store.analyze(&view);
if !hit_threshold && analysis.score >= self.confidence_threshold {
hit_threshold = true;
}
if hit_threshold {
if analysis.score < self.confidence_threshold {
break 'start;
} else {
found = (start, end, Some(analysis));
}
}
}
}
let matched = found.2?;
let check = matched.data;
let view = text.with_view(found.0, found.1);
let (optimized, optimized_score) = view.optimize_bounds(check);
if optimized_score < self.confidence_threshold {
return None;
}
Some(ContainedResult {
score: optimized_score,
license: IdentifiedLicense {
name: matched.name,
kind: matched.license_type,
data: matched.data,
},
line_range: optimized.lines_view(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_construct() {
let store = Store::new();
Scanner::new(&store);
Scanner::new(&store).confidence_threshold(0.5);
Scanner::new(&store)
.shallow_limit(0.99)
.optimize(true)
.max_passes(100);
}
#[test]
fn shallow_scan() {
let store = create_dummy_store();
let test_data = TextData::new("lorem ipsum\naaaaa bbbbb\nccccc\nhello");
let strategy = Scanner::new(&store)
.confidence_threshold(0.5)
.shallow_limit(0.0);
let result = strategy.scan(&test_data);
assert!(
result.score > 0.5,
"score must meet threshold; was {}",
result.score
);
assert_eq!(
result.license.expect("result has a license").name,
"license-1"
);
let strategy = Scanner::new(&store)
.confidence_threshold(0.8)
.shallow_limit(0.0);
let result = strategy.scan(&test_data);
assert!(result.license.is_none(), "result license is None");
}
#[test]
fn single_optimize() {
let store = create_dummy_store();
let test_data = TextData::new(
"lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout",
);
let strategy = Scanner::new(&store)
.confidence_threshold(0.5)
.optimize(true)
.shallow_limit(1.0);
let result = strategy.scan(&test_data);
assert!(result.license.is_none(), "result license is None");
assert_eq!(result.containing.len(), 1);
let contained = &result.containing[0];
assert_eq!(contained.license.name, "license-2");
assert!(
contained.score > 0.5,
"contained score is greater than threshold"
);
}
#[test]
fn find_multiple_licenses_elimination() {
let store = create_dummy_store();
let test_data = TextData::new(
"lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout\naaaaa\nbbbbb\nccccc",
);
let strategy = Scanner::new(&store)
.confidence_threshold(0.5)
.optimize(true)
.shallow_limit(1.0);
let result = strategy.scan(&test_data);
assert!(result.license.is_none(), "result license is None");
assert_eq!(2, result.containing.len());
let mut found1 = 0;
let mut found2 = 0;
for contained in &result.containing {
match contained.license.name {
"license-1" => {
assert!(contained.score > 0.5, "license-1 score meets threshold");
found1 += 1;
}
"license-2" => {
assert!(contained.score > 0.5, "license-2 score meets threshold");
found2 += 1;
}
_ => {
panic!("somehow got an unknown license name");
}
}
}
assert!(
found1 == 1 && found2 == 1,
"found both licenses exactly once"
);
}
#[test]
fn find_multiple_licenses_topdown() {
let store = create_dummy_store();
let test_data = TextData::new(
"lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout\naaaaa\nbbbbb\nccccc",
);
let strategy = Scanner::with_scan_mode(&store, ScanMode::TopDown { step_size: 1 })
.confidence_threshold(0.5);
let result = strategy.scan(&test_data);
assert!(result.license.is_none(), "result license is None");
println!("{:?}", result);
assert_eq!(2, result.containing.len());
let mut found1 = 0;
let mut found2 = 0;
for contained in &result.containing {
match contained.license.name {
"license-1" => {
assert!(contained.score > 0.5, "license-1 score meets threshold");
found1 += 1;
}
"license-2" => {
assert!(contained.score > 0.5, "license-2 score meets threshold");
found2 += 1;
}
_ => {
panic!("somehow got an unknown license name");
}
}
}
assert!(
found1 == 1 && found2 == 1,
"found both licenses exactly once"
);
}
fn create_dummy_store() -> Store {
let mut store = Store::new();
store.add_license("license-1".into(), "aaaaa\nbbbbb\nccccc".into());
store.add_license(
"license-2".into(),
"1234 5678 1234\n0000\n1010101010\n\n8888 9999".into(),
);
store
}
}