provenant/license_detection/query/
mod.rs1use crate::license_detection::index::LicenseIndex;
4use crate::license_detection::index::dictionary::{KnownToken, QueryToken, TokenId, TokenKind};
5use crate::license_detection::tokenize::tokenize_as_ids;
6use bit_set::BitSet;
7use std::cell::{OnceCell, RefCell};
8use std::collections::{HashMap, HashSet};
9
10#[derive(Debug, Clone)]
20pub struct PositionSpan {
21 start: usize,
22 end: usize,
23}
24
25impl PositionSpan {
26 pub fn new(start: usize, end: usize) -> Self {
27 Self { start, end }
28 }
29
30 pub fn contains(&self, pos: usize) -> bool {
31 self.start <= pos && pos <= self.end
32 }
33
34 pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
35 self.start..=self.end
36 }
37}
38
39#[derive(Debug)]
51pub struct Query<'a> {
52 pub text: String,
56
57 pub tokens: Vec<TokenId>,
61
62 pub line_by_pos: Vec<usize>,
69
70 pub unknowns_by_pos: HashMap<Option<i32>, usize>,
79
80 pub stopwords_by_pos: HashMap<Option<i32>, usize>,
86
87 pub shorts_and_digits_pos: HashSet<usize>,
93
94 pub high_matchables: BitSet,
100
101 pub low_matchables: BitSet,
107
108 pub is_binary: bool,
112
113 pub(crate) query_run_ranges: Vec<(usize, Option<usize>)>,
119
120 pub spdx_lines: Vec<(String, usize, usize)>,
127
128 pub index: &'a LicenseIndex,
130}
131
132pub fn matched_text_from_text(text: &str, start_line: usize, end_line: usize) -> String {
133 if start_line == 0 || end_line == 0 || start_line > end_line {
134 return String::new();
135 }
136
137 text.lines()
138 .enumerate()
139 .filter_map(|(idx, line)| {
140 let line_num = idx + 1;
141 if line_num >= start_line && line_num <= end_line {
142 Some(line)
143 } else {
144 None
145 }
146 })
147 .collect::<Vec<_>>()
148 .join("\n")
149}
150
151impl<'a> Query<'a> {
152 const TEXT_LINE_THRESHOLD: usize = 15;
167 const BINARY_LINE_THRESHOLD: usize = 50;
168 const MAX_TOKEN_PER_LINE: usize = 25;
169
170 fn compute_spdx_offset(
171 tokens: &[QueryToken],
172 dictionary: &crate::license_detection::index::dictionary::TokenDictionary,
173 ) -> Option<usize> {
174 let get_known_id = |i: usize| -> Option<TokenId> {
175 match tokens.get(i)? {
176 QueryToken::Known(known) => Some(known.id),
177 _ => None,
178 }
179 };
180
181 let spdx_id = dictionary.get("spdx")?;
182 let license_id = dictionary.get("license")?;
183 let identifier_id = dictionary.get("identifier")?;
184 let licence_id = dictionary.get("licence");
185
186 let licenses_id = dictionary.get("licenses");
187 let nuget_id = dictionary.get("nuget");
188 let org_id = dictionary.get("org");
189
190 let is_spdx_prefix = |ids: [Option<TokenId>; 3]| -> bool {
191 ids.iter().all(|id| id.is_some())
192 && ids[0] == Some(spdx_id)
193 && (ids[1] == Some(license_id) || ids[1] == licence_id)
194 && ids[2] == Some(identifier_id)
195 };
196
197 let is_nuget_prefix = |ids: [Option<TokenId>; 3]| -> bool {
198 licenses_id.is_some()
199 && nuget_id.is_some()
200 && org_id.is_some()
201 && ids[0] == licenses_id
202 && ids[1] == Some(nuget_id.unwrap())
203 && ids[2] == Some(org_id.unwrap())
204 };
205
206 if tokens.len() >= 3 {
207 let first_three = [get_known_id(0), get_known_id(1), get_known_id(2)];
208 if is_spdx_prefix(first_three) || is_nuget_prefix(first_three) {
209 return Some(0);
210 }
211 }
212
213 if tokens.len() >= 4 {
214 let second_three = [get_known_id(1), get_known_id(2), get_known_id(3)];
215 if is_spdx_prefix(second_three) || is_nuget_prefix(second_three) {
216 return Some(1);
217 }
218 }
219
220 if tokens.len() >= 5 {
221 let third_three = [get_known_id(2), get_known_id(3), get_known_id(4)];
222 if is_spdx_prefix(third_three) || is_nuget_prefix(third_three) {
223 return Some(2);
224 }
225 }
226
227 None
228 }
229
230 pub fn from_extracted_text(
231 text: &str,
232 index: &'a LicenseIndex,
233 binary_derived: bool,
234 ) -> Result<Self, anyhow::Error> {
235 let line_threshold = if binary_derived {
236 Self::BINARY_LINE_THRESHOLD
237 } else {
238 Self::TEXT_LINE_THRESHOLD
239 };
240
241 Self::with_source_options(text, index, line_threshold, Some(binary_derived))
242 }
243
244 pub fn query_runs(&self) -> Vec<QueryRun<'_>> {
248 self.query_run_ranges
249 .iter()
250 .map(|&(start, end)| QueryRun::new(self, start, end))
251 .collect()
252 }
253
254 fn with_source_options(
255 text: &str,
256 index: &'a LicenseIndex,
257 line_threshold: usize,
258 binary_derived: Option<bool>,
259 ) -> Result<Self, anyhow::Error> {
260 let is_binary = match binary_derived {
261 Some(is_binary) => is_binary,
262 None => Self::detect_binary(text)?,
263 };
264 let has_long_lines = Self::detect_long_lines(text);
265
266 let mut tokens = Vec::new();
267 let mut line_by_pos = Vec::new();
268 let mut unknowns_by_pos: HashMap<Option<i32>, usize> = HashMap::new();
269 let mut stopwords_by_pos: HashMap<Option<i32>, usize> = HashMap::new();
270 let mut shorts_and_digits_pos = HashSet::new();
271 let mut spdx_lines: Vec<(String, usize, usize)> = Vec::new();
272
273 let mut known_pos = -1i32;
274 let mut started = false;
275 let mut current_line = 1usize;
276
277 let mut tokens_by_line: Vec<Vec<Option<KnownToken>>> = Vec::new();
278
279 for line in text.lines() {
280 let line_trimmed = line.trim();
281 let mut line_tokens: Vec<Option<KnownToken>> = Vec::new();
282
283 let mut line_first_known_pos = None;
284
285 let line_query_tokens = tokenize_as_ids(line_trimmed, &index.dictionary);
286
287 for query_token in &line_query_tokens {
288 match query_token {
289 QueryToken::Known(known_token) => {
290 known_pos += 1;
291 started = true;
292 tokens.push(known_token.id);
293 line_by_pos.push(current_line);
294 line_tokens.push(Some(*known_token));
295
296 if line_first_known_pos.is_none() {
297 line_first_known_pos = Some(known_pos);
298 }
299
300 if known_token.is_short_or_digit {
301 let _ = shorts_and_digits_pos.insert(known_pos as usize);
302 }
303 }
304 QueryToken::Unknown if !started => {
305 *unknowns_by_pos.entry(None).or_insert(0) += 1;
306 line_tokens.push(None);
307 }
308 QueryToken::Unknown => {
309 *unknowns_by_pos.entry(Some(known_pos)).or_insert(0) += 1;
310 line_tokens.push(None);
311 }
312 QueryToken::Stopword if !started => {
313 *stopwords_by_pos.entry(None).or_insert(0) += 1;
314 }
315 QueryToken::Stopword => {
316 *stopwords_by_pos.entry(Some(known_pos)).or_insert(0) += 1;
317 }
318 }
319 }
320
321 let line_last_known_pos = known_pos;
322
323 let spdx_start_offset =
324 Self::compute_spdx_offset(&line_query_tokens, &index.dictionary);
325
326 if let Some(offset) = spdx_start_offset
327 && let Some(line_first_known_pos) = line_first_known_pos
328 {
329 let spdx_start_known_pos = line_first_known_pos + offset as i32;
330 if spdx_start_known_pos <= line_last_known_pos {
331 let spdx_start = spdx_start_known_pos as usize;
332 let spdx_end = (line_last_known_pos + 1) as usize;
333 spdx_lines.push((line_trimmed.to_string(), spdx_start, spdx_end));
334 }
335 }
336
337 tokens_by_line.push(line_tokens);
338 current_line += 1;
339 }
340
341 let high_matchables: BitSet = tokens
342 .iter()
343 .enumerate()
344 .filter(|(_pos, tid)| index.dictionary.token_kind(**tid) == TokenKind::Legalese)
345 .map(|(pos, _tid)| pos)
346 .collect();
347
348 let low_matchables: BitSet = tokens
349 .iter()
350 .enumerate()
351 .filter(|(_pos, tid)| index.dictionary.token_kind(**tid) == TokenKind::Regular)
352 .map(|(pos, _tid)| pos)
353 .collect();
354
355 let query_runs = Self::compute_query_runs(&tokens_by_line, line_threshold, has_long_lines);
356
357 Ok(Query {
358 text: text.to_string(),
359 tokens,
360 line_by_pos,
361 unknowns_by_pos,
362 stopwords_by_pos,
363 shorts_and_digits_pos,
364 high_matchables,
365 low_matchables,
366 is_binary,
367 query_run_ranges: query_runs,
368 spdx_lines,
369 index,
370 })
371 }
372
373 fn detect_binary(text: &str) -> Result<bool, anyhow::Error> {
387 let null_byte_count = text.bytes().filter(|&b| b == 0).count();
388
389 if null_byte_count > 0 {
390 return Ok(true);
391 }
392
393 let non_printable_ratio = text
394 .chars()
395 .filter(|&c| {
396 !c.is_ascii() && !c.is_ascii_graphic() && c != '\n' && c != '\r' && c != '\t'
397 })
398 .count() as f64
399 / text.len().max(1) as f64;
400
401 Ok(non_printable_ratio > 0.3)
402 }
403
404 fn detect_long_lines(text: &str) -> bool {
414 text.lines()
415 .any(|line| crate::license_detection::tokenize::count_tokens(line) > 25)
416 }
417
418 fn break_long_lines(lines: &[Vec<Option<KnownToken>>]) -> Vec<Vec<Option<KnownToken>>> {
419 lines
420 .iter()
421 .flat_map(|line| {
422 if line.is_empty() {
423 return Vec::new();
424 }
425
426 if line.len() <= Self::MAX_TOKEN_PER_LINE {
427 vec![line.clone()]
428 } else {
429 line.chunks(Self::MAX_TOKEN_PER_LINE)
430 .map(|chunk| chunk.to_vec())
431 .collect()
432 }
433 })
434 .collect()
435 }
436
437 fn compute_query_runs(
438 tokens_by_line: &[Vec<Option<KnownToken>>],
439 line_threshold: usize,
440 has_long_lines: bool,
441 ) -> Vec<(usize, Option<usize>)> {
442 let processed_lines = if has_long_lines {
443 Self::break_long_lines(tokens_by_line)
444 } else {
445 tokens_by_line.to_vec()
446 };
447
448 let mut query_runs = Vec::new();
449 let mut query_run_start = 0usize;
450 let mut query_run_end = None;
451 let mut empty_lines = 0usize;
452 let mut pos = 0usize;
453 let mut query_run_is_all_digit = true;
454
455 for line_tokens in processed_lines {
456 if query_run_end.is_some() && empty_lines >= line_threshold {
457 if !query_run_is_all_digit {
458 query_runs.push((query_run_start, query_run_end));
459 }
460 query_run_start = pos;
461 query_run_end = None;
462 empty_lines = 0;
463 query_run_is_all_digit = true;
464 }
465
466 if query_run_end.is_none() {
467 query_run_start = pos;
468 }
469
470 if line_tokens.is_empty() {
471 empty_lines += 1;
472 continue;
473 }
474
475 let line_is_all_digit = line_tokens
476 .iter()
477 .all(|token_id| token_id.map(|known| known.is_digit_only).unwrap_or(true));
478 let mut line_has_known_tokens = false;
479 let mut line_has_good_tokens = false;
480
481 for known in line_tokens.into_iter().flatten() {
482 line_has_known_tokens = true;
483 if known.kind == TokenKind::Legalese {
484 line_has_good_tokens = true;
485 }
486 if !known.is_digit_only {
487 query_run_is_all_digit = false;
488 }
489 query_run_end = Some(pos);
490 pos += 1;
491 }
492
493 if line_is_all_digit || !line_has_known_tokens {
494 empty_lines += 1;
495 continue;
496 }
497
498 if line_has_good_tokens {
499 empty_lines = 0;
500 } else {
501 empty_lines += 1;
502 }
503 }
504
505 if let Some(end) = query_run_end
506 && !query_run_is_all_digit
507 {
508 query_runs.push((query_run_start, Some(end)));
509 }
510
511 query_runs
512 }
513
514 #[inline]
524 pub fn line_for_pos(&self, pos: usize) -> Option<usize> {
525 self.line_by_pos.get(pos).copied()
526 }
527
528 #[inline]
530 pub fn is_empty(&self) -> bool {
531 self.tokens.is_empty()
532 }
533
534 pub fn whole_query_run(&self) -> QueryRun<'a> {
538 QueryRun::whole_query_snapshot(self)
539 }
540
541 pub fn subtract(&mut self, span: &PositionSpan) {
550 for pos in span.iter() {
551 self.high_matchables.remove(pos);
552 self.low_matchables.remove(pos);
553 }
554 }
555
556 pub fn matched_text(&self, start_line: usize, end_line: usize) -> String {
570 matched_text_from_text(&self.text, start_line, end_line)
571 }
572}
573
574#[derive(Debug, Clone)]
575struct WholeQueryRunSnapshot<'a> {
576 index: &'a LicenseIndex,
577 tokens: Vec<TokenId>,
578 line_by_pos: Vec<usize>,
579 high_matchables: BitSet,
580 low_matchables: BitSet,
581}
582
583#[derive(Debug, Clone)]
591pub struct QueryRun<'a> {
592 query: Option<&'a Query<'a>>,
593 whole_query_snapshot: Option<WholeQueryRunSnapshot<'a>>,
594 pub start: usize,
595 pub end: Option<usize>,
596 cached_high_matchables: OnceCell<BitSet>,
597 cached_low_matchables: OnceCell<BitSet>,
598 combined_matchables: RefCell<Option<BitSet>>,
599}
600
601impl<'a> QueryRun<'a> {
602 pub fn new(query: &'a Query<'a>, start: usize, end: Option<usize>) -> Self {
611 Self {
612 query: Some(query),
613 whole_query_snapshot: None,
614 start,
615 end,
616 cached_high_matchables: OnceCell::new(),
617 cached_low_matchables: OnceCell::new(),
618 combined_matchables: RefCell::new(None),
619 }
620 }
621
622 fn whole_query_snapshot(query: &Query<'a>) -> Self {
623 let end = if query.is_empty() {
624 None
625 } else {
626 Some(query.tokens.len() - 1)
627 };
628
629 Self {
630 query: None,
631 whole_query_snapshot: Some(WholeQueryRunSnapshot {
632 index: query.index,
633 tokens: query.tokens.clone(),
634 line_by_pos: query.line_by_pos.clone(),
635 high_matchables: query.high_matchables.clone(),
636 low_matchables: query.low_matchables.clone(),
637 }),
638 start: 0,
639 end,
640 cached_high_matchables: OnceCell::new(),
641 cached_low_matchables: OnceCell::new(),
642 combined_matchables: RefCell::new(None),
643 }
644 }
645
646 fn source_tokens(&self) -> &[TokenId] {
647 if let Some(query) = self.query {
648 &query.tokens
649 } else {
650 &self
651 .whole_query_snapshot
652 .as_ref()
653 .expect("snapshot-backed whole query run should have snapshot data")
654 .tokens
655 }
656 }
657
658 fn source_line_by_pos(&self) -> &[usize] {
659 if let Some(query) = self.query {
660 &query.line_by_pos
661 } else {
662 &self
663 .whole_query_snapshot
664 .as_ref()
665 .expect("snapshot-backed whole query run should have snapshot data")
666 .line_by_pos
667 }
668 }
669
670 fn source_high_matchables(&self) -> &BitSet {
671 if let Some(query) = self.query {
672 &query.high_matchables
673 } else {
674 &self
675 .whole_query_snapshot
676 .as_ref()
677 .expect("snapshot-backed whole query run should have snapshot data")
678 .high_matchables
679 }
680 }
681
682 fn source_low_matchables(&self) -> &BitSet {
683 if let Some(query) = self.query {
684 &query.low_matchables
685 } else {
686 &self
687 .whole_query_snapshot
688 .as_ref()
689 .expect("snapshot-backed whole query run should have snapshot data")
690 .low_matchables
691 }
692 }
693
694 pub fn get_index(&self) -> &LicenseIndex {
696 if let Some(query) = self.query {
697 query.index
698 } else {
699 self.whole_query_snapshot
700 .as_ref()
701 .expect("snapshot-backed whole query run should have snapshot data")
702 .index
703 }
704 }
705
706 pub fn line_for_pos(&self, pos: usize) -> Option<usize> {
714 self.source_line_by_pos().get(pos).copied()
715 }
716
717 pub fn tokens(&self) -> &[TokenId] {
723 match self.end {
724 Some(end) => &self.source_tokens()[self.start..=end],
725 None => &[],
726 }
727 }
728
729 pub fn tokens_with_pos(&self) -> impl Iterator<Item = (usize, TokenId)> + '_ {
733 self.tokens()
734 .iter()
735 .copied()
736 .enumerate()
737 .map(|(i, tid)| (self.start + i, tid))
738 }
739
740 pub fn is_digits_only(&self) -> bool {
744 self.tokens()
745 .iter()
746 .all(|&tid| self.get_index().dictionary.is_digit_only_token(tid))
747 }
748
749 pub fn is_matchable(&self, include_low: bool, exclude_positions: &[PositionSpan]) -> bool {
759 if self.is_digits_only() {
760 return false;
761 }
762
763 let matchables = self.matchables(include_low);
764
765 if exclude_positions.is_empty() {
766 return !matchables.is_empty();
767 }
768
769 let mut matchable_set = matchables;
770 for span in exclude_positions {
771 for pos in span.iter() {
772 matchable_set.remove(pos);
773 }
774 }
775
776 !matchable_set.is_empty()
777 }
778
779 pub fn matchables(&self, include_low: bool) -> BitSet {
780 if include_low {
781 if let Some(ref cached) = *self.combined_matchables.borrow() {
782 return cached.clone();
783 }
784 let combined: BitSet = self
785 .low_matchables()
786 .union(&self.high_matchables())
787 .collect();
788 *self.combined_matchables.borrow_mut() = Some(combined.clone());
789 combined
790 } else {
791 self.high_matchables()
792 }
793 }
794
795 pub fn matchable_tokens(&self) -> Vec<i32> {
796 let high_matchables = self.high_matchables();
797 if high_matchables.is_empty() {
798 return Vec::new();
799 }
800
801 let matchables = self.matchables(true);
802 self.tokens_with_pos()
803 .map(|(pos, tid)| {
804 if matchables.contains(pos) {
805 tid.raw() as i32
806 } else {
807 -1
808 }
809 })
810 .collect()
811 }
812
813 pub fn high_matchables(&self) -> BitSet {
814 self.cached_high_matchables
815 .get_or_init(|| {
816 let start = self.start;
817 let end = self.end;
818 let source = self.source_high_matchables();
819 let live_span = PositionSpan::new(start, end.unwrap_or(usize::MAX));
820 source
821 .iter()
822 .filter(|&pos| live_span.contains(pos))
823 .collect()
824 })
825 .clone()
826 }
827
828 pub fn low_matchables(&self) -> BitSet {
829 self.cached_low_matchables
830 .get_or_init(|| {
831 let start = self.start;
832 let end = self.end;
833 let source = self.source_low_matchables();
834 let live_span = PositionSpan::new(start, end.unwrap_or(usize::MAX));
835 source
836 .iter()
837 .filter(|&pos| live_span.contains(pos))
838 .collect()
839 })
840 .clone()
841 }
842}
843
844#[cfg(test)]
845mod test;