Skip to main content

provenant/askalono/
strategy.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::borrow::Cow;
5use std::fmt;
6
7use anyhow::Error;
8use log::{info, trace};
9use serde::Serialize;
10
11use super::{
12    license::{LicenseType, TextData},
13    store::{Match, Store},
14};
15
16/// A struct describing a license that was identified, as well as its type.
17#[derive(Serialize, Clone)]
18pub struct IdentifiedLicense<'a> {
19    /// The identifier of the license.
20    pub name: &'a str,
21    /// The type of the license that was matched.
22    pub kind: LicenseType,
23    /// A reference to the license data inside the store.
24    pub data: &'a TextData,
25}
26
27impl fmt::Debug for IdentifiedLicense<'_> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("IdentifiedLicense")
30            .field("name", &self.name)
31            .field("kind", &self.kind)
32            .finish()
33    }
34}
35
36/// Information about scanned content.
37///
38/// Produced by `ScanStrategy.scan`.
39#[derive(Serialize, Debug)]
40pub struct ScanResult<'a> {
41    /// The confidence of the match from 0.0 to 1.0.
42    pub score: f32,
43    /// The identified license of the overall text, or None if nothing met the
44    /// confidence threshold.
45    pub license: Option<IdentifiedLicense<'a>>,
46    /// Any licenses discovered inside the text, if `optimize` was enabled.
47    pub containing: Vec<ContainedResult<'a>>,
48}
49
50/// A struct describing a single license identified within a larger text.
51#[derive(Serialize, Debug, Clone)]
52pub struct ContainedResult<'a> {
53    /// The confidence of the match within the line range from 0.0 to 1.0.
54    pub score: f32,
55    /// The license identified in this portion of the text.
56    pub license: IdentifiedLicense<'a>,
57    /// A 0-indexed (inclusive, exclusive) range of line numbers identifying
58    /// where in the overall text a license was identified.
59    ///
60    /// See `TextData.lines_view()` for more information.
61    pub line_range: (usize, usize),
62}
63
64/// A `ScanStrategy` can be used as a high-level wrapped over a `Store`'s
65/// analysis logic.
66///
67/// A strategy configured here can be run repeatedly to scan a document for
68/// multiple licenses, or to automatically optimize to locate texts within a
69/// larger text.
70///
71/// # Examples
72///
73/// ```rust,should_panic
74/// # use std::error::Error;
75/// use provenant::askalono::{ScanStrategy, Store};
76///
77/// # fn main() -> Result<(), Box<dyn Error>> {
78/// let store = Store::new();
79/// // [...]
80/// let strategy = ScanStrategy::new(&store)
81///     .confidence_threshold(0.9)
82///     .optimize(true);
83/// let results = strategy.scan(&"my text to scan".into())?;
84/// # Ok(())
85/// # }
86/// ```
87pub struct ScanStrategy<'a> {
88    store: &'a Store,
89    mode: ScanMode,
90    confidence_threshold: f32,
91    shallow_limit: f32,
92    optimize: bool,
93    max_passes: u16,
94    step_size: usize,
95}
96
97/// Available scanning strategy modes.
98pub enum ScanMode {
99    /// Elimination is a general-purpose strategy that iteratively locates the
100    /// highest license match in a file, then the next, and so on until not
101    /// finding any more strong matches.
102    Elimination,
103
104    /// TopDown is a strategy intended for use with attribution documents, or
105    /// text files containing multiple licenses (and not much else). It's more
106    /// accurate than Elimination, but significantly slower.
107    TopDown,
108}
109
110impl<'a> ScanStrategy<'a> {
111    /// Construct a new scanning strategy tied to the given `Store`.
112    ///
113    /// By default, the strategy has conservative defaults and won't perform
114    /// any deeper investigaton into the contents of files.
115    pub fn new(store: &'a Store) -> ScanStrategy<'a> {
116        Self {
117            store,
118            mode: ScanMode::Elimination,
119            confidence_threshold: 0.9,
120            shallow_limit: 0.99,
121            optimize: false,
122            max_passes: 10,
123            step_size: 5,
124        }
125    }
126
127    /// Set the scanning mode.
128    ///
129    /// See ScanMode for a description of options. The default mode is
130    /// Elimination, which is a fast, good general-purpose matcher.
131    pub fn mode(mut self, mode: ScanMode) -> Self {
132        self.mode = mode;
133        self
134    }
135
136    /// Set the confidence threshold for this strategy.
137    ///
138    /// The overall license match must meet this number in order to be
139    /// reported. Additionally, if contained licenses are reported in the scan
140    /// (when `optimize` is enabled), they'll also need to meet this bar.
141    ///
142    /// Set this to 1.0 for only exact matches, and 0.0 to report even the
143    /// weakest match.
144    pub fn confidence_threshold(mut self, confidence_threshold: f32) -> Self {
145        self.confidence_threshold = confidence_threshold;
146        self
147    }
148
149    /// Set a fast-exit parameter that allows the strategy to skip the rest of
150    /// a scan for strong matches.
151    ///
152    /// This should be set higher than the confidence threshold; ideally close
153    /// to 1.0. If the overall match score is above this limit, the scanner
154    /// will return early and not bother performing deeper checks.
155    ///
156    /// This is really only useful in conjunction with `optimize`. A value of
157    /// 0.0 will fast-return on any match meeting the confidence threshold,
158    /// while a value of 1.0 will only stop on a perfect match.
159    pub fn shallow_limit(mut self, shallow_limit: f32) -> Self {
160        self.shallow_limit = shallow_limit;
161        self
162    }
163
164    /// Indicate whether a deeper scan should be performed.
165    ///
166    /// This is ignored if the shallow limit is met. It's not enabled by
167    /// default, however, so if you want deeper results you should set
168    /// `shallow_limit` fairly high and enable this.
169    pub fn optimize(mut self, optimize: bool) -> Self {
170        self.optimize = optimize;
171        self
172    }
173
174    /// The maximum number of identifications to perform before exiting a scan
175    /// of a single text.
176    ///
177    /// This is largely to prevent misconfigurations and infinite loop
178    /// scenarios, but if you have a document with a large number of licenses
179    /// then you may want to tune this to a value above the number of licenses
180    /// you expect to be identified.
181    pub fn max_passes(mut self, max_passes: u16) -> Self {
182        self.max_passes = max_passes;
183        self
184    }
185
186    /// Configure the scanning interval (in lines) for TopDown mode.
187    ///
188    /// A smaller step size will be more accurate at a significant cost of
189    /// speed.
190    pub fn step_size(mut self, step_size: usize) -> Self {
191        self.step_size = step_size;
192        self
193    }
194
195    /// Returns `true` if the underlying store has any licenses loaded.
196    pub fn store_has_licenses(&self) -> bool {
197        !self.store.is_empty()
198    }
199
200    /// Scan the given text content using this strategy's configured
201    /// preferences.
202    ///
203    /// Returns a `ScanResult` containing all discovered information.
204    pub fn scan(&self, text: &TextData) -> Result<ScanResult<'_>, Error> {
205        match self.mode {
206            ScanMode::Elimination => Ok(self.scan_elimination(text)),
207            ScanMode::TopDown => Ok(self.scan_topdown(text)),
208        }
209    }
210
211    fn scan_elimination(&self, text: &TextData) -> ScanResult<'_> {
212        let mut analysis = self.store.analyze(text);
213        let score = analysis.score;
214        let mut license = None;
215        let mut containing = Vec::new();
216        info!("Elimination top-level analysis: {:?}", analysis);
217
218        // meets confidence threshold? record that
219        if analysis.score > self.confidence_threshold {
220            license = Some(IdentifiedLicense {
221                name: analysis.name,
222                kind: analysis.license_type,
223                data: analysis.data,
224            });
225
226            // above the shallow limit -> exit
227            if analysis.score > self.shallow_limit {
228                return ScanResult {
229                    score,
230                    license,
231                    containing,
232                };
233            }
234        }
235
236        if self.optimize {
237            // repeatedly try to dig deeper
238            // this loop effectively iterates once for each license it finds
239            let mut current_text: Cow<'_, TextData> = Cow::Borrowed(text);
240            for _n in 0..self.max_passes {
241                let (optimized, optimized_score) = current_text.optimize_bounds(analysis.data);
242
243                // stop if we didn't find anything acceptable
244                if optimized_score < self.confidence_threshold {
245                    break;
246                }
247
248                // otherwise, save it
249                info!(
250                    "Optimized to {} lines ({}, {})",
251                    optimized_score,
252                    optimized.lines_view().0,
253                    optimized.lines_view().1
254                );
255                containing.push(ContainedResult {
256                    score: optimized_score,
257                    license: IdentifiedLicense {
258                        name: analysis.name,
259                        kind: analysis.license_type,
260                        data: analysis.data,
261                    },
262                    line_range: optimized.lines_view(),
263                });
264
265                // and white-out + reanalyze for next iteration
266                current_text = Cow::Owned(optimized.white_out());
267                analysis = self.store.analyze(&current_text);
268            }
269        }
270
271        ScanResult {
272            score,
273            license,
274            containing,
275        }
276    }
277
278    fn scan_topdown(&self, text: &TextData) -> ScanResult<'_> {
279        let (_, text_end) = text.lines_view();
280        let mut containing = Vec::new();
281
282        // find licenses working down thru the text's lines
283        let mut current_start = 0usize;
284        while current_start < text_end {
285            let result = self.topdown_find_contained_license(text, current_start);
286
287            let contained = match result {
288                Some(c) => c,
289                None => break,
290            };
291
292            current_start = contained.line_range.1 + 1;
293            containing.push(contained);
294        }
295
296        ScanResult {
297            score: 0.0,
298            license: None,
299            containing,
300        }
301    }
302
303    fn topdown_find_contained_license(
304        &self,
305        text: &TextData,
306        starting_at: usize,
307    ) -> Option<ContainedResult<'_>> {
308        let (_, text_end) = text.lines_view();
309        let mut found: (usize, usize, Option<Match<'_>>) = (0, 0, None);
310
311        trace!(
312            "topdown_find_contained_license starting at line {}",
313            starting_at
314        );
315
316        // speed: only start tracking once conf is met, and bail out after
317        let mut hit_threshold = false;
318
319        // move the start of window...
320        'start: for start in (starting_at..text_end).step_by(self.step_size) {
321            // ...and also the end of window to find high scores.
322            for end in (start..=text_end).step_by(self.step_size) {
323                let view = text.with_view(start, end);
324                let analysis = self.store.analyze(&view);
325
326                // just getting a feel for the data at this point, not yet
327                // optimizing the view.
328
329                // entering threshold: save the starting location
330                if !hit_threshold && analysis.score >= self.confidence_threshold {
331                    hit_threshold = true;
332                    trace!(
333                        "hit_threshold at ({}, {}) with score {}",
334                        start, end, analysis.score
335                    );
336                }
337
338                if hit_threshold {
339                    if analysis.score < self.confidence_threshold {
340                        // exiting threshold
341                        trace!(
342                            "exiting threshold at ({}, {}) with score {}",
343                            start, end, analysis.score
344                        );
345                        break 'start;
346                    } else {
347                        // maintaining threshold (also true for entering)
348                        found = (start, end, Some(analysis));
349                    }
350                }
351            }
352        }
353
354        // at this point we have a *rough* bounds for a match.
355        // now we can optimize to find the best one
356        let matched = found.2?;
357        let check = matched.data;
358        let view = text.with_view(found.0, found.1);
359        let (optimized, optimized_score) = view.optimize_bounds(check);
360
361        trace!(
362            "optimized {} {} at ({:?})",
363            optimized_score,
364            matched.name,
365            optimized.lines_view()
366        );
367
368        if optimized_score < self.confidence_threshold {
369            return None;
370        }
371
372        Some(ContainedResult {
373            score: optimized_score,
374            license: IdentifiedLicense {
375                name: matched.name,
376                kind: matched.license_type,
377                data: matched.data,
378            },
379            line_range: optimized.lines_view(),
380        })
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387
388    #[test]
389    fn can_construct() {
390        let store = Store::new();
391        ScanStrategy::new(&store);
392        ScanStrategy::new(&store).confidence_threshold(0.5);
393        ScanStrategy::new(&store)
394            .shallow_limit(0.99)
395            .optimize(true)
396            .max_passes(100);
397    }
398
399    #[test]
400    fn shallow_scan() {
401        let store = create_dummy_store();
402        let test_data = TextData::new("lorem ipsum\naaaaa bbbbb\nccccc\nhello");
403
404        // the above text should have a result with a confidence minimum of 0.5
405        let strategy = ScanStrategy::new(&store)
406            .confidence_threshold(0.5)
407            .shallow_limit(0.0);
408        let result = strategy.scan(&test_data).unwrap();
409        assert!(
410            result.score > 0.5,
411            "score must meet threshold; was {}",
412            result.score
413        );
414        assert_eq!(
415            result.license.expect("result has a license").name,
416            "license-1"
417        );
418
419        // but it won't pass with a threshold of 0.8
420        let strategy = ScanStrategy::new(&store)
421            .confidence_threshold(0.8)
422            .shallow_limit(0.0);
423        let result = strategy.scan(&test_data).unwrap();
424        assert!(result.license.is_none(), "result license is None");
425    }
426
427    #[test]
428    fn single_optimize() {
429        let store = create_dummy_store();
430        // this TextData matches license-2 with an overall score of ~0.46 and optimized
431        // score of ~0.57
432        let test_data = TextData::new(
433            "lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout",
434        );
435
436        // check that we can spot the gibberish license in the sea of other gibberish
437        let strategy = ScanStrategy::new(&store)
438            .confidence_threshold(0.5)
439            .optimize(true)
440            .shallow_limit(1.0);
441        let result = strategy.scan(&test_data).unwrap();
442        assert!(result.license.is_none(), "result license is None");
443        assert_eq!(result.containing.len(), 1);
444        let contained = &result.containing[0];
445        assert_eq!(contained.license.name, "license-2");
446        assert!(
447            contained.score > 0.5,
448            "contained score is greater than threshold"
449        );
450    }
451
452    #[test]
453    fn find_multiple_licenses_elimination() {
454        let store = create_dummy_store();
455        // this TextData matches license-2 with an overall score of ~0.46 and optimized
456        // score of ~0.57
457        let test_data = TextData::new(
458            "lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout\naaaaa\nbbbbb\nccccc",
459        );
460
461        // check that we can spot the gibberish license in the sea of other gibberish
462        let strategy = ScanStrategy::new(&store)
463            .mode(ScanMode::Elimination)
464            .confidence_threshold(0.5)
465            .optimize(true)
466            .shallow_limit(1.0);
467        let result = strategy.scan(&test_data).unwrap();
468        assert!(result.license.is_none(), "result license is None");
469        assert_eq!(2, result.containing.len());
470
471        // inspect the array and ensure we got both licenses
472        let mut found1 = 0;
473        let mut found2 = 0;
474        for contained in result.containing.iter() {
475            match contained.license.name {
476                "license-1" => {
477                    assert!(contained.score > 0.5, "license-1 score meets threshold");
478                    found1 += 1;
479                }
480                "license-2" => {
481                    assert!(contained.score > 0.5, "license-2 score meets threshold");
482                    found2 += 1;
483                }
484                _ => {
485                    panic!("somehow got an unknown license name");
486                }
487            }
488        }
489
490        assert!(
491            found1 == 1 && found2 == 1,
492            "found both licenses exactly once"
493        );
494    }
495
496    #[test]
497    fn find_multiple_licenses_topdown() {
498        env_logger::init();
499
500        let store = create_dummy_store();
501        // this TextData matches license-2 with an overall score of ~0.46 and optimized
502        // score of ~0.57
503        let test_data = TextData::new(
504            "lorem\nipsum abc def ghi jkl\n1234 5678 1234\n0000\n1010101010\n\n8888 9999\nwhatsit hello\narst neio qwfp colemak is the best keyboard layout\naaaaa\nbbbbb\nccccc",
505        );
506
507        // check that we can spot the gibberish license in the sea of other gibberish
508        let strategy = ScanStrategy::new(&store)
509            .mode(ScanMode::TopDown)
510            .confidence_threshold(0.5)
511            .step_size(1);
512        let result = strategy.scan(&test_data).unwrap();
513        assert!(result.license.is_none(), "result license is None");
514        assert_eq!(2, result.containing.len());
515
516        // inspect the array and ensure we got both licenses
517        let mut found1 = 0;
518        let mut found2 = 0;
519        for contained in result.containing.iter() {
520            match contained.license.name {
521                "license-1" => {
522                    assert!(contained.score > 0.5, "license-1 score meets threshold");
523                    found1 += 1;
524                }
525                "license-2" => {
526                    assert!(contained.score > 0.5, "license-2 score meets threshold");
527                    found2 += 1;
528                }
529                _ => {
530                    panic!("somehow got an unknown license name");
531                }
532            }
533        }
534
535        assert!(
536            found1 == 1 && found2 == 1,
537            "found both licenses exactly once"
538        );
539    }
540
541    fn create_dummy_store() -> Store {
542        let mut store = Store::new();
543        store.add_license("license-1".into(), "aaaaa\nbbbbb\nccccc".into());
544        store.add_license(
545            "license-2".into(),
546            "1234 5678 1234\n0000\n1010101010\n\n8888 9999".into(),
547        );
548        store
549    }
550}